Skip to content

Commit d775bec

Browse files
Remove the Node-Level Specific Metrics Logging During Training / Add MBPP / Dr. GRPO (#20)
* change log * format stuff * change default clip and std * add mbpp * DR GRPO * clip to be null
1 parent 81cb0b6 commit d775bec

31 files changed

+2207
-126
lines changed

README.md

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,44 +28,40 @@ python LLM_Collaboration_with_MARL/train_magrpo.py \
2828

2929
## Settings
3030

31-
### Joint Action Modes
31+
### Joint Action
3232

33-
`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: `align` (default), which pairs the g‑th generation of every agent to form G joint actions per node; and `cross`, which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): `align` → G^T; `cross` → (G^N)^T = G^{N·T}.
33+
`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: 'align' (default), which pairs the g‑th generation of every agent to form G joint actions per node; and 'cross', which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): align → G^T; cross G^{N·T}.
3434

3535
Aligned is faster in wall‑time (fewer sibling evaluations per node), while cross is more sample‑efficient (better value estimation) without extra VRAM because it reuses the same G generations per agent and only crosses them within the node. We never cross across different nodes/prompts; this preserves causal state consistency (actions are conditioned on the same prompts), keeps siblings comparable for the baseline/advantage, maintains correct credit assignment (log‑probs matched to rewards from the same state), and remains computationally tractable.
3636

37-
### Advantage Calculation
37+
### Advantage
3838

39-
`magrpo.normalize_advantage` is false by default. When true, compute z-scored advantages over sibling returns; when false, use a mean baseline without normalization.
39+
Advantages are used to optimize the agents policies, which use a mean baseline without any standard‑deviation normalization to make training unbiased (see [Dr. GRPO](https://arxiv.org/pdf/2503.20783)). We do not apply importance sampling ratios either, since our training is in an on-policy manner (the same policy is used for sampling and training).
4040

41-
`magrpo.epsilon_clip` clamps the advantage to [-epsilon_clip, +epsilon_clip] after normalization (default: None). 0 or None skips clamping entirely.
41+
### Number of Samples
4242

43-
We do not apply the importance sampling ratio because the policy changes slowly with LLMs, and the ratio is close to 1.0. This avoids numerical instability from multiplying many small probabilities.
43+
`magrpo.num_turns` is the number of turns in training and evaluation, and `magrpo.num_generations` is the number of samples per generation. Leaf (total samples at current turn) counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return).
4444

45-
### Number of Turns
45+
### Termination
4646

47-
`magrpo.num_turns` determines the number of turns (default: 2). Leaf counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return).
47+
`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly instead of expanding the full Monte Carlo tree. At each node (branch, turn), we compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue.
4848

49-
### Early Termination
49+
### New Prompts
5050

51-
`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly, instead of expanding the full Monte Carlo tree. At each node (branch, turn), compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue.
52-
53-
### 2+Turn Prompt
54-
55-
`external.original_prompt` and `external.previous_response` both default as `true`. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to `false` (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).
51+
`external.original_prompt` and `external.previous_response` both default as true. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to false (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).
5652

5753
### External Modes
5854

59-
`external.mode` is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include:
55+
`external.mode` is used to imitate the environment transition, which is set to 'level_feedback' by default. This gives additional information from external to prompts in the following turns; 'level_feedback' attaches test‑driven diagnostics, while alternatives include:
6056

61-
- `expert_edits`: an LLM proposes edits; prompts include edit suggestions plus context.
62-
- `level_passed` / `passed`: binary outcome oriented prompts with minimal context.
63-
- `plain`: no diagnostics, but still includes previous response (unless disabled) and a "Revise ..." instruction.
57+
- 'expert_edits': an LLM proposes edits; prompts include edit suggestions plus context.
58+
- 'level_passed' / 'passed': binary outcome oriented prompts with minimal context.
59+
- 'plain': no diagnostics, but still includes previous response (unless disabled) and a "revise your previous response" instruction.
6460

6561
Specific settings for 'level_feedback' is `external.sandbox_slice`, which controls how many eval tests to include in the feedback. By default, sandbox executes only the first assert (sandbox_slice=1). Use all eval tests by setting `external.sandbox_slice` to 0, None, or 'all'. Negative values use the last asserts. `external.sandbox_slice` only affects analysis-based modes ('level_feedback', 'level_passed', 'passed'), and it has no effect on 'expert_edits'.
6662

67-
Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude-3, GPT-4, once you have keys/tokens in your global environment variables.
63+
Specific settings for 'expert_edits' is `external.expert_edits_model`, which controls which LLM to use for proposing edits. By default, it uses DeepSeek-Coder. You can also change it to Claude, GPT, and other models, once you have keys/tokens in your environment.
6864

6965
### Output
7066

71-
`output.save_model` is set to `false` by default because of the huge storage required by multiple LLMs. `verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress. You can also turn on `magrpo.log_code_levels` to log the level-rewards during training, but it will crazily slow down the training.
67+
`output.save_model` is set to 'false' by default because of the huge storage required by multiple LLMs. `output.verbose` is used for debug printing on cluster if set to be true, but it is default to be false and you can only see a tqdm bar that shows the training progress.

baselines/che_concat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch

baselines/che_discuss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch
@@ -1009,6 +1008,9 @@ def evaluate_coophumaneval_two_round(
10091008

10101009

10111010
def main():
1011+
# --------------------------------------------------------------
1012+
# CLI: parse arguments
1013+
# --------------------------------------------------------------
10121014
parser = argparse.ArgumentParser(
10131015
description="CoopHumanEval Two-Round Model Evaluation"
10141016
)
@@ -1043,14 +1045,18 @@ def main():
10431045

10441046
args = parser.parse_args()
10451047

1046-
# Initialize two-round evaluator
1048+
# --------------------------------------------------------------
1049+
# Initialize evaluator
1050+
# --------------------------------------------------------------
10471051
evaluator = QwenCoopHumanEvalTwoRoundEvaluator(
10481052
aux_model_name=args.aux_model,
10491053
main_model_name=args.main_model,
10501054
device=args.device,
10511055
)
10521056

1057+
# --------------------------------------------------------------
10531058
# Run evaluation
1059+
# --------------------------------------------------------------
10541060
aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_two_round(
10551061
num_samples=args.samples,
10561062
num_generations=args.generations,

baselines/che_sequential.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch

baselines/che_single_agent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch
@@ -731,6 +730,9 @@ def evaluate_coophumaneval_baseline(
731730

732731

733732
def main():
733+
# --------------------------------------------------------------
734+
# CLI: parse arguments
735+
# --------------------------------------------------------------
734736
parser = argparse.ArgumentParser(
735737
description="CoopHumanEval Single Agent Baseline Evaluation"
736738
)
@@ -760,12 +762,16 @@ def main():
760762

761763
args = parser.parse_args()
762764

763-
# Initialize baseline evaluator
765+
# --------------------------------------------------------------
766+
# Initialize evaluator
767+
# --------------------------------------------------------------
764768
evaluator = QwenCoopHumanEvalSingleAgentBaseline(
765769
model_name=args.model, device=args.device
766770
)
767771

772+
# --------------------------------------------------------------
768773
# Run evaluation
774+
# --------------------------------------------------------------
769775
aggregated_metrics, sample_results = evaluator.evaluate_coophumaneval_baseline(
770776
num_samples=args.samples,
771777
num_generations=args.generations,

baselines/he_concat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch

baselines/he_discuss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch
@@ -1012,6 +1011,9 @@ def evaluate_humaneval_two_round(
10121011

10131012

10141013
def main():
1014+
# --------------------------------------------------------------
1015+
# CLI: parse arguments
1016+
# --------------------------------------------------------------
10151017
parser = argparse.ArgumentParser(description="HumanEval Two-Round Model Evaluation")
10161018
parser.add_argument(
10171019
"--aux-model", default="Qwen/Qwen2.5-Coder-3B", help="Auxiliary model name"
@@ -1044,14 +1046,18 @@ def main():
10441046

10451047
args = parser.parse_args()
10461048

1047-
# Initialize two-round evaluator
1049+
# --------------------------------------------------------------
1050+
# Initialize evaluator
1051+
# --------------------------------------------------------------
10481052
evaluator = QwenHumanEvalTwoRoundEvaluator(
10491053
aux_model_name=args.aux_model,
10501054
main_model_name=args.main_model,
10511055
device=args.device,
10521056
)
10531057

1058+
# --------------------------------------------------------------
10541059
# Run evaluation
1060+
# --------------------------------------------------------------
10551061
aggregated_metrics, sample_results = evaluator.evaluate_humaneval_two_round(
10561062
num_samples=args.samples,
10571063
num_generations=args.generations,

baselines/he_sequential.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch

baselines/he_single_agent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import re
44
import signal
55
import time
6-
from collections import defaultdict
7-
from math import comb
6+
87

98
import numpy as np
109
import torch
@@ -734,6 +733,9 @@ def evaluate_humaneval_baseline(
734733

735734

736735
def main():
736+
# --------------------------------------------------------------
737+
# CLI: parse arguments
738+
# --------------------------------------------------------------
737739
parser = argparse.ArgumentParser(
738740
description="HumanEval Single Agent Baseline Evaluation"
739741
)
@@ -763,12 +765,16 @@ def main():
763765

764766
args = parser.parse_args()
765767

766-
# Initialize baseline evaluator
768+
# --------------------------------------------------------------
769+
# Initialize evaluator
770+
# --------------------------------------------------------------
767771
evaluator = QwenHumanEvalSingleAgentBaseline(
768772
model_name=args.model, device=args.device
769773
)
770774

775+
# --------------------------------------------------------------
771776
# Run evaluation
777+
# --------------------------------------------------------------
772778
aggregated_metrics, sample_results = evaluator.evaluate_humaneval_baseline(
773779
num_samples=args.samples,
774780
num_generations=args.generations,

configs/grpo_che_config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ grpo:
4747
discount: 0.9
4848
termination_threshold: -0.1
4949
reward_shift: -2.1
50-
normalize_advantage: false
5150
epsilon_clip: null
5251

5352
# wandb

0 commit comments

Comments
 (0)