Skip to content

Commit 7ad5985

Browse files
Early termination (#17)
* early * change default num_gen * allow verbose * rm the single-turn yaml * Update README.md
1 parent cd060e2 commit 7ad5985

13 files changed

+132
-321
lines changed

README.md

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,65 @@
11
# LLM Collaboration with MARL
22

3-
This repository contains training scripts and configurations for the paper "LLM Collaboration with Multi‑Agent Reinforcement Learning".
4-
- [Benchmarks](#benchmarks)
5-
- [Training Scripts](#training-scripts)
6-
- [Default Configs](#default-configs)
7-
- [Parameter Overrides](#parameter-overrides)
8-
- [Multi-Turn Settings](#multi-turn-settings)
9-
- [2+Turn Prompt Composition](#2turn-prompt-composition)
10-
- [External Modes](#external-modes)
11-
- [Sandbox Tests](#sandbox-tests)
3+
Training scripts and configs for _"LLM Collaboration with Multi‑Agent Reinforcement Learning"_.
124

135
## Benchmarks
146

15-
- HumanEval (HE): 164 problems on split `test`
16-
- CoopHumanEval (CHE): 82 problems on split `test`
7+
- MBPP: 427 problems on split `sanitized`
8+
- HumanEval: 164 problems on split `test`
9+
- CoopHumanEval: 82 problems on split `test`
1710

1811
## Training Scripts
1912

2013
### Default Configs
2114

2215
```bash
23-
# Single-agent HumanEval (GRPO)
2416
python LLM_Collaboration_with_MARL/train_grpo.py \
2517
--config LLM_Collaboration_with_MARL/configs/grpo_he_config.yaml
2618

27-
# Multi-agent CoopHumanEval (MAGRPO)
2819
python LLM_Collaboration_with_MARL/train_magrpo.py \
2920
--config LLM_Collaboration_with_MARL/configs/magrpo_che_config.yaml
30-
31-
# Multi-turn HumanEval (MT-MAGRPO)
32-
python LLM_Collaboration_with_MARL/train_magrpo.py \
33-
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml
3421
```
3522

3623
### Parameter Overrides
3724

38-
You can override any configuration parameter using `--override`:
25+
You can always override any configuration parameter using `--override`:
3926

4027
```bash
41-
# Change model
4228
python LLM_Collaboration_with_MARL/train_magrpo.py \
4329
--config LLM_Collaboration_with_MARL/configs/magrpo_he_config.yaml \
44-
--override model_name='bigcode/starcoder2-3b'
30+
--override model.name='bigcode/starcoder2-3b' magrpo.num_turns=1
31+
```
4532

46-
# Modify training params
47-
python LLM_Collaboration_with_MARL/train_grpo.py \
48-
--config LLM_Collaboration_with_MARL/configs/grpo_che_config.yaml \
49-
--override grpo.num_train_epochs=20 grpo.learning_rate=3e-5
33+
## Settings
5034

51-
# Multi-turn override example
52-
python LLM_Collaboration_with_MARL/train_magrpo.py \
53-
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \
54-
--override dataset.train_split='test[16:]' dataset.eval_split='test[:16]' \
55-
magrpo.num_turns=2
35+
### Joint Action Modes
5636

57-
# Enable code-level training metrics (expensive; default is off)
58-
python LLM_Collaboration_with_MARL/train_magrpo.py \
59-
--config LLM_Collaboration_with_MARL/configs/magrpo_he_config.yaml \
60-
--override magrpo.log_code_levels=true
61-
```
62-
## Multi-Turn Settings
37+
`magrpo.joint_mode` determine how to combine each agent's K generations into joint actions at each turn. 2 modes are supported: if set 'align' by default, each agent's k-th generation is paired with the other agents' k-th generations to form a joint action; if set 'cross', all combinations of the agents' K generations are used to form joint actions (K^N joint actions for N agents).
6338

64-
### 2+Turn Prompt Composition
39+
Since the number of samples will also grow exponentially with the number of turns, aligned joint will be **more flexible** (\#samples could not be a perfect power) and hence faster to train in wall time. However, using cross joint will be more sample efficient (much lower VRAM compare to 'align' when num_generations=K^N), it also performs better since the value estimation is more accurate.
6540

66-
To save memory usage, 2+ turn prompts **include the previous response without the original first‑turn problem prompt by default**. You can add the original prompt to match the concept of observation-action history in MARL.
41+
### Number of Turns
6742

68-
```bash
69-
python LLM_Collaboration_with_MARL/train_magrpo.py \
70-
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \
71-
--override magrpo.external_original_prompt=True magrpo.external_previous_response=True
72-
```
43+
`magrpo.num_turns` determines the number of turns (`magrpo.num_turns=2` by default). The number of samples at each turn will grow exponentially with the number of turns: K^TN at turn T if cross joint, K^N if aligned joint.
7344

74-
### External Modes
45+
### Early Termination
7546

76-
Multi-turn training supports external transition modes for 2nd+ turns, set via `external.mode`:
47+
`magrpo.termination_threshold` is used to incentive agents to find high-reward solutions quickly, instead of expanding the full Monte Carlo tree.
7748

78-
- `level_feedback` **(default)**: Detailed diagnostics (impl found, syntax with line/col, per-test pass/fail errors, aux usage).
79-
- Requires `external.expert_model` in config when using `expert_edits` (e.g., `deepseek-coder`, Claude, etc.). This parameter is ignored for other modes (`level_feedback`, `level_passed`, `passed`, `plain`).
80-
- Requires corrsponding API keys in env vars.
81-
- `level_passed`: Binary passed signals (impl found, syntax, tests summary, aux usage).
82-
- `passed`: A binary signal — "All levels passed" or "Not all levels passed".
83-
- `plain`: No signals or diagnostics.
49+
At each node (branch, turn), compute the mean immediate **reward across the sibling** joint actions at that node. If the mean exceeds the threshold, that branch stops expanding at this turn; training backpropagates from the truncated subtree. Other branches continue.
8450

85-
```bash
86-
# HumanEval with detailed feedback signals
87-
python LLM_Collaboration_with_MARL/train_magrpo.py \
88-
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_he_config.yaml \
89-
--override external.mode='level_feedback'
90-
```
51+
### Multi-Turn Prompt
9152

92-
### Sandbox Tests
53+
`external.original_prompt` and `external.previous_response` both default as `true`. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to `false` (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).
9354

94-
The external modes obtain `entry_point` and tests via an internal resolver registered by the training script. **By default, sandbox executes only the first assert (`sandbox_slice=1`).** Use all eval tests by setting `external.sandbox_slice` to `0`, `None`, or `'all'`. A negative value uses the last N asserts. Note: `external.sandbox_slice` only affects analysis-based modes (`level_feedback`, `level_passed`, `passed`), and it has no effect on `expert_edits`.
55+
### External Modes
9556

96-
```bash
97-
# Add an external.sandbox_slice override
98-
python LLM_Collaboration_with_MARL/train_magrpo.py \
99-
--config LLM_Collaboration_with_MARL/configs/mt_magrpo_che_config.yaml \
100-
--override external.mode='level_feedback' external.sandbox_slice=-2
101-
```
57+
`external.mode` is set to be 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include 'expert_edits' (an LLM proposes edits), 'level_passed'/'passed' (binary outcomes), and 'plain' (no signals).
58+
59+
Specific settings for 'level_feedback' is `external.sandbox_slice`, which controls how many eval tests to include in the feedback. By default, sandbox executes only the first assert (sandbox_slice=1). Use all eval tests by setting `external.sandbox_slice` to 0, None, or 'all'. Negative values use the last asserts. `external.sandbox_slice` only affects analysis-based modes ('level_feedback', 'level_passed', 'passed'), and it has no effect on 'expert_edits'.
60+
61+
Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude-3, GPT-4, once you have keys/tokens in your global environment variables.
62+
63+
### Output
64+
65+
`output.save_model` is set to `false` by default because of the huge storage required by multiple LLMs. `verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. You can also turn on `magrpo.log_code_levels` to log the level-rewards during training, but it will crazily slow down the training.

configs/grpo_che_config.yaml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ model:
99
trust_remote_code: true
1010
model_kwargs:
1111
trust_remote_code: true
12-
torch_dtype: "auto"
12+
torch_dtype: "bfloat16"
1313

1414
# dataset
1515
dataset:
@@ -20,8 +20,9 @@ dataset:
2020

2121
# output
2222
output:
23-
base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
23+
base_dir: "output"
2424
save_final_model: false
25+
verbose: false
2526

2627
# external
2728
external:
@@ -32,23 +33,25 @@ external:
3233

3334
# grpo
3435
grpo:
35-
num_train_epochs: 16
36+
num_turns: 2
37+
num_train_epochs: 8
3638
per_device_train_batch_size: 1
37-
learning_rate: 1.0e-5
39+
learning_rate: 2.0e-5
3840
logging_steps: 50
3941
save_steps: 200
4042
num_generations: 4
4143
max_new_tokens: 256
42-
joint_mode: cross
44+
joint_mode: aligned
4345
temperature: 0.8
4446
top_p: 0.95
4547
discount: 0.9
48+
termination_threshold: -0.1
4649
reward_shift: -2.1
4750

4851
# wandb
4952
wandb:
5053
project: "mlrl"
5154
entity: "nu-llpr"
5255
name: "grpo_coophumaneval"
53-
dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
54-
tags: ["grpo", "coophumaneval", "single-agent"]
56+
dir: "output"
57+
tags: ["grpo", "coophumaneval"]

configs/grpo_he_config.yaml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ model:
99
trust_remote_code: true
1010
model_kwargs:
1111
trust_remote_code: true
12-
torch_dtype: "auto"
12+
torch_dtype: "bfloat16"
1313

1414
# dataset
1515
dataset:
@@ -20,8 +20,9 @@ dataset:
2020

2121
# output
2222
output:
23-
base_dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
23+
base_dir: "output"
2424
save_final_model: false
25+
verbose: false
2526

2627
# external
2728
external:
@@ -32,23 +33,25 @@ external:
3233

3334
# grpo
3435
grpo:
35-
num_train_epochs: 8
36+
num_turns: 2
37+
num_train_epochs: 6
3638
per_device_train_batch_size: 1
37-
learning_rate: 1.0e-5
39+
learning_rate: 2.0e-5
3840
logging_steps: 50
3941
save_steps: 200
4042
num_generations: 4
4143
max_new_tokens: 256
42-
joint_mode: cross
44+
joint_mode: aligned
4345
temperature: 0.8
4446
top_p: 0.95
4547
discount: 0.9
48+
termination_threshold: -0.1
4649
reward_shift: -2.1
4750

4851
# wandb
4952
wandb:
5053
project: "mlrl"
5154
entity: "nu-llpr"
5255
name: "grpo_humaneval"
53-
dir: "../../../work/hdd/bepg/sliu30/output_st_grpo"
54-
tags: ["grpo", "humaneval", "single-agent"]
56+
dir: "output"
57+
tags: ["grpo", "humaneval"]

configs/magrpo_che_config.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ model:
99
trust_remote_code: true
1010
model_kwargs:
1111
trust_remote_code: true
12-
torch_dtype: "auto"
12+
torch_dtype: "bfloat16"
1313

1414
# dataset
1515
dataset:
@@ -20,8 +20,9 @@ dataset:
2020

2121
# output
2222
output:
23-
base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
23+
base_dir: "output"
2424
save_final_model: false
25+
verbose: false
2526

2627
# external
2728
external:
@@ -32,7 +33,8 @@ external:
3233

3334
# magrpo
3435
magrpo:
35-
num_train_epochs: 16
36+
num_turns: 2
37+
num_train_epochs: 8
3638
per_device_train_batch_size: 1
3739
learning_rate: 2.0e-5
3840
logging_steps: 50
@@ -41,15 +43,16 @@ magrpo:
4143
max_new_tokens: 256
4244
temperature: 0.8
4345
top_p: 0.95
44-
joint_mode: cross
46+
joint_mode: aligned
4547
num_agents: 2
4648
discount: 0.9
49+
termination_threshold: -0.2
4750
reward_shift: -4
4851

4952
# wandb
5053
wandb:
5154
project: "mlrl"
5255
entity: "nu-llpr"
5356
name: "magrpo_coophumaneval"
54-
dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
57+
dir: "output"
5558
tags: ["magrpo", "coophumaneval", "multi-agent"]

configs/magrpo_he_config.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ model:
99
trust_remote_code: true
1010
model_kwargs:
1111
trust_remote_code: true
12-
torch_dtype: "auto"
12+
torch_dtype: "bfloat16"
1313

1414
# dataset
1515
dataset:
@@ -20,8 +20,9 @@ dataset:
2020

2121
# output
2222
output:
23-
base_dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
23+
base_dir: "output"
2424
save_final_model: false
25+
verbose: false
2526

2627
# external
2728
external:
@@ -32,22 +33,24 @@ external:
3233

3334
# magrpo
3435
magrpo:
35-
num_train_epochs: 8
36+
num_turns: 2
37+
num_train_epochs: 6
3638
per_device_train_batch_size: 1
3739
learning_rate: 2.0e-5
3840
logging_steps: 50
3941
save_steps: 200
4042
num_generations: 4
4143
max_new_tokens: 256
42-
joint_mode: cross
44+
joint_mode: aligned
4345
num_agents: 2
4446
discount: 0.9
47+
termination_threshold: -0.2
4548
reward_shift: -4
4649

4750
# wandb
4851
wandb:
4952
project: "mlrl"
5053
entity: "nu-llpr"
5154
name: "magrpo_humaneval"
52-
dir: "../../../work/hdd/bepg/sliu30/output_st_magrpo"
55+
dir: "output"
5356
tags: ["magrpo", "humaneval", "multi-agent"]

configs/mt_grpo_che_config.yaml

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)