Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 165 additions & 23 deletions agent_torch/core/llm/Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,56 @@
Template instance; no separate Slot class is needed.
"""

from typing import Any, Optional, Tuple, Callable
from typing import Any, Optional, Tuple, Callable, Dict, List
import torch
import torch.nn as nn


class Variable:
def __init__(self, desc: Optional[str] = None, learnable: bool = False, default: Any = None):
def __init__(self, desc: Optional[str] = None, learnable: bool = False, default: Any = None,
presentations: Optional[List[str]] = None):
"""
Args:
desc: Description of what this variable represents
learnable: Whether this variable's presentation can be optimized
default: Default value when variable is not set
presentations: List of format strings for presentation choices.
presentations[0] should always be "" (skip)
presentations[1+] are user-defined formats with {value} placeholder

Example:
skill = lm.Variable(
desc="Programming expertise",
learnable=True,
presentations=[
"", # Choice 0: Skip
"Skill: {value}", # Choice 1: Formal
"Expert in {value}", # Choice 2: Expertise
"{value} experience" # Choice 3: Casual
]
)
"""
self.desc = desc
self.learnable = learnable
self.default = default
self._name: Optional[str] = None
# Number of presentation options for P3O (0..num_options-1)
# 0: skip, 1: direct, 2: labeled, 3: contextual, 4: descriptive
self.num_options: int = 5

# Set up presentation choices
if presentations is None:
# Default behavior: binary choice (skip or include with current format)
self.presentations = [
"", # Choice 0: Skip
"- {value}: {value}" # Choice 1: Default format
]
else:
# User-provided presentations
if not presentations or presentations[0] != "":
# Ensure first choice is always "skip"
presentations = [""] + (presentations or ["- {value}: {value}"])
self.presentations = presentations

# Number of presentation options for P3O
self.num_options: int = len(self.presentations)

def __set_name__(self, owner, name: str):
self._name = name
Expand All @@ -42,16 +78,26 @@ def name(self) -> Optional[str]:
# --- Learnable parameter support (replaces Slot) ---
def get_parameter(self, instance: Any) -> Optional[nn.Parameter]:
"""Return/create the learnable parameter (logits over options) for this variable on the given instance."""
if not self.learnable or self._name is None:
if not self.learnable:
return None

# Auto-discover name if not set
if self._name is None:
for name, var in getattr(instance, '_variables', {}).items():
if var is self:
self._name = name
break

if self._name is None:
return None

param_attr = f"__var_param__{self._name}"
param = getattr(instance, param_attr, None)

if not isinstance(param, nn.Parameter):
# Initialize logits over presentation options
init = torch.full((self.num_options,), 0.5, dtype=torch.float32)
# Bias against skip option (index 0)
if self.num_options > 0:
init[0] = -1.0
# Match original experiment: unbiased initialization with torch.zeros()
# This gives 50/50 probability for binary choices, letting P3O learn naturally
init = torch.zeros(self.num_options, dtype=torch.float32)
param = nn.Parameter(init, requires_grad=True)
setattr(instance, param_attr, param)
return param
Expand Down Expand Up @@ -87,29 +133,124 @@ def fmt(category: int, data: dict) -> str:
if not field_name:
return ""
raw_value = data.get(field_name)

# Handle sparse skill data: if skill is not relevant (0 or missing), always skip
if field_name != 'soc_code' and field_name != 'job_title':
# For skill fields, check if this skill is relevant to this job
if raw_value is None or raw_value == 0 or raw_value == '0':
return "" # Skill not relevant for this job, always skip

if raw_value is None:
return ""
value = map_value(raw_value)

# If not learnable, always direct
# If not learnable, always use the first non-skip presentation
if not self.learnable:
if len(self.presentations) > 1:
return self.presentations[1].format(value=value)
return value

if category == 0:
return ""
if category == 1:
return value
if category == 2:
return f"{field_name}: {value}"
if category == 3:
return f"with {value}"
if category == 4:
return f"The {field_name} is {value}"
# Fallback
# For skills: only render if choice=1 AND skill is relevant (value=1)
if field_name != 'soc_code' and field_name != 'job_title':
if category == 0:
return "" # P3O chose to skip this skill
elif category == 1 and raw_value == 1:
# P3O chose to include AND skill is relevant for this job
if 1 < len(self.presentations):
return self.presentations[1].format(value=field_name.replace('_', ' ').title())
return f"- {field_name.replace('_', ' ').title()}: {field_name.replace('_', ' ').title()}"
else:
return "" # Skill not relevant or P3O chose to skip

# For non-skill fields (soc_code, etc.), use normal presentation logic
if 0 <= category < len(self.presentations):
presentation = self.presentations[category]
if presentation == "":
return ""
return presentation.format(value=value)

# Fallback to last presentation if category is out of range
if self.presentations:
last_presentation = self.presentations[-1]
return last_presentation.format(value=value) if last_presentation else ""

return value

return self.num_options, fmt

# --- DSPy Conversion Utilities ---
@classmethod
def from_dspy_field(cls, field_name: str, field_annotation, dspy_field, **kwargs) -> 'Variable':
"""Convert a DSPy InputField or OutputField to an lm.Variable.

Args:
field_name: Name of the field in the DSPy signature
field_annotation: Type annotation (e.g., str, JobMetrics)
dspy_field: The dspy.InputField() or dspy.OutputField() instance
**kwargs: Additional Variable constructor arguments

Returns:
Variable instance configured for use in AgentTorch templates

Example:
# From DSPy signature:
# job_info: str = dspy.InputField(desc="Job description")

var = lm.Variable.from_dspy_field(
"job_info", str, dspy.InputField(desc="Job description"),
learnable=True # Make it optimizable
)
"""
# Extract description from DSPy field
desc = getattr(dspy_field, 'desc', None) or f"Converted from DSPy field: {field_name}"

# InputFields are typically learnable (content we want to optimize)
# OutputFields are typically not learnable (LLM generates them)
default_learnable = 'InputField' in str(type(dspy_field))
learnable = kwargs.pop('learnable', default_learnable)

# Create Variable with DSPy metadata
return cls(
desc=desc,
learnable=learnable,
default=kwargs.pop('default', None),
**kwargs
)

@classmethod
def from_dspy_signature(cls, signature_class) -> Dict[str, 'Variable']:
"""Convert an entire DSPy Signature to a dictionary of lm.Variables.

Args:
signature_class: A DSPy Signature class

Returns:
Dictionary mapping field names to Variable instances

Example:
class JobSignature(dspy.Signature):
job_info: str = dspy.InputField(desc="Job skills")
prediction: JobMetrics = dspy.OutputField(desc="Predictions")

variables = lm.Variable.from_dspy_signature(JobSignature)
# Returns: {"job_info": Variable(...), "prediction": Variable(...)}
"""
import inspect
variables = {}

# Get signature fields
if hasattr(signature_class, '__annotations__'):
for field_name, field_type in signature_class.__annotations__.items():
if hasattr(signature_class, field_name):
dspy_field = getattr(signature_class, field_name)
# Skip non-field attributes
if hasattr(dspy_field, 'desc') or 'Field' in str(type(dspy_field)):
variables[field_name] = cls.from_dspy_field(
field_name, field_type, dspy_field
)

return variables

# --- Helpers for P3O optimization over Variable options ---
def sample_index(self, instance: Any) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Sample a presentation index using the instance-bound logits.
Expand All @@ -122,6 +263,7 @@ def sample_index(self, instance: Any) -> Tuple[int, torch.Tensor, torch.Tensor]:
logits = self.get_parameter(instance)
if logits is None:
return 1, torch.tensor(0.0), torch.tensor(0.0)

probs = torch.softmax(logits, dim=0)
dist = torch.distributions.Categorical(probs)
idx = dist.sample()
Expand Down
75 changes: 55 additions & 20 deletions agent_torch/core/llm/archetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def configure(self, *, external_df=None, split: int | None = None):
setattr(self._prompt, "_external_df", df)
return self

def sample(self, kwargs: Dict[str, Any] | None = None, verbose: bool = False) -> torch.Tensor:
def sample(self, kwargs: Dict[str, Any] | None = None, verbose: bool = False, batch_size: int = None) -> torch.Tensor:
"""Sample decisions.
- If broadcast not called: run a single prompt and return (1,)
- If broadcast called: run group-based prompts over population and return (n_agents,)
Expand All @@ -105,11 +105,19 @@ def sample(self, kwargs: Dict[str, Any] | None = None, verbose: bool = False) ->
# One prompt only
prompt_list = []
if isinstance(self._prompt, Template):
# If an external_df was configured, generate prompts for all rows
# If an external_df was configured, generate prompts for rows (with optional batching)
external_df = getattr(self._prompt, "_external_df", None)
if external_df is not None:
# Apply batch sampling if specified
row_indices = list(range(len(external_df)))
if batch_size and batch_size < len(external_df):
import random
row_indices = random.sample(row_indices, batch_size)
if verbose:
print(f"Batch sampling: Using {batch_size} jobs out of {len(external_df)} total")

prompt_list = []
for row_idx in range(len(external_df)):
for row_idx in row_indices:
# Pre-broadcast: show placeholders for fields not present in external_df
base_text = self._prompt.get_base_prompt_manager_template()
data = self._prompt.assemble_data(
Expand Down Expand Up @@ -163,27 +171,54 @@ def _safe_fill(m):
value = 0.0
if outputs:
out0 = outputs[0]
try:
text_value = out0["text"] if isinstance(out0, dict) and "text" in out0 else out0
value = float(text_value)
if verbose:
print(f"Parsed value: {value}")
except Exception:
value = 0.0
if verbose:
print(f"Failed to parse, using default: {value}")
# Process structured response
structured_data = out0["response"]
value = sum(float(v) for v in structured_data.values())
if verbose:
print(f"Parsed structured value: {value}")
if verbose:
print(f"=== End LLM Call ===\n")
tensor_out = torch.tensor([value], device=kwargs["device"]).float()
if len(prompt_list) > 1:
try:
vals = []
for out in outputs:
tv = out["text"] if isinstance(out, dict) and "text" in out else out
vals.append(float(tv))
tensor_out = torch.tensor(vals, device=kwargs["device"]).float()
except Exception:
pass
vals = []
for out in outputs:
structured_data = out["response"]
val = sum(float(v) for v in structured_data.values())
vals.append(val)
tensor_out = torch.tensor(vals, device=kwargs["device"]).float()
# Store data for P3O compatibility (pre-broadcast individual job mode)
# Create a mock behavior object to store the required P3O data
if not hasattr(self, '_mock_behavior'):
from types import SimpleNamespace
self._mock_behavior = SimpleNamespace()

# Store group data for P3O (each job is treated as its own "group")
group_keys = [f"job_{i}" for i in range(len(outputs))]
group_outputs = []
group_structured = []

for out in outputs:
structured_data = out["response"]
val = sum(float(v) for v in structured_data.values())
group_outputs.append(val)
group_structured.append(structured_data)

# Store in mock behavior for P3O to access
self._mock_behavior.last_group_keys = group_keys
self._mock_behavior.last_group_outputs = group_outputs
self._mock_behavior.last_group_structured = group_structured

# Store slot choices if template has learnable variables
if isinstance(self._prompt, Template):
slots = self._prompt.create_slots()
sampled_choices = {}
for name, var in slots.items():
if getattr(var, 'learnable', False):
# Sample choice from variable distribution
idx, _, _ = var.sample_index(self._prompt)
sampled_choices[name] = idx
self._mock_behavior.last_slot_choices = sampled_choices

# Always print meta summary regardless of verbosity
try:
_mean = float(tensor_out.mean().item())
Expand Down
Loading