Skip to content

Commit e95e44f

Browse files
Fix CoMLRL's #26 (#23)
* stateless * u * Update README.md
1 parent 26c11e5 commit e95e44f

15 files changed

+580
-195
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ python LLM_Collaboration_with_MARL/train_magrpo.py \
3030

3131
### Joint Action
3232

33-
`magrpo.joint_mode` determines how to combine each agent’s G generations into joint actions at each turn. Two modes are supported: 'align' (default), which pairs the g‑th generation of every agent to form G joint actions per node; and 'cross', which forms the Cartesian product within a node, yielding G^N joint actions per node (N agents). Total leaf joint trajectories after T turns (no early termination): align → G^T; cross → G^{N·T}.
33+
`magrpo.joint_mode` determines how to combine each agent’s $G$ generations into joint actions at each turn. Two modes are supported: 'align' (default), which pairs the $g$‑th generation of every agent to form $G$ joint actions per node; and 'cross', which forms the Cartesian product within a node, yielding $G^N$ joint actions per node ($N$ agents). Total leaf joint trajectories after $T$ turns (no early termination): align → $G^T$; cross → $G^{N\cdot T}$.
3434

3535
Aligned is faster in wall‑time (fewer sibling evaluations per node), while cross is more sample‑efficient (better value estimation) without extra VRAM because it reuses the same G generations per agent and only crosses them within the node. We never cross across different nodes/prompts; this preserves causal state consistency (actions are conditioned on the same prompts), keeps siblings comparable for the baseline/advantage, maintains correct credit assignment (log‑probs matched to rewards from the same state), and remains computationally tractable.
3636

@@ -40,15 +40,15 @@ Advantages are used to optimize the agents policies, which use a mean baseline w
4040

4141
### Number of Samples
4242

43-
`magrpo.num_turns` is the number of turns in training and evaluation, and `magrpo.num_generations` is the number of samples per generation. Leaf (total samples at current turn) counts grow with T: `aligned` → G^T; `cross` → G^{N·T}. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size G for `aligned`, and G^N for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage Aᵢ = Returnᵢ − mean_sibling(Return).
43+
`magrpo.num_turns` is the number of turns in training and evaluation, and `magrpo.num_generations` is the number of samples per generation. Leaf (total samples at current turn) counts grow with $T$: `aligned`$G^T$; `cross`$G^{N\cdot T}$. At each node, the sibling set (competing joint actions under the same prompt/context/turn) has size $G$ for `aligned`, and $G^N$ for `cross`. The policy‑gradient baseline is the mean return over these siblings at that node, i.e., advantage $A_i = \mathrm{Return}_i - \operatorname{mean}_{\text{sibling}}(\mathrm{Return})$.
4444

4545
### Termination
4646

4747
`magrpo.termination_threshold` is used to incentivize agents to find high‑reward solutions quickly instead of expanding the full Monte Carlo tree. At each node (branch, turn), we compute the mean immediate reward across that node’s sibling joint actions; if the mean exceeds the threshold, that branch stops expanding at this turn and the trainer backpropagates from the truncated subtree. Other branches continue.
4848

49-
### New Prompts
49+
### History Controls
5050

51-
`external.original_prompt` and `external.previous_response` both default as true. 2+ turn prompts include both the original first‑turn problem prompt and the previous response by default to preserve full context; you can shorten the context by setting either to false (for example, keep only the previous response to reduce tokens while retaining the most recent interaction).
51+
`external.memory_mode` controls how many memory and what memory to be given to each agent, it should be selected from `last`, `full`, `memoryful`. `full` (default) includes all prior prompts/responses per flags (compact "History" block); `last` includes only first‑turn prompt and last response per flags; `memoryful`: rely on model’s internal state (KV cache), trainer carries per‑agent KV caches across turns and continues generation from them and prompts omit explicit history. In addition, `external.previous_prompts` and `external.previous_responses` determine which parts of the agent‑wise history are inserted into the next‑turn prompt text: in `last`, `previous_prompts` includes the agent’s first‑turn prompt and `previous_responses` includes only the most recent response; in `full`, `previous_prompts` includes all prior prompts and `previous_responses` includes all prior responses; in `memoryful`, neither history is injected into the text because the per‑agent KV cache already carries this context.
5252

5353
### External Modes
5454

configs/grpo_che_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# grpo
3536
grpo:

configs/grpo_he_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# grpo
3536
grpo:

configs/grpo_mbpp_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# grpo
3536
grpo:

configs/magrpo_che_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# magrpo
3536
magrpo:

configs/magrpo_he_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# magrpo
3536
magrpo:

configs/magrpo_mbpp_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ output:
2828
external:
2929
mode: "level_feedback"
3030
sandbox_slice: 1
31-
original_prompt: true
32-
previous_response: true
31+
previous_prompts: true
32+
previous_responses: true
33+
memory_mode: full
3334

3435
# magrpo
3536
magrpo:

external/__init__.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def get_external_transition(
4444
agent_completions: Union[List[str], Tuple[str, str]],
4545
num_agents: int = 2,
4646
mode: str = "expert_edits",
47+
*,
48+
# New history flags
49+
previous_prompts: bool = False,
50+
previous_responses: bool = True,
51+
memory_mode: str = "last",
52+
# Per-branch history from trainer
53+
prompt_history_per_agent: Optional[List[List[str]]] = None,
54+
response_history_per_agent: Optional[List[List[str]]] = None,
4755
**kwargs,
4856
) -> Union[List[str], Tuple[str, str]]:
4957
"""
@@ -84,9 +92,12 @@ def print(*args, **kwargs): # type: ignore
8492

8593
# Route to the requested mode implementation
8694
mode = (mode or "").lower()
87-
# Pull common flags controlling prompt composition
88-
original_prompt_flag = kwargs.get("original_prompt", False)
89-
previous_response_flag = kwargs.get("previous_response", True)
95+
memory_mode = (memory_mode or "last").lower()
96+
# Prepare normalized histories
97+
if prompt_history_per_agent is None:
98+
prompt_history_per_agent = [[] for _ in range(int(num_agents))]
99+
if response_history_per_agent is None:
100+
response_history_per_agent = [[] for _ in range(int(num_agents))]
90101

91102
if mode == "expert_edits":
92103
if int(num_agents) == 1:
@@ -113,9 +124,12 @@ def print(*args, **kwargs): # type: ignore
113124
entry_point=entry_point,
114125
aux_completion=aux_comp,
115126
main_completion=main_comp,
116-
original_prompt_flag=original_prompt_flag,
117-
previous_response_flag=previous_response_flag,
127+
previous_prompts=previous_prompts,
128+
previous_responses=previous_responses,
129+
memory_mode=memory_mode,
118130
num_agent=int(num_agents),
131+
prompt_history_per_agent=prompt_history_per_agent,
132+
response_history_per_agent=response_history_per_agent,
119133
)
120134

121135
# Print preview
@@ -145,9 +159,12 @@ def print(*args, **kwargs): # type: ignore
145159
main_completion=main_comp,
146160
test_code=test_code,
147161
entry_point=entry_point,
148-
original_prompt_flag=original_prompt_flag,
149-
previous_response_flag=previous_response_flag,
162+
previous_prompts=previous_prompts,
163+
previous_responses=previous_responses,
164+
memory_mode=memory_mode,
150165
num_agent=int(num_agents),
166+
prompt_history_per_agent=prompt_history_per_agent,
167+
response_history_per_agent=response_history_per_agent,
151168
)
152169
print("\n" + "=" * 60)
153170
print("EXTERNAL MODE PREVIEW: level_feedback")
@@ -174,9 +191,12 @@ def print(*args, **kwargs): # type: ignore
174191
main_completion=main_comp,
175192
test_code=test_code,
176193
entry_point=entry_point,
177-
original_prompt_flag=original_prompt_flag,
178-
previous_response_flag=previous_response_flag,
194+
previous_prompts=previous_prompts,
195+
previous_responses=previous_responses,
196+
memory_mode=memory_mode,
179197
num_agent=int(num_agents),
198+
prompt_history_per_agent=prompt_history_per_agent,
199+
response_history_per_agent=response_history_per_agent,
180200
)
181201
print("\n" + "=" * 60)
182202
print("EXTERNAL MODE PREVIEW: level_passed")
@@ -203,9 +223,12 @@ def print(*args, **kwargs): # type: ignore
203223
main_completion=main_comp,
204224
test_code=test_code,
205225
entry_point=entry_point,
206-
original_prompt_flag=original_prompt_flag,
207-
previous_response_flag=previous_response_flag,
226+
previous_prompts=previous_prompts,
227+
previous_responses=previous_responses,
228+
memory_mode=memory_mode,
208229
num_agent=int(num_agents),
230+
prompt_history_per_agent=prompt_history_per_agent,
231+
response_history_per_agent=response_history_per_agent,
209232
)
210233
print("\n" + "=" * 60)
211234
print("EXTERNAL MODE PREVIEW: passed")
@@ -232,9 +255,12 @@ def print(*args, **kwargs): # type: ignore
232255
main_completion=main_comp,
233256
test_code=test_code,
234257
entry_point=entry_point,
235-
original_prompt_flag=original_prompt_flag,
236-
previous_response_flag=previous_response_flag,
258+
previous_prompts=previous_prompts,
259+
previous_responses=previous_responses,
260+
memory_mode=memory_mode,
237261
num_agent=int(num_agents),
262+
prompt_history_per_agent=prompt_history_per_agent,
263+
response_history_per_agent=response_history_per_agent,
238264
)
239265
print("\n" + "=" * 60)
240266
print("EXTERNAL MODE PREVIEW: plain")

external/expert_edits.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
import re
4-
from typing import List, Tuple
4+
from typing import List, Tuple, Optional
55

66
from anthropic import Anthropic
77
from openai import OpenAI
@@ -164,9 +164,13 @@ def format_followup_prompts(
164164
entry_point: str = "",
165165
aux_completion: str = "",
166166
main_completion: str = "",
167-
original_prompt_flag: bool = False,
168-
previous_response_flag: bool = True,
167+
*,
168+
previous_prompts: bool = False,
169+
previous_responses: bool = True,
170+
memory_mode: str = "last",
169171
num_agent: int = 2,
172+
prompt_history_per_agent: Optional[List[List[str]]] = None,
173+
response_history_per_agent: Optional[List[List[str]]] = None,
170174
) -> Tuple[str, str]:
171175
"""
172176
Format the 2+ turn prompts for expert_edits mode to match other modes:
@@ -177,6 +181,23 @@ def format_followup_prompts(
177181

178182
target_entry = entry_point or "main"
179183

184+
# Normalize histories
185+
memory_mode = (memory_mode or "last").lower()
186+
# If full history has only one prior turn, render identically to 'last'
187+
if memory_mode == "full":
188+
try:
189+
counts_p = [len(x or []) for x in (prompt_history_per_agent or [])]
190+
counts_r = [len(x or []) for x in (response_history_per_agent or [])]
191+
if counts_p and max(counts_p) <= 1 and counts_r and max(counts_r) <= 1:
192+
memory_mode = "last"
193+
except Exception:
194+
pass
195+
196+
if prompt_history_per_agent is None:
197+
prompt_history_per_agent = [[] for _ in range(int(num_agent))]
198+
if response_history_per_agent is None:
199+
response_history_per_agent = [[] for _ in range(int(num_agent))]
200+
180201
# Single-agent: only build main prompt; no aux references
181202
if int(num_agent) == 1:
182203
main_lines: List[str] = []
@@ -186,20 +207,33 @@ def format_followup_prompts(
186207
or "<no implementation found>"
187208
)
188209

189-
if original_prompt_flag:
190-
_aux_base, main_base = build_first_turn_prompts(
191-
original_prompt, target_entry
192-
)
193-
main_lines.extend([main_base, ""]) # context then blank line
194-
195-
if previous_response_flag:
210+
if memory_mode == "full":
211+
if previous_prompts and prompt_history_per_agent and prompt_history_per_agent[0]:
212+
main_lines.extend(["History: previous prompts:"])
213+
for t, ph in enumerate(prompt_history_per_agent[0], start=1):
214+
main_lines.append(f"- Turn {t} prompt:\n{ph}")
215+
main_lines.append("")
216+
if previous_responses and response_history_per_agent and response_history_per_agent[0]:
217+
main_lines.extend(["History: your previous responses:"])
218+
for t, resp in enumerate(response_history_per_agent[0], start=1):
219+
main_lines.append(f"- Turn {t} response:\n{resp}")
220+
main_lines.append("")
221+
elif memory_mode == "last":
222+
if previous_prompts:
223+
_aux_base, main_base = build_first_turn_prompts(
224+
original_prompt, target_entry
225+
)
226+
main_lines.extend([main_base, ""]) # context then blank line
227+
if memory_mode == "last" and previous_responses:
196228
main_lines.extend(
197229
[
198230
"Your previous implementation:",
199231
prev_main,
200232
"",
201233
]
202234
)
235+
elif memory_mode == "memoryful":
236+
pass
203237

204238
main_lines.extend(
205239
[
@@ -225,14 +259,45 @@ def format_followup_prompts(
225259
aux_lines: List[str] = []
226260
main_lines: List[str] = []
227261

228-
if original_prompt_flag:
229-
aux_base, main_base = build_first_turn_prompts(original_prompt, target_entry)
230-
aux_lines.extend([aux_base, ""]) # add a blank line after context
231-
main_lines.extend([main_base, ""])
232-
233-
if previous_response_flag:
234-
aux_lines.extend(["Your previous aux(...) implementation:", prev_aux, ""])
235-
main_lines.extend(["Your previous main implementation:", prev_main, ""])
262+
if memory_mode == "full":
263+
if previous_prompts:
264+
if prompt_history_per_agent and len(prompt_history_per_agent) >= 2:
265+
aux_ph = prompt_history_per_agent[0]
266+
main_ph = prompt_history_per_agent[1]
267+
if aux_ph:
268+
aux_lines.append("History: previous prompts:")
269+
for t, ph in enumerate(aux_ph, start=1):
270+
aux_lines.append(f"- Turn {t} prompt:\n{ph}")
271+
aux_lines.append("")
272+
if main_ph:
273+
main_lines.append("History: previous prompts:")
274+
for t, ph in enumerate(main_ph, start=1):
275+
main_lines.append(f"- Turn {t} prompt:\n{ph}")
276+
main_lines.append("")
277+
if previous_responses:
278+
if response_history_per_agent and len(response_history_per_agent) >= 2:
279+
aux_rh = response_history_per_agent[0]
280+
main_rh = response_history_per_agent[1]
281+
if aux_rh:
282+
aux_lines.append("History: your previous aux(...) responses:")
283+
for t, resp in enumerate(aux_rh, start=1):
284+
aux_lines.append(f"- Turn {t} response:\n{resp}")
285+
aux_lines.append("")
286+
if main_rh:
287+
main_lines.append("History: your previous main responses:")
288+
for t, resp in enumerate(main_rh, start=1):
289+
main_lines.append(f"- Turn {t} response:\n{resp}")
290+
main_lines.append("")
291+
elif memory_mode == "last":
292+
if previous_prompts:
293+
aux_base, main_base = build_first_turn_prompts(original_prompt, target_entry)
294+
aux_lines.extend([aux_base, ""]) # add a blank line after context
295+
main_lines.extend([main_base, ""])
296+
if previous_responses:
297+
aux_lines.extend(["Your previous aux(...) implementation:", prev_aux, ""])
298+
main_lines.extend(["Your previous main implementation:", prev_main, ""])
299+
elif memory_mode == "memoryful":
300+
pass
236301

237302
aux_lines.extend(
238303
[

0 commit comments

Comments
 (0)