diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..04a091f3 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,60 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Repository Overview + +This is a monorepo of production-grade ML projects built with [ZenML](https://zenml.io). Each subdirectory is a self-contained project demonstrating MLOps best practices across domains: LLMOps, Computer Vision, traditional MLOps, and Data Science. + +## Essential Commands + +```bash +# REQUIRED before any commit - formats entire repo +bash scripts/format.sh + +# Run a specific project +cd / +pip install -r requirements.txt # or: uv pip install -r requirements.txt +python run.py # most projects use this entry point +python run.py --help # check project-specific options + +# View ZenML dashboard +zenml up +``` + +## Code Style + +- **Formatter**: Ruff (line length: 79 chars) +- **Docstrings**: Google style +- **Imports**: Sorted automatically by ruff; group by stdlib → third-party → local + +The `scripts/format.sh` script handles all formatting (unused import removal, import sorting, code formatting). + +## Project Structure Pattern + +Each project follows this structure: +``` +/ +├── run.py # Main entry point with CLI +├── requirements.txt # Dependencies +├── configs/ # YAML pipeline configurations +├── pipelines/ # ZenML @pipeline definitions +├── steps/ # ZenML @step definitions +├── utils/ # Helper functions +└── materializers/ # Custom artifact serializers (optional) +``` + +## ZenML Patterns + +- Pipelines use `@pipeline` decorator, steps use `@step` decorator +- Configuration is typically YAML-based in `configs/` directory +- Run with `--no-cache` flag to disable ZenML caching during development +- Artifacts are automatically versioned and tracked + +## Adding New Projects + +See `ADDING_PROJECTS.md` for the complete guide. Key requirements: +- Include `requirements.txt` +- Add comprehensive `README.md` +- Add project to the table in root `README.md` +- Run `bash scripts/format.sh` before committing diff --git a/art-rl/.env.example b/art-rl/.env.example new file mode 100644 index 00000000..5e5b2379 --- /dev/null +++ b/art-rl/.env.example @@ -0,0 +1,5 @@ +# Required for RULER scoring (uses LLM to judge agent performance) +OPENAI_API_KEY=your-openai-api-key + +# Optional for experiment tracking with Weights & Biases +WANDB_API_KEY=your-wandb-api-key diff --git a/art-rl/README.md b/art-rl/README.md new file mode 100644 index 00000000..65651607 --- /dev/null +++ b/art-rl/README.md @@ -0,0 +1,250 @@ +# ART Email Search Agent + +Train an email search agent using [OpenPipe ART](https://github.com/openpipe/art) (Agentic Reinforcement Training) with [ZenML](https://zenml.io) for production ML pipelines. + +## Overview + +This project demonstrates how to: + +- **Train an RL agent** using GRPO (Group Relative Policy Optimization) with RULER scoring +- **Track artifacts** including scenarios, model checkpoints, and training metrics +- **Orchestrate on Kubernetes** with GPU step operators for training +- **Evaluate models** with automated correctness judging +- **Deploy as HTTP service** using ZenML Pipeline Deployments + +The agent learns to search through emails and answer questions using LangGraph's ReAct pattern, starting from a Qwen 2.5 7B base model. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ ZenML Pipelines │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ data_preparation_pipeline (cached, no GPU) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ download_ │─▶│ create_ │─▶│ load_ │ │ +│ │ enron_data │ │ database │ │ scenarios │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ +│ training_pipeline (GPU required) │ +│ ┌──────────────┐ ┌────────────────────────────────┐ │ +│ │ setup_art_ │─▶│ train_agent │ │ +│ │ model │ │ • LangGraph rollouts │ │ +│ └──────────────┘ │ • RULER scoring │ │ +│ │ • GRPO training │ │ +│ └────────────────────────────────┘ │ +│ │ +│ evaluation_pipeline (GPU required) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ load_trained │─▶│ run_ │─▶│ compute_ │ │ +│ │ _model │ │ inference │ │ metrics │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ +│ inference_pipeline (HTTP Deployment) │ +│ ┌────────────────────────────────────────────────────────────┐ │ +│ │ POST /invoke │ │ +│ │ → run_single_inference → { answer, source_ids } │ │ +│ └────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +## Quick Start + +### Prerequisites + +- Python 3.10+ +- CUDA-compatible GPU (for training/inference) +- OpenAI API key (for RULER scoring) + +### Installation + +```bash +cd art-rl +pip install -r requirements.txt # or: uv pip install -r requirements.txt +``` + +### Environment Setup + +Create a `.env` file: + +```bash +# Required for RULER scoring +OPENAI_API_KEY=your-openai-key + +# Optional for experiment tracking +WANDB_API_KEY=your-wandb-key +``` + +### Running the Pipelines + +```bash +# 1. Prepare data (run once, artifacts are cached) +python run.py --pipeline data + +# 2. Train the agent (requires GPU) +python run.py --pipeline train --config configs/training_local.yaml + +# 3. Evaluate on test scenarios +python run.py --pipeline eval + +# Or run everything: +python run.py --pipeline all +``` + +## Project Structure + +``` +art-rl/ +├── run.py # CLI entry point +├── requirements.txt # Dependencies +├── configs/ +│ ├── data_prep.yaml # Data preparation config +│ ├── training_local.yaml # Local GPU training +│ ├── training_k8s.yaml # Kubernetes training +│ ├── evaluation.yaml # Evaluation config +│ └── deployment.yaml # HTTP deployment config +├── pipelines/ +│ ├── data_preparation.py # Data pipeline +│ ├── training.py # Training pipeline +│ ├── evaluation.py # Evaluation pipeline +│ └── inference.py # Inference pipeline (deployable) +├── steps/ +│ ├── data/ # Data preparation steps +│ ├── training/ # Training steps +│ ├── evaluation/ # Evaluation steps +│ └── inference/ # Inference steps +├── environment/ +│ ├── models.py # Pydantic data models +│ ├── email_db.py # SQLite database operations +│ └── tools.py # LangGraph tools +└── agent/ + ├── rollout.py # LangGraph rollout function + └── judge.py # Correctness judging +``` + +## How It Works + +### The Email Search Task + +The agent is trained to answer questions about a user's email inbox. Given a question like: + +> "Who can I contact for Power Operations when Sally is in London?" + +The agent must: +1. Search the email database using relevant keywords +2. Read specific emails to find information +3. Return a final answer with source references + +### Training with ART + +[ART (Agentic Reinforcement Training)](https://art.openpipe.ai/) uses GRPO to train the agent: + +1. **Rollouts**: For each training scenario, generate multiple trajectories using the LangGraph ReAct agent +2. **RULER Scoring**: An LLM judge scores trajectories relative to each other (easier than absolute scoring) +3. **GRPO Update**: Policy is updated to favor higher-scoring trajectories + +### LangGraph Integration + +The agent uses LangGraph's ReAct pattern with three tools: + +- `search_inbox_tool`: Search emails by keywords +- `read_email_tool`: Read a specific email by ID +- `return_final_answer_tool`: Provide the final answer + +## Configuration + +### Training Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `base_model` | `Qwen/Qwen2.5-7B-Instruct` | Base model for fine-tuning | +| `groups_per_step` | 2 | Scenario groups per training step | +| `rollouts_per_group` | 4 | Rollouts per scenario for GRPO | +| `learning_rate` | 1e-5 | Optimizer learning rate | +| `max_steps` | 20 | Maximum training steps | +| `ruler_model` | `openai/o4-mini` | Model for RULER scoring | + +### Kubernetes Training + +For production training on Kubernetes with GPU nodes: + +```bash +python run.py --pipeline train --config configs/training_k8s.yaml +``` + +The config includes: +- GPU node affinity +- Resource requests/limits +- Shared memory volume for PyTorch + +### Deploying as HTTP Service + +After training, deploy the agent as a real-time HTTP service: + +```bash +# Deploy the inference pipeline +python run.py --pipeline deploy --name my-email-agent + +# Or use ZenML CLI directly +zenml pipeline deploy pipelines.inference.inference_pipeline --name my-email-agent +``` + +Once deployed, invoke the service via HTTP: + +```bash +curl -X POST http://localhost:8080/invoke \ + -H "Content-Type: application/json" \ + -d '{ + "parameters": { + "question": "What meeting is scheduled for next week?", + "inbox_address": "john.smith@enron.com", + "query_date": "2001-05-15" + } + }' +``` + +Response: +```json +{ + "question": "What meeting is scheduled for next week?", + "answer": "There is a team sync scheduled for Monday at 10am.", + "source_ids": [""], + "success": true +} +``` + +Manage deployments: +```bash +zenml deployment list # List all deployments +zenml deployment describe my-email-agent # Show deployment details +zenml deployment logs my-email-agent -f # Stream logs +zenml deployment deprovision my-email-agent # Stop deployment +``` + +## ZenML Features Used + +- **Artifact Tracking**: Scenarios, checkpoints, and metrics are versioned +- **Model Control Plane**: Training metrics logged with `log_model_metadata()` +- **Docker Settings**: Custom images with CUDA and dependencies +- **Pipeline Caching**: Data preparation runs once, reused for training +- **Kubernetes Orchestration**: GPU pod settings for training steps +- **Pipeline Deployments**: Deploy inference pipelines as HTTP services + +## Dataset + +This project uses the [Enron Email Dataset](https://huggingface.co/datasets/corbt/enron-emails) with [sample questions](https://huggingface.co/datasets/corbt/enron_emails_sample_questions) from Hugging Face. + +## References + +- [ART Documentation](https://art.openpipe.ai/) +- [LangGraph Integration](https://art.openpipe.ai/integrations/langgraph-integration) +- [RULER Documentation](https://art.openpipe.ai/fundamentals/ruler) +- [ZenML Documentation](https://docs.zenml.io/) +- [ZenML Pipeline Deployments](https://docs.zenml.io/concepts/deployment) +- [Original ART Notebook](https://github.com/openpipe/art) + +## License + +Apache 2.0 diff --git a/art-rl/agent/__init__.py b/art-rl/agent/__init__.py new file mode 100644 index 00000000..3b754709 --- /dev/null +++ b/art-rl/agent/__init__.py @@ -0,0 +1,10 @@ +# Agent module for email search agent +from agent.judge import CorrectnessJudgeResponse, judge_correctness +from agent.rollout import ProjectTrajectory, rollout + +__all__ = [ + "rollout", + "ProjectTrajectory", + "judge_correctness", + "CorrectnessJudgeResponse", +] diff --git a/art-rl/agent/judge.py b/art-rl/agent/judge.py new file mode 100644 index 00000000..1264d0bd --- /dev/null +++ b/art-rl/agent/judge.py @@ -0,0 +1,80 @@ +"""Correctness judging for agent answers using LLM evaluation.""" + +from textwrap import dedent + +from environment.models import Scenario +from litellm import acompletion +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt + + +class CorrectnessJudgeResponse(BaseModel): + """Response from the correctness judge LLM.""" + + reasoning: str = Field(description="Explanation of the reasoning process.") + accept: bool = Field( + description="Whether the AI answer should be accepted." + ) + + +@retry(stop=stop_after_attempt(3)) +async def judge_correctness( + scenario: Scenario, + answer: str, + judge_model: str = "openai/gpt-4.1", +) -> CorrectnessJudgeResponse: + """Judge whether an agent's answer is correct. + + Uses an LLM to compare the agent's answer against the reference answer, + checking for semantic correctness rather than exact matching. + + Args: + scenario: The scenario containing the question and reference answer. + answer: The agent's generated answer. + judge_model: The LiteLLM model identifier for the judge. + + Returns: + CorrectnessJudgeResponse with reasoning and accept/reject decision. + """ + system_prompt = dedent( + """ + You are given a question, the reference answer (labelled **Reference + answer**), and an answer generated by an AI assistant (labelled + **AI answer**). + + Your task is to decide whether the AI answer is correct and should be + accepted. You should accept the answer if it contains the relevant + information from the reference answer. You should not accept the answer + if it is missing information relevant to the question, or if it + contradicts the reference answer. + """ + ) + + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": ( + f"Question: {scenario.question}\n" + f"Reference answer: {scenario.answer}\n" + f"AI answer: {answer}" + ), + }, + ] + + response = await acompletion( + model=judge_model, + messages=messages, + response_format=CorrectnessJudgeResponse, + ) + + first_choice = response.choices[0] + raw_content = first_choice.message.content or "{}" + + try: + return CorrectnessJudgeResponse.model_validate_json(raw_content) + except Exception as e: + return CorrectnessJudgeResponse( + reasoning=f"Parse error: {e}\nRaw: {raw_content}", + accept=False, + ) diff --git a/art-rl/agent/rollout.py b/art-rl/agent/rollout.py new file mode 100644 index 00000000..e3ae1dcf --- /dev/null +++ b/art-rl/agent/rollout.py @@ -0,0 +1,137 @@ +"""LangGraph rollout function for the email search agent. + +This module implements the core agent loop using LangGraph's ReAct pattern. +Each rollout represents one episode of the agent attempting to answer a +question by searching through emails. +""" + +import uuid +from textwrap import dedent +from typing import Optional + +import art +from environment.models import EmailScenario, FinalAnswer +from environment.tools import create_email_tools +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.prebuilt import create_react_agent + +from agent.judge import judge_correctness + +# Maximum number of agent turns before stopping +MAX_TURNS = 20 + + +class ProjectTrajectory(art.Trajectory): + """Extended trajectory that captures the agent's final answer.""" + + final_answer: Optional[FinalAnswer] = None + + +async def rollout( + model: art.Model, + email_scenario: EmailScenario, + db_path: str = "./enron_emails.db", + judge_model: str = "openai/gpt-4.1", +) -> ProjectTrajectory: + """Execute a single rollout of the email search agent. + + This function: + 1. Sets up the LangGraph ReAct agent with email search tools + 2. Runs the agent on the given scenario + 3. Judges the correctness of the final answer + 4. Returns a trajectory with metrics for training + + Args: + model: The ART model to use for inference. + email_scenario: The scenario containing the question and metadata. + db_path: Path to the email database. + judge_model: LiteLLM model identifier for correctness judging. + + Returns: + ProjectTrajectory with the agent's conversation and metrics. + """ + # Import here to avoid circular imports and allow lazy loading + from art.langgraph import init_chat_model + + scenario = email_scenario.scenario + + traj = ProjectTrajectory( + reward=0.0, + messages_and_choices=[], + metadata={ + "scenario_id": scenario.id, + "step": email_scenario.step, + }, + ) + + system_prompt = dedent( + f""" + You are an email search agent. You are given a user query and a list + of tools you can use to search the user's email. Use the tools to + search the user's emails and find the answer to the user's query. + You may take up to {MAX_TURNS} turns to find the answer, so if your + first search doesn't find the answer, you can try with different + keywords. + + User's email address is {scenario.inbox_address} + Today's date is {scenario.query_date} + + When you have found the answer, use the return_final_answer_tool to + provide your final answer along with the source message IDs. + """ + ) + + # Mutable container for the final answer (captured by tool closure) + final_answer_container: dict[str, Optional[FinalAnswer]] = {"value": None} + + def on_final_answer(answer: FinalAnswer) -> None: + final_answer_container["value"] = answer + + # Create scenario-specific tools + tools = create_email_tools( + scenario=scenario, + db_path=db_path, + on_final_answer=on_final_answer, + ) + + # Initialize the chat model from ART + chat_model = init_chat_model(model.name, temperature=1.0) + + # Create the LangGraph ReAct agent + react_agent = create_react_agent(chat_model, tools) + + try: + # Run the agent + config = { + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": MAX_TURNS, + } + + await react_agent.ainvoke( + { + "messages": [ + SystemMessage(content=system_prompt), + HumanMessage(content=scenario.question), + ] + }, + config=config, + ) + + # Check if we got a final answer + if final_answer_container["value"]: + traj.final_answer = final_answer_container["value"] + # Score the trajectory using the judge + correctness_response = await judge_correctness( + scenario, + traj.final_answer.answer, + judge_model=judge_model, + ) + traj.metrics["correct"] = float(correctness_response.accept) + + except Exception as e: + print(f"Error running LangGraph agent: {e}") + traj.messages_and_choices.append( + {"role": "assistant", "content": f"Error: {str(e)}"} + ) + + return traj diff --git a/art-rl/configs/data_prep.yaml b/art-rl/configs/data_prep.yaml new file mode 100644 index 00000000..c6b177f7 --- /dev/null +++ b/art-rl/configs/data_prep.yaml @@ -0,0 +1,21 @@ +# Data preparation pipeline configuration +# +# This pipeline downloads the Enron email dataset and creates +# a searchable SQLite database. Run once before training. +# +# Usage: +# python run.py --pipeline data --config configs/data_prep.yaml + +parameters: + db_path: "./enron_emails.db" + max_body_length: 5000 + max_recipients: 30 + train_limit: 50 + test_limit: 20 + max_messages: 1 + seed: 42 + +settings: + docker: + requirements: requirements.txt + python_package_installer: uv diff --git a/art-rl/configs/deployment.yaml b/art-rl/configs/deployment.yaml new file mode 100644 index 00000000..1e375f3b --- /dev/null +++ b/art-rl/configs/deployment.yaml @@ -0,0 +1,63 @@ +# Deployment configuration for the inference pipeline +# +# Deploy the trained agent as an HTTP service for real-time queries. +# +# Usage: +# # Deploy locally +# python run.py --pipeline deploy +# +# # Or using ZenML CLI directly +# zenml pipeline deploy pipelines.inference.inference_pipeline --name art-email-agent +# +# # Invoke the deployed service +# curl -X POST http://localhost:8080/invoke \ +# -H "Content-Type: application/json" \ +# -d '{ +# "parameters": { +# "question": "What meeting is scheduled for next week?", +# "inbox_address": "john.smith@enron.com", +# "query_date": "2001-05-15" +# } +# }' + +model: + name: art-email-agent + description: "Email search agent deployed as HTTP service" + tags: + - art + - langgraph + - email-agent + - deployment + - inference + +parameters: + checkpoint_path: "./.art/checkpoints/latest" + db_path: "./enron_emails.db" + judge_model: "openai/gpt-4.1" + model_name: "art-email-agent" + project_name: "email-search-agent" + art_path: "./.art" + +settings: + deployment: + app_title: "ART Email Search Agent" + app_description: >- + Email search agent trained with OpenPipe ART + LangGraph. + Answers questions about emails using a ReAct agent pattern. + app_version: "1.0.0" + docs_url_path: "/docs" + invoke_url_path: "/invoke" + health_url_path: "/health" + cors: + allow_origins: ["*"] + allow_methods: ["GET", "POST", "OPTIONS"] + allow_headers: ["*"] + uvicorn_host: "0.0.0.0" + uvicorn_port: 8080 + + docker: + parent_image: pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime + requirements: requirements.txt + python_package_installer: uv + apt_packages: + - git diff --git a/art-rl/configs/evaluation.yaml b/art-rl/configs/evaluation.yaml new file mode 100644 index 00000000..017b6495 --- /dev/null +++ b/art-rl/configs/evaluation.yaml @@ -0,0 +1,55 @@ +# Evaluation pipeline configuration +# +# Evaluates a trained model on test scenarios. +# Requires GPU for inference. +# +# Usage: +# python run.py --pipeline eval --config configs/evaluation.yaml + +model: + name: art-email-agent + description: "Email search agent trained with ART + LangGraph" + tags: + - art + - langgraph + - email-agent + - evaluation + +parameters: + judge_model: "openai/gpt-4.1" + art_path: "./.art" + +settings: + docker: + parent_image: pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime + requirements: requirements.txt + python_package_installer: uv + apt_packages: + - git + +steps: + run_inference: + settings: + orchestrator.kubernetes: + pod_settings: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: nvidia.com/gpu + operator: Exists + resources: + limits: + nvidia.com/gpu: "1" + requests: + nvidia.com/gpu: "1" + memory: "32Gi" + volumes: + - emptyDir: + medium: Memory + sizeLimit: 16Gi + name: dshm + volume_mounts: + - mountPath: /dev/shm + name: dshm diff --git a/art-rl/configs/training_k8s.yaml b/art-rl/configs/training_k8s.yaml new file mode 100644 index 00000000..88c6ae92 --- /dev/null +++ b/art-rl/configs/training_k8s.yaml @@ -0,0 +1,69 @@ +# Kubernetes GPU training configuration +# +# For production training on a Kubernetes cluster with GPU nodes. +# Requires a GPU step operator or node affinity configuration. +# +# Usage: +# python run.py --pipeline train --config configs/training_k8s.yaml + +model: + name: art-email-agent + description: "Email search agent trained with ART + LangGraph" + tags: + - art + - langgraph + - email-agent + - rl + - kubernetes + +parameters: + model_name: "art-email-agent" + project_name: "email-search-agent" + base_model: "Qwen/Qwen2.5-7B-Instruct" + groups_per_step: 2 + num_epochs: 20 + rollouts_per_group: 4 + learning_rate: 1.0e-5 + max_steps: 20 + ruler_model: "openai/o4-mini" + art_path: "./.art" + +settings: + docker: + parent_image: pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime + requirements: requirements.txt + python_package_installer: uv + apt_packages: + - git + +steps: + train_agent: + enable_cache: false + settings: + orchestrator.kubernetes: + pod_settings: + # GPU node affinity + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: nvidia.com/gpu + operator: Exists + # GPU resource requests + resources: + limits: + nvidia.com/gpu: "1" + requests: + nvidia.com/gpu: "1" + memory: "32Gi" + cpu: "8" + # Shared memory for PyTorch DataLoader + volumes: + - emptyDir: + medium: Memory + sizeLimit: 16Gi + name: dshm + volume_mounts: + - mountPath: /dev/shm + name: dshm diff --git a/art-rl/configs/training_local.yaml b/art-rl/configs/training_local.yaml new file mode 100644 index 00000000..e1f4654b --- /dev/null +++ b/art-rl/configs/training_local.yaml @@ -0,0 +1,34 @@ +# Local GPU training configuration +# +# For development and testing on a local machine with GPU. +# Uses smaller batch sizes and fewer steps for faster iteration. +# +# Usage: +# python run.py --pipeline train --config configs/training_local.yaml + +model: + name: art-email-agent + description: "Email search agent trained with ART + LangGraph" + tags: + - art + - langgraph + - email-agent + - rl + - local + +parameters: + model_name: "art-email-agent-local" + project_name: "email-search-agent" + base_model: "Qwen/Qwen2.5-7B-Instruct" + groups_per_step: 2 + num_epochs: 5 + rollouts_per_group: 4 + learning_rate: 1.0e-5 + max_steps: 10 + ruler_model: "openai/o4-mini" + art_path: "./.art" + +settings: + docker: + requirements: requirements.txt + python_package_installer: uv diff --git a/art-rl/environment/__init__.py b/art-rl/environment/__init__.py new file mode 100644 index 00000000..c67a2847 --- /dev/null +++ b/art-rl/environment/__init__.py @@ -0,0 +1,21 @@ +# Environment utilities for email search agent +from environment.email_db import ( + create_email_database, + get_db_connection, + read_email, + search_emails, +) +from environment.models import Email, FinalAnswer, Scenario, SearchResult +from environment.tools import create_email_tools + +__all__ = [ + "Email", + "Scenario", + "SearchResult", + "FinalAnswer", + "get_db_connection", + "search_emails", + "read_email", + "create_email_database", + "create_email_tools", +] diff --git a/art-rl/environment/email_db.py b/art-rl/environment/email_db.py new file mode 100644 index 00000000..6ca4216a --- /dev/null +++ b/art-rl/environment/email_db.py @@ -0,0 +1,385 @@ +"""Email database operations using SQLite with FTS5 for full-text search.""" + +import os +import sqlite3 +from datetime import datetime +from typing import List, Optional + +from datasets import Features, Sequence, Value, load_dataset +from tqdm import tqdm + +from environment.models import Email, SearchResult + +# Database configuration +DEFAULT_DB_PATH = "./enron_emails.db" +EMAIL_DATASET_REPO_ID = "corbt/enron-emails" + +# Module-level connection cache +_db_connections: dict[str, sqlite3.Connection] = {} + + +def get_db_connection(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection: + """Get or create a database connection. + + Uses a module-level cache to reuse connections within the same process. + """ + if db_path not in _db_connections: + if not os.path.exists(db_path): + raise FileNotFoundError( + f"Database not found at {db_path}. " + "Run the data preparation pipeline first." + ) + _db_connections[db_path] = sqlite3.connect( + db_path, check_same_thread=False + ) + return _db_connections[db_path] + + +def create_email_database( + db_path: str = DEFAULT_DB_PATH, + max_body_length: int = 5000, + max_recipients: int = 30, +) -> str: + """Create the email database from Hugging Face dataset. + + Filters out emails that are too long or have too many recipients, + and deduplicates based on (subject, body, from_address). + + Returns: + Path to the created database. + """ + print("Creating email database from Hugging Face dataset...") + print("This may take several minutes for the full Enron dataset...") + + # Database schema with FTS5 for full-text search + sql_create_tables = """ + DROP TABLE IF EXISTS recipients; + DROP TABLE IF EXISTS emails_fts; + DROP TABLE IF EXISTS emails; + + CREATE TABLE emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT UNIQUE, + subject TEXT, + from_address TEXT, + date TEXT, + body TEXT, + file_name TEXT + ); + + CREATE TABLE recipients ( + email_id TEXT, + recipient_address TEXT, + recipient_type TEXT + ); + """ + + sql_create_indexes = """ + CREATE INDEX idx_emails_from ON emails(from_address); + CREATE INDEX idx_emails_date ON emails(date); + CREATE INDEX idx_emails_message_id ON emails(message_id); + CREATE INDEX idx_recipients_address ON recipients(recipient_address); + CREATE INDEX idx_recipients_type ON recipients(recipient_type); + CREATE INDEX idx_recipients_email_id ON recipients(email_id); + CREATE INDEX idx_recipients_address_email + ON recipients(recipient_address, email_id); + + CREATE VIRTUAL TABLE emails_fts USING fts5( + subject, + body, + content='emails', + content_rowid='id' + ); + + CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN + INSERT INTO emails_fts (rowid, subject, body) + VALUES (new.id, new.subject, new.body); + END; + + CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN + DELETE FROM emails_fts WHERE rowid=old.id; + END; + + CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN + UPDATE emails_fts SET subject=new.subject, body=new.body + WHERE rowid=old.id; + END; + """ + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.executescript(sql_create_tables) + conn.commit() + + # Load dataset with expected schema + print("Loading email dataset from Hugging Face...") + expected_features = Features( + { + "message_id": Value("string"), + "subject": Value("string"), + "from": Value("string"), + "to": Sequence(Value("string")), + "cc": Sequence(Value("string")), + "bcc": Sequence(Value("string")), + "date": Value("timestamp[us]"), + "body": Value("string"), + "file_name": Value("string"), + } + ) + + dataset = load_dataset( + EMAIL_DATASET_REPO_ID, features=expected_features, split="train" + ) + print(f"Dataset contains {len(dataset)} total emails") + + # Populate database with filtering and deduplication + print("Populating database...") + conn.execute("PRAGMA synchronous = OFF;") + conn.execute("PRAGMA journal_mode = MEMORY;") + conn.execute("BEGIN TRANSACTION;") + + record_count = 0 + skipped_count = 0 + duplicate_count = 0 + processed_emails: set[tuple] = set() + + for email_data in tqdm(dataset, desc="Inserting emails"): + message_id = email_data["message_id"] + subject = email_data["subject"] + from_address = email_data["from"] + date_obj: datetime = email_data["date"] + body = email_data["body"] + file_name = email_data["file_name"] + to_list = [str(addr) for addr in email_data["to"] if addr] + cc_list = [str(addr) for addr in email_data["cc"] if addr] + bcc_list = [str(addr) for addr in email_data["bcc"] if addr] + + total_recipients = len(to_list) + len(cc_list) + len(bcc_list) + + # Filter out very long emails and those with too many recipients + if len(body) > max_body_length: + skipped_count += 1 + continue + + if total_recipients > max_recipients: + skipped_count += 1 + continue + + # Deduplication check + email_key = (subject, body, from_address) + if email_key in processed_emails: + duplicate_count += 1 + continue + processed_emails.add(email_key) + + date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S") + + cursor.execute( + """ + INSERT INTO emails + (message_id, subject, from_address, date, body, file_name) + VALUES (?, ?, ?, ?, ?, ?) + """, + (message_id, subject, from_address, date_str, body, file_name), + ) + + # Insert recipients + recipient_data = [] + for addr in to_list: + recipient_data.append((message_id, addr, "to")) + for addr in cc_list: + recipient_data.append((message_id, addr, "cc")) + for addr in bcc_list: + recipient_data.append((message_id, addr, "bcc")) + + if recipient_data: + cursor.executemany( + """ + INSERT INTO recipients + (email_id, recipient_address, recipient_type) + VALUES (?, ?, ?) + """, + recipient_data, + ) + + record_count += 1 + + conn.commit() + + # Create indexes and FTS triggers + print("Creating indexes and FTS...") + cursor.executescript(sql_create_indexes) + cursor.execute('INSERT INTO emails_fts(emails_fts) VALUES("rebuild")') + conn.commit() + + print(f"Successfully created database with {record_count} emails.") + print(f"Skipped {skipped_count} emails due to length/recipient limits.") + print(f"Skipped {duplicate_count} duplicate emails.") + + # Cache the connection + _db_connections[db_path] = conn + + return db_path + + +def search_emails( + inbox: str, + keywords: List[str], + db_path: str = DEFAULT_DB_PATH, + from_addr: Optional[str] = None, + to_addr: Optional[str] = None, + sent_after: Optional[str] = None, + sent_before: Optional[str] = None, + max_results: int = 10, +) -> List[SearchResult]: + """Search the email database based on keywords and filters. + + Args: + inbox: Email address of the inbox to search. + keywords: List of keywords to search for (AND logic). + db_path: Path to the SQLite database. + from_addr: Optional filter for sender address. + to_addr: Optional filter for recipient address. + sent_after: Optional date filter (YYYY-MM-DD). + sent_before: Optional date filter (YYYY-MM-DD). + max_results: Maximum number of results to return. + + Returns: + List of SearchResult objects with message_id and snippet. + """ + conn = get_db_connection(db_path) + cursor = conn.cursor() + + where_clauses: List[str] = [] + params: List[str | int] = [] + + if not keywords: + raise ValueError("No keywords provided for search.") + + if max_results > 10: + raise ValueError("max_results must be less than or equal to 10.") + + # FTS5 query - escape quotes for safety + fts_query = " ".join( + f'"{k.replace(chr(34), chr(34)+chr(34))}"' for k in keywords + ) + where_clauses.append("fts.emails_fts MATCH ?") + params.append(fts_query) + + # Inbox filter - emails sent by or received by this address + where_clauses.append( + """ + (e.from_address = ? OR EXISTS ( + SELECT 1 FROM recipients r_inbox + WHERE r_inbox.recipient_address = ? + AND r_inbox.email_id = e.message_id + )) + """ + ) + params.extend([inbox, inbox]) + + if from_addr: + where_clauses.append("e.from_address = ?") + params.append(from_addr) + + if to_addr: + where_clauses.append( + """ + EXISTS ( + SELECT 1 FROM recipients r_to + WHERE r_to.recipient_address = ? + AND r_to.email_id = e.message_id + ) + """ + ) + params.append(to_addr) + + if sent_after: + where_clauses.append("e.date >= ?") + params.append(f"{sent_after} 00:00:00") + + if sent_before: + where_clauses.append("e.date < ?") + params.append(f"{sent_before} 00:00:00") + + sql = f""" + SELECT + e.message_id, + snippet(emails_fts, -1, '', '', ' ... ', 15) as snippet + FROM + emails e JOIN emails_fts fts ON e.id = fts.rowid + WHERE + {" AND ".join(where_clauses)} + ORDER BY + e.date DESC + LIMIT ?; + """ + params.append(max_results) + + cursor.execute(sql, params) + results = cursor.fetchall() + + return [SearchResult(message_id=row[0], snippet=row[1]) for row in results] + + +def read_email( + message_id: str, + db_path: str = DEFAULT_DB_PATH, +) -> Optional[Email]: + """Retrieve a single email by its message_id. + + Args: + message_id: The unique message ID of the email. + db_path: Path to the SQLite database. + + Returns: + Email object if found, None otherwise. + """ + conn = get_db_connection(db_path) + cursor = conn.cursor() + + cursor.execute( + """ + SELECT message_id, date, subject, from_address, body, file_name + FROM emails WHERE message_id = ? + """, + (message_id,), + ) + email_row = cursor.fetchone() + + if not email_row: + return None + + msg_id, date, subject, from_addr, body, file_name = email_row + + # Get recipients + cursor.execute( + "SELECT recipient_address, recipient_type FROM recipients " + "WHERE email_id = ?", + (message_id,), + ) + recipient_rows = cursor.fetchall() + + to_addresses = [] + cc_addresses = [] + bcc_addresses = [] + + for addr, type_val in recipient_rows: + if type_val.lower() == "to": + to_addresses.append(addr) + elif type_val.lower() == "cc": + cc_addresses.append(addr) + elif type_val.lower() == "bcc": + bcc_addresses.append(addr) + + return Email( + message_id=msg_id, + date=date, + subject=subject, + from_address=from_addr, + to_addresses=to_addresses, + cc_addresses=cc_addresses, + bcc_addresses=bcc_addresses, + body=body, + file_name=file_name, + ) diff --git a/art-rl/environment/models.py b/art-rl/environment/models.py new file mode 100644 index 00000000..07529734 --- /dev/null +++ b/art-rl/environment/models.py @@ -0,0 +1,55 @@ +"""Data models for the email search agent environment.""" + +from dataclasses import dataclass +from typing import List, Literal, Optional + +from pydantic import BaseModel, Field + + +class Email(BaseModel): + """Represents an email from the Enron dataset.""" + + message_id: str + date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS' + subject: Optional[str] = None + from_address: Optional[str] = None + to_addresses: List[str] = Field(default_factory=list) + cc_addresses: List[str] = Field(default_factory=list) + bcc_addresses: List[str] = Field(default_factory=list) + body: Optional[str] = None + file_name: Optional[str] = None + + +class Scenario(BaseModel): + """A question-answer scenario for training/evaluation.""" + + id: int + question: str + answer: str + message_ids: List[str] # message_ids of referenced emails + how_realistic: float + inbox_address: str + query_date: str + split: Literal["train", "test"] + + +@dataclass +class SearchResult: + """Result from searching the email database.""" + + message_id: str + snippet: str + + +class FinalAnswer(BaseModel): + """The agent's final answer with source references.""" + + answer: str + source_ids: List[str] + + +class EmailScenario(BaseModel): + """Wrapper for scenario with training step info.""" + + step: int + scenario: Scenario diff --git a/art-rl/environment/tools.py b/art-rl/environment/tools.py new file mode 100644 index 00000000..efc94113 --- /dev/null +++ b/art-rl/environment/tools.py @@ -0,0 +1,88 @@ +"""LangGraph tools for the email search agent.""" + +from dataclasses import asdict +from typing import Callable, List, Optional + +from langchain_core.tools import tool + +from environment.email_db import read_email, search_emails +from environment.models import FinalAnswer, Scenario + + +def create_email_tools( + scenario: Scenario, + db_path: str, + on_final_answer: Optional[Callable[[FinalAnswer], None]] = None, +) -> List: + """Create LangGraph tools for a specific scenario. + + The tools are scenario-specific because they need access to the + inbox address and query date for filtering. + + Args: + scenario: The current scenario being processed. + db_path: Path to the email database. + on_final_answer: Callback invoked when the agent provides a final answer. + + Returns: + List of LangChain tools for the LangGraph agent. + """ + + @tool + def search_inbox_tool(keywords: List[str]) -> List[dict]: + """Search the inbox for emails matching the given keywords. + + Args: + keywords: List of keywords to search for (uses AND logic). + + Returns: + List of search results with message_id and snippet. + """ + results = search_emails( + inbox=scenario.inbox_address, + keywords=keywords, + db_path=db_path, + sent_before=scenario.query_date, + ) + return [asdict(result) for result in results] + + @tool + def read_email_tool(message_id: str) -> Optional[dict]: + """Read a specific email by message ID. + + Args: + message_id: The unique identifier of the email to read. + + Returns: + Email content as a dictionary, or None if not found. + """ + email = read_email(message_id, db_path=db_path) + if email: + return email.model_dump() + return None + + @tool + def return_final_answer_tool( + answer: str, + reference_message_ids: List[str], + ) -> dict: + """Return the final answer with source references. + + Use this tool when you have found the answer to the user's question. + + Args: + answer: The answer to the user's question. + reference_message_ids: List of message IDs that support the answer. + + Returns: + The final answer as a dictionary. + """ + final_answer = FinalAnswer( + answer=answer, + source_ids=reference_message_ids, + ) + if on_final_answer: + on_final_answer(final_answer) + return final_answer.model_dump() + + return [search_inbox_tool, read_email_tool, return_final_answer_tool] diff --git a/art-rl/materializers/__init__.py b/art-rl/materializers/__init__.py new file mode 100644 index 00000000..a608af4b --- /dev/null +++ b/art-rl/materializers/__init__.py @@ -0,0 +1,2 @@ +# Custom materializers for the email search agent +# (Reserved for future checkpoint materializers if needed) diff --git a/art-rl/pipelines/__init__.py b/art-rl/pipelines/__init__.py new file mode 100644 index 00000000..33345da7 --- /dev/null +++ b/art-rl/pipelines/__init__.py @@ -0,0 +1,12 @@ +# ZenML pipelines for the email search agent +from pipelines.data_preparation import data_preparation_pipeline +from pipelines.evaluation import evaluation_pipeline +from pipelines.inference import inference_pipeline +from pipelines.training import training_pipeline + +__all__ = [ + "data_preparation_pipeline", + "training_pipeline", + "evaluation_pipeline", + "inference_pipeline", +] diff --git a/art-rl/pipelines/data_preparation.py b/art-rl/pipelines/data_preparation.py new file mode 100644 index 00000000..f818af9a --- /dev/null +++ b/art-rl/pipelines/data_preparation.py @@ -0,0 +1,68 @@ +"""Data preparation pipeline for the email search agent. + +This pipeline downloads and prepares all data artifacts needed for training: +1. Downloads the Enron email dataset from Hugging Face +2. Creates a SQLite database with FTS5 for fast email search +3. Loads Q&A scenarios for training and testing + +Run this pipeline once before training - artifacts are cached and reused. +""" + +from steps.data import ( + create_database, + download_enron_data, + load_scenarios, +) +from zenml import Model, pipeline + + +@pipeline( + model=Model( + name="art-email-agent", + description="Email search agent trained with ART + LangGraph", + tags=["art", "langgraph", "email-agent", "rl"], + ), +) +def data_preparation_pipeline( + db_path: str = "./enron_emails.db", + max_body_length: int = 5000, + max_recipients: int = 30, + train_limit: int = 50, + test_limit: int = 20, + max_messages: int = 1, + seed: int = 42, +): + """Prepare data artifacts for the email search agent. + + This pipeline is designed to run once and cache all artifacts. + Subsequent training runs will reuse these cached artifacts. + + Args: + db_path: Path for the SQLite email database. + max_body_length: Filter out emails longer than this. + max_recipients: Filter out emails with more recipients. + train_limit: Maximum training scenarios to load. + test_limit: Maximum test scenarios to load. + max_messages: Filter to scenarios with at most this many source emails. + seed: Random seed for reproducible scenario shuffling. + """ + # Step 1: Download raw emails from Hugging Face + raw_emails = download_enron_data() + + # Step 2: Create searchable SQLite database + db_path_out = create_database( + raw_emails=raw_emails, + db_path=db_path, + max_body_length=max_body_length, + max_recipients=max_recipients, + ) + + # Step 3: Load training and test scenarios + train_scenarios, test_scenarios = load_scenarios( + train_limit=train_limit, + test_limit=test_limit, + max_messages=max_messages, + seed=seed, + ) + + return db_path_out, train_scenarios, test_scenarios diff --git a/art-rl/pipelines/evaluation.py b/art-rl/pipelines/evaluation.py new file mode 100644 index 00000000..df97224d --- /dev/null +++ b/art-rl/pipelines/evaluation.py @@ -0,0 +1,74 @@ +"""Evaluation pipeline for the email search agent. + +This pipeline evaluates a trained model on test scenarios: +1. Loads the trained model from checkpoint +2. Runs inference on each test scenario +3. Computes accuracy and other metrics + +Requires GPU resources for inference. +""" + +from typing import List + +from environment.models import Scenario +from steps.evaluation import ( + compute_metrics, + load_trained_model, + run_inference, +) +from zenml import Model, pipeline +from zenml.config import DockerSettings + +docker_settings = DockerSettings( + parent_image="pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime", + requirements="requirements.txt", + python_package_installer="uv", + apt_packages=["git"], +) + + +@pipeline( + model=Model( + name="art-email-agent", + description="Email search agent trained with ART + LangGraph", + tags=["art", "langgraph", "email-agent", "rl"], + ), + settings={"docker": docker_settings}, +) +def evaluation_pipeline( + model_config: dict, + checkpoint_path: str, + test_scenarios: List[Scenario], + db_path: str, + judge_model: str = "openai/gpt-4.1", + art_path: str = "./.art", +): + """Evaluate the trained email search agent. + + Args: + model_config: Model configuration from training. + checkpoint_path: Path to the trained checkpoint. + test_scenarios: Test scenarios from data preparation. + db_path: Path to the email database. + judge_model: LiteLLM model for correctness judging. + art_path: Directory for ART files. + """ + # Step 1: Prepare model loading configuration + inference_config = load_trained_model( + model_config=model_config, + checkpoint_path=checkpoint_path, + art_path=art_path, + ) + + # Step 2: Run inference on test scenarios (requires GPU) + predictions = run_inference( + inference_config=inference_config, + test_scenarios=test_scenarios, + db_path=db_path, + judge_model=judge_model, + ) + + # Step 3: Compute metrics + metrics = compute_metrics(predictions=predictions) + + return predictions, metrics diff --git a/art-rl/pipelines/inference.py b/art-rl/pipelines/inference.py new file mode 100644 index 00000000..e9afb197 --- /dev/null +++ b/art-rl/pipelines/inference.py @@ -0,0 +1,92 @@ +"""Inference pipeline for deployment as HTTP service. + +This pipeline is designed to be deployed via ZenML Pipeline Deployments, +exposing the trained email agent as an HTTP endpoint for real-time queries. +""" + +from steps.inference import run_single_inference +from zenml import Model, pipeline +from zenml.config import DeploymentSettings + +# Deployment settings for the HTTP service +deployment_settings = DeploymentSettings( + app_title="ART Email Search Agent", + app_description=( + "Email search agent trained with OpenPipe ART + LangGraph. " + "Answers questions about emails using a ReAct agent pattern." + ), + app_version="1.0.0", + docs_url_path="/docs", + invoke_url_path="/invoke", + health_url_path="/health", + cors={ + "allow_origins": ["*"], + "allow_methods": ["GET", "POST", "OPTIONS"], + "allow_headers": ["*"], + }, + uvicorn_host="0.0.0.0", + uvicorn_port=8080, +) + + +@pipeline( + model=Model( + name="art-email-agent", + description="Email search agent for inference", + tags=["art", "langgraph", "inference", "deployment"], + ), + settings={"deployment": deployment_settings}, +) +def inference_pipeline( + question: str, + inbox_address: str, + query_date: str, + checkpoint_path: str = "./.art/checkpoints/latest", + db_path: str = "./enron_emails.db", + judge_model: str = "openai/gpt-4.1", + model_name: str = "art-email-agent", + project_name: str = "email-search-agent", + art_path: str = "./.art", +) -> dict: + """Inference pipeline for single email search queries. + + This pipeline can be deployed as an HTTP service using ZenML's + Pipeline Deployments feature. Each invocation runs a single query + through the trained agent. + + Example invocation via HTTP: + POST /invoke + { + "parameters": { + "question": "What meeting is scheduled for next week?", + "inbox_address": "john.smith@enron.com", + "query_date": "2001-05-15" + } + } + + Args: + question: The question to answer about the user's emails. + inbox_address: The email address of the inbox to search. + query_date: The reference date for the query (YYYY-MM-DD). + checkpoint_path: Path to the trained model checkpoint. + db_path: Path to the email database. + judge_model: LiteLLM model ID for correctness judging. + model_name: Name of the ART model. + project_name: Name of the ART project. + art_path: Path to ART artifacts directory. + + Returns: + Dictionary with the agent's answer and metadata. + """ + result = run_single_inference( + question=question, + inbox_address=inbox_address, + query_date=query_date, + checkpoint_path=checkpoint_path, + db_path=db_path, + judge_model=judge_model, + model_name=model_name, + project_name=project_name, + art_path=art_path, + ) + return result diff --git a/art-rl/pipelines/training.py b/art-rl/pipelines/training.py new file mode 100644 index 00000000..0a59f85e --- /dev/null +++ b/art-rl/pipelines/training.py @@ -0,0 +1,90 @@ +"""Training pipeline for the email search agent. + +This pipeline runs the ART training loop using: +- GRPO (Group Relative Policy Optimization) for policy updates +- RULER for relative trajectory scoring +- LangGraph ReAct agents for executing rollouts + +Requires GPU resources for the train_agent step. +""" + +from typing import List + +from environment.models import Scenario +from steps.training import setup_art_model, train_agent +from zenml import Model, pipeline +from zenml.config import DockerSettings + +# Docker settings for GPU training +docker_settings = DockerSettings( + parent_image="pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime", + requirements="requirements.txt", + python_package_installer="uv", + apt_packages=["git"], +) + + +@pipeline( + model=Model( + name="art-email-agent", + description="Email search agent trained with ART + LangGraph", + tags=["art", "langgraph", "email-agent", "rl"], + ), + settings={"docker": docker_settings}, +) +def training_pipeline( + train_scenarios: List[Scenario], + db_path: str, + model_name: str = "art-email-agent", + project_name: str = "email-search-agent", + base_model: str = "Qwen/Qwen2.5-7B-Instruct", + groups_per_step: int = 2, + num_epochs: int = 20, + rollouts_per_group: int = 4, + learning_rate: float = 1e-5, + max_steps: int = 20, + ruler_model: str = "openai/o4-mini", + art_path: str = "./.art", +): + """Train the email search agent using ART. + + This pipeline: + 1. Configures the ART model + 2. Runs the training loop with GRPO and RULER + + Args: + train_scenarios: Training scenarios from data preparation. + db_path: Path to the email database. + model_name: Name for the trained model. + project_name: Project name for experiment tracking. + base_model: Hugging Face model ID for the base model. + groups_per_step: Scenario groups per training step. + num_epochs: Number of passes through training data. + rollouts_per_group: Rollouts per scenario for GRPO. + learning_rate: Optimizer learning rate. + max_steps: Maximum training steps. + ruler_model: LiteLLM model for RULER scoring. + art_path: Directory for ART checkpoints. + """ + # Step 1: Configure the model + model_config = setup_art_model( + model_name=model_name, + project_name=project_name, + base_model=base_model, + ) + + # Step 2: Run training (requires GPU) + checkpoint_path, training_metrics = train_agent( + model_config=model_config, + train_scenarios=train_scenarios, + db_path=db_path, + groups_per_step=groups_per_step, + num_epochs=num_epochs, + rollouts_per_group=rollouts_per_group, + learning_rate=learning_rate, + max_steps=max_steps, + ruler_model=ruler_model, + art_path=art_path, + ) + + return checkpoint_path, training_metrics diff --git a/art-rl/requirements.txt b/art-rl/requirements.txt new file mode 100644 index 00000000..8ad55fe1 --- /dev/null +++ b/art-rl/requirements.txt @@ -0,0 +1,22 @@ +# Core dependencies +zenml>=0.70.0 + +# ART (Agentic Reinforcement Training) with LangGraph integration +openpipe-art[backend,langgraph]>=0.4.11 + +# LangChain/LangGraph for agent framework +langchain-core>=0.3.0 +langgraph>=0.2.0 +langchain-openai>=0.2.0 + +# LLM inference +litellm>=1.50.0 + +# Data handling +datasets>=3.0.0 +pydantic>=2.0.0 + +# Utilities +tenacity>=8.0.0 +tqdm>=4.60.0 +python-dotenv>=1.0.0 diff --git a/art-rl/run.py b/art-rl/run.py new file mode 100644 index 00000000..f44b07f4 --- /dev/null +++ b/art-rl/run.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +"""CLI entry point for the ART Email Search Agent project. + +This script provides commands to run the different pipelines: +- data: Prepare data artifacts (download emails, create database, load scenarios) +- train: Train the agent using ART with GRPO and RULER +- eval: Evaluate a trained model on test scenarios +- all: Run the complete workflow (data → train → eval) +- deploy: Deploy the inference pipeline as an HTTP service + +Examples: + # Prepare data (run once, artifacts are cached) + python run.py --pipeline data + + # Train with local GPU + python run.py --pipeline train --config configs/training_local.yaml + + # Train on Kubernetes + python run.py --pipeline train --config configs/training_k8s.yaml + + # Evaluate trained model + python run.py --pipeline eval + + # Run complete workflow + python run.py --pipeline all + + # Deploy as HTTP service + python run.py --pipeline deploy --name my-agent-service +""" + +import argparse +import os +from typing import Optional + +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + + +def check_environment(): + """Verify required environment variables are set.""" + if not os.environ.get("OPENAI_API_KEY"): + print("WARNING: OPENAI_API_KEY not set. Required for RULER scoring.") + print("Set it in .env file or export OPENAI_API_KEY=your-key") + + if not os.environ.get("WANDB_API_KEY"): + print( + "INFO: WANDB_API_KEY not set. Weights & Biases logging disabled." + ) + + +def run_data_pipeline(config_path: Optional[str] = None): + """Run the data preparation pipeline.""" + from pipelines.data_preparation import data_preparation_pipeline + + print("=" * 60) + print("Running Data Preparation Pipeline") + print("=" * 60) + + if config_path: + data_preparation_pipeline.with_options(config_path=config_path)() + else: + data_preparation_pipeline() + + print("\nData preparation complete!") + print("Artifacts are cached and will be reused in subsequent runs.") + + +def run_training_pipeline( + config_path: Optional[str] = None, + no_cache: bool = False, +): + """Run the training pipeline.""" + from pipelines.training import training_pipeline + from zenml.client import Client + + print("=" * 60) + print("Running Training Pipeline") + print("=" * 60) + + # Get artifacts from most recent data preparation run + client = Client() + + try: + # Find the most recent data preparation pipeline run + runs = client.list_pipeline_runs( + pipeline_name_or_id="data_preparation_pipeline", + sort_by="desc:created", + size=1, + ) + + if not runs.items: + print("ERROR: No data preparation run found.") + print("Please run: python run.py --pipeline data") + return + + latest_run = runs.items[0] + print(f"Using data from run: {latest_run.id}") + + # Get artifacts from the run + db_path = latest_run.steps["create_database"].output.load() + train_scenarios = ( + latest_run.steps["load_scenarios"] + .outputs["train_scenarios"] + .load() + ) + + print(f"Loaded {len(train_scenarios)} training scenarios") + print(f"Database path: {db_path}") + + except Exception as e: + print(f"ERROR: Failed to load data artifacts: {e}") + print("Please run: python run.py --pipeline data") + return + + # Run training + pipeline_instance = training_pipeline + + if config_path: + pipeline_instance = pipeline_instance.with_options( + config_path=config_path + ) + + if no_cache: + pipeline_instance = pipeline_instance.with_options(enable_cache=False) + + pipeline_instance( + train_scenarios=train_scenarios, + db_path=db_path, + ) + + print("\nTraining complete!") + + +def run_evaluation_pipeline( + config_path: Optional[str] = None, + checkpoint_path: Optional[str] = None, +): + """Run the evaluation pipeline.""" + from pipelines.evaluation import evaluation_pipeline + from zenml.client import Client + + print("=" * 60) + print("Running Evaluation Pipeline") + print("=" * 60) + + client = Client() + + # Get test scenarios from data preparation + try: + data_runs = client.list_pipeline_runs( + pipeline_name_or_id="data_preparation_pipeline", + sort_by="desc:created", + size=1, + ) + + if not data_runs.items: + print("ERROR: No data preparation run found.") + print("Please run: python run.py --pipeline data") + return + + data_run = data_runs.items[0] + db_path = data_run.steps["create_database"].output.load() + test_scenarios = ( + data_run.steps["load_scenarios"].outputs["test_scenarios"].load() + ) + + print(f"Loaded {len(test_scenarios)} test scenarios") + + except Exception as e: + print(f"ERROR: Failed to load data artifacts: {e}") + return + + # Get checkpoint from training + if not checkpoint_path: + try: + train_runs = client.list_pipeline_runs( + pipeline_name_or_id="training_pipeline", + sort_by="desc:created", + size=1, + ) + + if not train_runs.items: + print("ERROR: No training run found.") + print("Please run: python run.py --pipeline train") + return + + train_run = train_runs.items[0] + checkpoint_path = ( + train_run.steps["train_agent"] + .outputs["checkpoint_path"] + .load() + ) + model_config = train_run.steps["setup_art_model"].output.load() + + print(f"Using checkpoint: {checkpoint_path}") + + except Exception as e: + print(f"ERROR: Failed to load training artifacts: {e}") + return + else: + # Use default model config if checkpoint provided manually + model_config = { + "name": "art-email-agent", + "project": "email-search-agent", + "base_model": checkpoint_path, + } + + # Run evaluation + pipeline_instance = evaluation_pipeline + + if config_path: + pipeline_instance = pipeline_instance.with_options( + config_path=config_path + ) + + pipeline_instance( + model_config=model_config, + checkpoint_path=checkpoint_path, + test_scenarios=test_scenarios, + db_path=db_path, + ) + + print("\nEvaluation complete!") + + +def run_all_pipelines(config_path: Optional[str] = None): + """Run the complete workflow: data → train → eval.""" + print("=" * 60) + print("Running Complete Workflow") + print("=" * 60) + + run_data_pipeline() + run_training_pipeline(config_path=config_path) + run_evaluation_pipeline() + + print("\n" + "=" * 60) + print("Complete workflow finished!") + print("=" * 60) + + +def run_deploy( + deployment_name: str = "art-email-agent", + config_path: Optional[str] = None, +): + """Deploy the inference pipeline as an HTTP service.""" + from pipelines.inference import inference_pipeline + + print("=" * 60) + print("Deploying Inference Pipeline") + print("=" * 60) + + pipeline_instance = inference_pipeline + + if config_path: + pipeline_instance = pipeline_instance.with_options( + config_path=config_path + ) + + # Deploy the pipeline as an HTTP service + deployment = pipeline_instance.deploy(deployment_name=deployment_name) + + print(f"\nDeployment '{deployment_name}' created successfully!") + print(f"URL: {deployment.url}") + print("\nExample invocation:") + print(f" curl -X POST {deployment.url}/invoke \\") + print(' -H "Content-Type: application/json" \\') + print(" -d '{") + print(' "parameters": {') + print(' "question": "What meeting is scheduled?",') + print(' "inbox_address": "john.smith@enron.com",') + print(' "query_date": "2001-05-15"') + print(" }") + print(" }'") + print("\nManagement commands:") + print(f" zenml deployment describe {deployment_name}") + print(f" zenml deployment logs {deployment_name} -f") + print(f" zenml deployment deprovision {deployment_name}") + + +def main(): + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + description="ART Email Search Agent - ZenML Pipeline Runner", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument( + "--pipeline", + "-p", + choices=["data", "train", "eval", "all", "deploy"], + required=True, + help="Pipeline to run", + ) + + parser.add_argument( + "--config", + "-c", + type=str, + help="Path to YAML configuration file", + ) + + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable ZenML caching", + ) + + parser.add_argument( + "--checkpoint", + type=str, + help="Path to model checkpoint (for eval pipeline)", + ) + + parser.add_argument( + "--name", + type=str, + default="art-email-agent", + help="Deployment name (for deploy pipeline)", + ) + + args = parser.parse_args() + + # Check environment + check_environment() + + # Run the selected pipeline + if args.pipeline == "data": + run_data_pipeline(config_path=args.config) + + elif args.pipeline == "train": + run_training_pipeline( + config_path=args.config, + no_cache=args.no_cache, + ) + + elif args.pipeline == "eval": + run_evaluation_pipeline( + config_path=args.config, + checkpoint_path=args.checkpoint, + ) + + elif args.pipeline == "all": + run_all_pipelines(config_path=args.config) + + elif args.pipeline == "deploy": + run_deploy( + deployment_name=args.name, + config_path=args.config, + ) + + +if __name__ == "__main__": + main() diff --git a/art-rl/steps/__init__.py b/art-rl/steps/__init__.py new file mode 100644 index 00000000..891dd29d --- /dev/null +++ b/art-rl/steps/__init__.py @@ -0,0 +1,25 @@ +# ZenML steps for the email search agent +from steps.data import ( + create_database, + download_enron_data, + load_scenarios, +) +from steps.evaluation import compute_metrics, load_trained_model, run_inference +from steps.inference import run_single_inference +from steps.training import setup_art_model, train_agent + +__all__ = [ + # Data steps + "download_enron_data", + "create_database", + "load_scenarios", + # Training steps + "setup_art_model", + "train_agent", + # Evaluation steps + "load_trained_model", + "run_inference", + "compute_metrics", + # Inference steps + "run_single_inference", +] diff --git a/art-rl/steps/data/__init__.py b/art-rl/steps/data/__init__.py new file mode 100644 index 00000000..179503a1 --- /dev/null +++ b/art-rl/steps/data/__init__.py @@ -0,0 +1,10 @@ +# Data preparation steps +from steps.data.create_database import create_database +from steps.data.download_enron import download_enron_data +from steps.data.load_scenarios import load_scenarios + +__all__ = [ + "download_enron_data", + "create_database", + "load_scenarios", +] diff --git a/art-rl/steps/data/create_database.py b/art-rl/steps/data/create_database.py new file mode 100644 index 00000000..d9b084f4 --- /dev/null +++ b/art-rl/steps/data/create_database.py @@ -0,0 +1,192 @@ +"""Step to create the SQLite email database with FTS5 search.""" + +import os +import sqlite3 +from datetime import datetime +from typing import Annotated + +from datasets import Dataset +from tqdm import tqdm +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def create_database( + raw_emails: Dataset, + db_path: str = "./enron_emails.db", + max_body_length: int = 5000, + max_recipients: int = 30, +) -> Annotated[str, "db_path"]: + """Create a SQLite database with FTS5 full-text search from raw emails. + + This step processes the raw email dataset and creates an optimized + SQLite database with: + - Emails table with core email fields + - Recipients table for to/cc/bcc addresses + - FTS5 virtual table for fast full-text search + - Indexes for common query patterns + + Emails are filtered to exclude: + - Very long emails (> max_body_length characters) + - Emails with too many recipients (> max_recipients) + - Duplicate emails (same subject, body, and sender) + + Args: + raw_emails: The raw email dataset from Hugging Face. + db_path: Path where the SQLite database will be created. + max_body_length: Maximum allowed email body length. + max_recipients: Maximum allowed total recipients. + + Returns: + Path to the created database file. + """ + logger.info(f"Creating email database at {db_path}...") + + # Remove existing database if present + if os.path.exists(db_path): + os.remove(db_path) + + # Database schema + sql_create_tables = """ + CREATE TABLE emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT UNIQUE, + subject TEXT, + from_address TEXT, + date TEXT, + body TEXT, + file_name TEXT + ); + + CREATE TABLE recipients ( + email_id TEXT, + recipient_address TEXT, + recipient_type TEXT + ); + """ + + sql_create_indexes = """ + CREATE INDEX idx_emails_from ON emails(from_address); + CREATE INDEX idx_emails_date ON emails(date); + CREATE INDEX idx_emails_message_id ON emails(message_id); + CREATE INDEX idx_recipients_address ON recipients(recipient_address); + CREATE INDEX idx_recipients_type ON recipients(recipient_type); + CREATE INDEX idx_recipients_email_id ON recipients(email_id); + CREATE INDEX idx_recipients_address_email + ON recipients(recipient_address, email_id); + + CREATE VIRTUAL TABLE emails_fts USING fts5( + subject, + body, + content='emails', + content_rowid='id' + ); + + CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN + INSERT INTO emails_fts (rowid, subject, body) + VALUES (new.id, new.subject, new.body); + END; + + CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN + DELETE FROM emails_fts WHERE rowid=old.id; + END; + + CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN + UPDATE emails_fts SET subject=new.subject, body=new.body + WHERE rowid=old.id; + END; + """ + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.executescript(sql_create_tables) + conn.commit() + + # Optimize for bulk inserts + conn.execute("PRAGMA synchronous = OFF;") + conn.execute("PRAGMA journal_mode = MEMORY;") + conn.execute("BEGIN TRANSACTION;") + + record_count = 0 + skipped_count = 0 + duplicate_count = 0 + processed_emails: set[tuple] = set() + + for email_data in tqdm(raw_emails, desc="Inserting emails"): + message_id = email_data["message_id"] + subject = email_data["subject"] + from_address = email_data["from"] + date_obj: datetime = email_data["date"] + body = email_data["body"] + file_name = email_data["file_name"] + to_list = [str(addr) for addr in email_data["to"] if addr] + cc_list = [str(addr) for addr in email_data["cc"] if addr] + bcc_list = [str(addr) for addr in email_data["bcc"] if addr] + + total_recipients = len(to_list) + len(cc_list) + len(bcc_list) + + # Apply filters + if len(body) > max_body_length: + skipped_count += 1 + continue + + if total_recipients > max_recipients: + skipped_count += 1 + continue + + # Deduplication + email_key = (subject, body, from_address) + if email_key in processed_emails: + duplicate_count += 1 + continue + processed_emails.add(email_key) + + date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S") + + cursor.execute( + """ + INSERT INTO emails + (message_id, subject, from_address, date, body, file_name) + VALUES (?, ?, ?, ?, ?, ?) + """, + (message_id, subject, from_address, date_str, body, file_name), + ) + + # Insert recipients + recipient_data = [] + for addr in to_list: + recipient_data.append((message_id, addr, "to")) + for addr in cc_list: + recipient_data.append((message_id, addr, "cc")) + for addr in bcc_list: + recipient_data.append((message_id, addr, "bcc")) + + if recipient_data: + cursor.executemany( + """ + INSERT INTO recipients + (email_id, recipient_address, recipient_type) + VALUES (?, ?, ?) + """, + recipient_data, + ) + + record_count += 1 + + conn.commit() + + # Create indexes and FTS + logger.info("Creating indexes and FTS virtual table...") + cursor.executescript(sql_create_indexes) + cursor.execute('INSERT INTO emails_fts(emails_fts) VALUES("rebuild")') + conn.commit() + conn.close() + + logger.info(f"Created database with {record_count} emails") + logger.info(f"Skipped {skipped_count} emails (length/recipient limits)") + logger.info(f"Skipped {duplicate_count} duplicate emails") + + return db_path diff --git a/art-rl/steps/data/download_enron.py b/art-rl/steps/data/download_enron.py new file mode 100644 index 00000000..7f7bfcb7 --- /dev/null +++ b/art-rl/steps/data/download_enron.py @@ -0,0 +1,49 @@ +"""Step to download the Enron email dataset from Hugging Face.""" + +from typing import Annotated + +from datasets import Dataset, Features, Sequence, Value, load_dataset +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + +EMAIL_DATASET_REPO_ID = "corbt/enron-emails" + + +@step +def download_enron_data() -> Annotated[Dataset, "raw_emails"]: + """Download the Enron email dataset from Hugging Face. + + This step fetches the complete Enron email corpus, which contains + approximately 500,000 emails from Enron employees. The dataset is + used to create a searchable email database for the agent. + + Returns: + The raw email dataset from Hugging Face. + """ + logger.info(f"Downloading email dataset from {EMAIL_DATASET_REPO_ID}...") + + # Define expected schema for type safety + expected_features = Features( + { + "message_id": Value("string"), + "subject": Value("string"), + "from": Value("string"), + "to": Sequence(Value("string")), + "cc": Sequence(Value("string")), + "bcc": Sequence(Value("string")), + "date": Value("timestamp[us]"), + "body": Value("string"), + "file_name": Value("string"), + } + ) + + dataset = load_dataset( + EMAIL_DATASET_REPO_ID, + features=expected_features, + split="train", + ) + + logger.info(f"Downloaded {len(dataset)} emails") + return dataset diff --git a/art-rl/steps/data/load_scenarios.py b/art-rl/steps/data/load_scenarios.py new file mode 100644 index 00000000..13032636 --- /dev/null +++ b/art-rl/steps/data/load_scenarios.py @@ -0,0 +1,80 @@ +"""Step to load training and test scenarios from Hugging Face.""" + +from typing import Annotated, List, Tuple + +from datasets import load_dataset +from environment.models import Scenario +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + +SCENARIO_DATASET_REPO_ID = "corbt/enron_emails_sample_questions" + + +@step +def load_scenarios( + train_limit: int = 50, + test_limit: int = 20, + max_messages: int = 1, + seed: int = 42, +) -> Tuple[ + Annotated[List[Scenario], "train_scenarios"], + Annotated[List[Scenario], "test_scenarios"], +]: + """Load Q&A scenarios from Hugging Face for training and evaluation. + + Each scenario contains: + - A question about emails in a user's inbox + - The reference answer + - Message IDs of the emails containing the answer + - Metadata (inbox address, query date, realism score) + + Args: + train_limit: Maximum number of training scenarios to load. + test_limit: Maximum number of test scenarios to load. + max_messages: Filter to scenarios with at most this many source emails. + Simpler scenarios (fewer source emails) are easier to learn from. + seed: Random seed for reproducible shuffling. + + Returns: + Tuple of (train_scenarios, test_scenarios). + """ + logger.info(f"Loading scenarios from {SCENARIO_DATASET_REPO_ID}...") + + # Load train split + train_dataset = load_dataset(SCENARIO_DATASET_REPO_ID, split="train") + test_dataset = load_dataset(SCENARIO_DATASET_REPO_ID, split="test") + + # Filter by max_messages + if max_messages is not None: + train_dataset = train_dataset.filter( + lambda x: len(x["message_ids"]) <= max_messages + ) + test_dataset = test_dataset.filter( + lambda x: len(x["message_ids"]) <= max_messages + ) + + # Shuffle with seed for reproducibility + train_dataset = train_dataset.shuffle(seed=seed) + test_dataset = test_dataset.shuffle(seed=seed) + + # Convert to Scenario objects + train_scenarios = [Scenario(**row, split="train") for row in train_dataset] + test_scenarios = [Scenario(**row, split="test") for row in test_dataset] + + # Apply limits + if train_limit: + train_scenarios = train_scenarios[:train_limit] + if test_limit: + test_scenarios = test_scenarios[:test_limit] + + logger.info(f"Loaded {len(train_scenarios)} training scenarios") + logger.info(f"Loaded {len(test_scenarios)} test scenarios") + + # Log sample scenario + if train_scenarios: + sample = train_scenarios[0] + logger.info(f"Sample scenario: {sample.question[:100]}...") + + return train_scenarios, test_scenarios diff --git a/art-rl/steps/evaluation/__init__.py b/art-rl/steps/evaluation/__init__.py new file mode 100644 index 00000000..34fbf9d8 --- /dev/null +++ b/art-rl/steps/evaluation/__init__.py @@ -0,0 +1,10 @@ +# Evaluation steps +from steps.evaluation.compute_metrics import compute_metrics +from steps.evaluation.load_model import load_trained_model +from steps.evaluation.run_inference import run_inference + +__all__ = [ + "load_trained_model", + "run_inference", + "compute_metrics", +] diff --git a/art-rl/steps/evaluation/compute_metrics.py b/art-rl/steps/evaluation/compute_metrics.py new file mode 100644 index 00000000..5fdaca40 --- /dev/null +++ b/art-rl/steps/evaluation/compute_metrics.py @@ -0,0 +1,72 @@ +"""Step to compute evaluation metrics from predictions.""" + +from typing import Annotated, List + +from zenml import log_model_metadata, step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def compute_metrics( + predictions: List[dict], +) -> Annotated[dict, "evaluation_metrics"]: + """Compute evaluation metrics from inference predictions. + + Calculates: + - Accuracy: Fraction of scenarios answered correctly + - Answer rate: Fraction of scenarios that got any answer + - Source precision: How often predicted sources match expected + + Args: + predictions: List of prediction dictionaries from run_inference. + + Returns: + Dictionary of evaluation metrics. + """ + total = len(predictions) + if total == 0: + logger.warning("No predictions to evaluate") + return { + "total_scenarios": 0, + "correct": 0, + "accuracy": 0.0, + "answer_rate": 0.0, + "source_precision": 0.0, + } + + correct = sum(p["correct"] for p in predictions) + answered = sum(1 for p in predictions if p.get("predicted_answer")) + errors = sum(1 for p in predictions if p.get("error")) + + # Calculate source precision (did we find the right emails?) + source_matches = 0 + source_total = 0 + for p in predictions: + expected = set(p.get("expected_message_ids", [])) + predicted = set(p.get("predicted_source_ids", [])) + if expected: + source_total += 1 + if expected & predicted: # Any overlap + source_matches += 1 + + metrics = { + "total_scenarios": total, + "correct": correct, + "accuracy": correct / total, + "answered": answered, + "answer_rate": answered / total, + "errors": errors, + "error_rate": errors / total, + "source_matches": source_matches, + "source_precision": ( + source_matches / source_total if source_total > 0 else 0.0 + ), + } + + # Log to ZenML Model Control Plane + log_model_metadata(metadata={"evaluation": metrics}) + + logger.info(f"Evaluation metrics: {metrics}") + return metrics diff --git a/art-rl/steps/evaluation/load_model.py b/art-rl/steps/evaluation/load_model.py new file mode 100644 index 00000000..47ebd5b0 --- /dev/null +++ b/art-rl/steps/evaluation/load_model.py @@ -0,0 +1,40 @@ +"""Step to load a trained model for inference.""" + +from typing import Annotated + +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def load_trained_model( + model_config: dict, + checkpoint_path: str, + art_path: str = "./.art", +) -> Annotated[dict, "inference_config"]: + """Prepare configuration for loading the trained model. + + Since the actual model loading requires GPU resources and must happen + in the inference step, this step prepares the configuration needed + to load the model from the checkpoint. + + Args: + model_config: Original model configuration. + checkpoint_path: Path to the trained checkpoint. + art_path: Directory for ART files. + + Returns: + Configuration dict for loading the model during inference. + """ + inference_config = { + **model_config, + "checkpoint_path": checkpoint_path, + "art_path": art_path, + } + + logger.info( + f"Prepared inference config from checkpoint: {checkpoint_path}" + ) + return inference_config diff --git a/art-rl/steps/evaluation/run_inference.py b/art-rl/steps/evaluation/run_inference.py new file mode 100644 index 00000000..83047806 --- /dev/null +++ b/art-rl/steps/evaluation/run_inference.py @@ -0,0 +1,141 @@ +"""Step to run inference on test scenarios.""" + +import asyncio +from typing import Annotated, List + +from environment.models import Scenario +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +def _run_inference_loop( + inference_config: dict, + test_scenarios: List[Scenario], + db_path: str, + judge_model: str, +) -> List[dict]: + """Internal async inference loop.""" + import art + from agent.rollout import rollout + from art.langgraph import wrap_rollout + from art.local import LocalBackend + from environment.models import EmailScenario + + async def _async_inference(): + # Initialize model from checkpoint + model = art.TrainableModel( + name=inference_config["name"], + project=inference_config["project"], + base_model=inference_config["checkpoint_path"], + ) + + model._internal_config = art.dev.InternalModelConfig( + init_args=art.dev.InitArgs(max_seq_length=8192), + engine_args=art.dev.EngineArgs( + enforce_eager=True, + gpu_memory_utilization=0.8, + ), + ) + + backend = LocalBackend( + in_process=True, + path=inference_config["art_path"], + ) + await model.register(backend) + + predictions = [] + + for i, scenario in enumerate(test_scenarios): + logger.info( + f"Running inference on scenario {i+1}/{len(test_scenarios)}" + ) + + email_scenario = EmailScenario(step=0, scenario=scenario) + + try: + traj = await wrap_rollout(model, rollout)( + model, + email_scenario, + db_path=db_path, + judge_model=judge_model, + ) + + predictions.append( + { + "scenario_id": scenario.id, + "question": scenario.question, + "expected_answer": scenario.answer, + "expected_message_ids": scenario.message_ids, + "predicted_answer": ( + traj.final_answer.answer + if traj.final_answer + else None + ), + "predicted_source_ids": ( + traj.final_answer.source_ids + if traj.final_answer + else [] + ), + "correct": traj.metrics.get("correct", 0), + } + ) + except Exception as e: + logger.error(f"Error on scenario {scenario.id}: {e}") + predictions.append( + { + "scenario_id": scenario.id, + "question": scenario.question, + "expected_answer": scenario.answer, + "expected_message_ids": scenario.message_ids, + "predicted_answer": None, + "predicted_source_ids": [], + "correct": 0, + "error": str(e), + } + ) + + return predictions + + return asyncio.run(_async_inference()) + + +@step +def run_inference( + inference_config: dict, + test_scenarios: List[Scenario], + db_path: str, + judge_model: str = "openai/gpt-4.1", +) -> Annotated[List[dict], "predictions"]: + """Run the trained agent on test scenarios. + + For each scenario, the agent: + 1. Searches the email database using its learned strategy + 2. Provides a final answer with source references + 3. Gets judged for correctness by the judge model + + Args: + inference_config: Configuration from load_trained_model step. + test_scenarios: List of test scenarios to evaluate. + db_path: Path to the email database. + judge_model: LiteLLM model ID for correctness judging. + + Returns: + List of prediction dictionaries with results for each scenario. + """ + logger.info(f"Running inference on {len(test_scenarios)} test scenarios") + + predictions = _run_inference_loop( + inference_config=inference_config, + test_scenarios=test_scenarios, + db_path=db_path, + judge_model=judge_model, + ) + + correct_count = sum(p["correct"] for p in predictions) + logger.info( + f"Inference complete: {correct_count}/{len(predictions)} correct" + ) + + return predictions diff --git a/art-rl/steps/inference/__init__.py b/art-rl/steps/inference/__init__.py new file mode 100644 index 00000000..9eb66df2 --- /dev/null +++ b/art-rl/steps/inference/__init__.py @@ -0,0 +1,7 @@ +"""Inference steps for deployment.""" + +from steps.inference.single_inference import run_single_inference + +__all__ = [ + "run_single_inference", +] diff --git a/art-rl/steps/inference/single_inference.py b/art-rl/steps/inference/single_inference.py new file mode 100644 index 00000000..2a63dbba --- /dev/null +++ b/art-rl/steps/inference/single_inference.py @@ -0,0 +1,151 @@ +"""Step to run a single inference query.""" + +import asyncio +from typing import Annotated + +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +def _run_single_query( + question: str, + inbox_address: str, + query_date: str, + checkpoint_path: str, + db_path: str, + judge_model: str, + model_name: str, + project_name: str, + art_path: str, +) -> dict: + """Internal async inference for a single query.""" + import art + from agent.rollout import rollout + from art.langgraph import wrap_rollout + from art.local import LocalBackend + from environment.models import EmailScenario, Scenario + + async def _async_query(): + # Initialize model from checkpoint + model = art.TrainableModel( + name=model_name, + project=project_name, + base_model=checkpoint_path, + ) + + model._internal_config = art.dev.InternalModelConfig( + init_args=art.dev.InitArgs(max_seq_length=8192), + engine_args=art.dev.EngineArgs( + enforce_eager=True, + gpu_memory_utilization=0.8, + ), + ) + + backend = LocalBackend( + in_process=True, + path=art_path, + ) + await model.register(backend) + + # Create a scenario from the query parameters + scenario = Scenario( + id=0, + question=question, + answer="", # Unknown - this is what we're trying to find + message_ids=[], + how_realistic=1.0, + inbox_address=inbox_address, + query_date=query_date, + split="inference", + ) + + email_scenario = EmailScenario(step=0, scenario=scenario) + + try: + traj = await wrap_rollout(model, rollout)( + model, + email_scenario, + db_path=db_path, + judge_model=judge_model, + ) + + return { + "question": question, + "inbox_address": inbox_address, + "query_date": query_date, + "answer": ( + traj.final_answer.answer if traj.final_answer else None + ), + "source_ids": ( + traj.final_answer.source_ids if traj.final_answer else [] + ), + "success": traj.final_answer is not None, + } + except Exception as e: + logger.error(f"Error during inference: {e}") + return { + "question": question, + "inbox_address": inbox_address, + "query_date": query_date, + "answer": None, + "source_ids": [], + "success": False, + "error": str(e), + } + + return asyncio.run(_async_query()) + + +@step +def run_single_inference( + question: str, + inbox_address: str, + query_date: str, + checkpoint_path: str = "./.art/checkpoints/latest", + db_path: str = "./enron_emails.db", + judge_model: str = "openai/gpt-4.1", + model_name: str = "art-email-agent", + project_name: str = "email-search-agent", + art_path: str = "./.art", +) -> Annotated[dict, "inference_result"]: + """Run a single inference query through the trained agent. + + This step loads the trained model and runs a single email search + query, returning the agent's answer with source references. + + Args: + question: The question to answer about the user's emails. + inbox_address: The email address of the inbox to search. + query_date: The reference date for the query (YYYY-MM-DD). + checkpoint_path: Path to the trained model checkpoint. + db_path: Path to the email database. + judge_model: LiteLLM model ID for correctness judging. + model_name: Name of the ART model. + project_name: Name of the ART project. + art_path: Path to ART artifacts directory. + + Returns: + Dictionary with the agent's answer and metadata. + """ + logger.info(f"Running inference for question: {question[:50]}...") + + result = _run_single_query( + question=question, + inbox_address=inbox_address, + query_date=query_date, + checkpoint_path=checkpoint_path, + db_path=db_path, + judge_model=judge_model, + model_name=model_name, + project_name=project_name, + art_path=art_path, + ) + + if result.get("success"): + logger.info(f"Inference successful: {result['answer'][:100]}...") + else: + logger.warning(f"Inference failed: {result.get('error', 'No answer')}") + + return result diff --git a/art-rl/steps/training/__init__.py b/art-rl/steps/training/__init__.py new file mode 100644 index 00000000..54330374 --- /dev/null +++ b/art-rl/steps/training/__init__.py @@ -0,0 +1,8 @@ +# Training steps +from steps.training.setup_model import setup_art_model +from steps.training.train_agent import train_agent + +__all__ = [ + "setup_art_model", + "train_agent", +] diff --git a/art-rl/steps/training/setup_model.py b/art-rl/steps/training/setup_model.py new file mode 100644 index 00000000..63ae5c76 --- /dev/null +++ b/art-rl/steps/training/setup_model.py @@ -0,0 +1,40 @@ +"""Step to configure the ART trainable model.""" + +from typing import Annotated + +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def setup_art_model( + model_name: str = "art-email-agent", + project_name: str = "email-search-agent", + base_model: str = "Qwen/Qwen2.5-7B-Instruct", +) -> Annotated[dict, "model_config"]: + """Configure the ART model for training. + + This step creates a configuration dictionary that will be used to + initialize the ART TrainableModel in the training step. We separate + configuration from initialization because the actual model loading + requires GPU resources. + + Args: + model_name: Name for the trained model (used for checkpoints). + project_name: Project name for organizing experiments. + base_model: Hugging Face model ID for the base model. + Qwen 2.5 7B Instruct is recommended for this task. + + Returns: + Configuration dictionary for model initialization. + """ + config = { + "name": model_name, + "project": project_name, + "base_model": base_model, + } + + logger.info(f"Model configuration: {config}") + return config diff --git a/art-rl/steps/training/train_agent.py b/art-rl/steps/training/train_agent.py new file mode 100644 index 00000000..9c657d30 --- /dev/null +++ b/art-rl/steps/training/train_agent.py @@ -0,0 +1,229 @@ +"""Step to run the ART training loop with GRPO and RULER scoring.""" + +import asyncio +from typing import Annotated, List, Tuple + +from environment.models import Scenario +from zenml import log_model_metadata, step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +def _run_training_loop( + model_config: dict, + train_scenarios: List[Scenario], + db_path: str, + groups_per_step: int, + num_epochs: int, + rollouts_per_group: int, + learning_rate: float, + max_steps: int, + ruler_model: str, + art_path: str, +) -> Tuple[str, dict]: + """Internal async training loop wrapped for sync execution. + + This function contains the core ART training logic using GRPO + (Group Relative Policy Optimization) with RULER scoring. + """ + import art + from agent.rollout import rollout + from art.langgraph import wrap_rollout + from art.local import LocalBackend + from art.rewards import ruler_score_group + from art.utils import iterate_dataset + from environment.models import EmailScenario + + async def _async_train(): + # Initialize model with memory-optimized config for GPU containers + model = art.TrainableModel( + name=model_config["name"], + project=model_config["project"], + base_model=model_config["base_model"], + ) + + # Configure for GPU memory efficiency + model._internal_config = art.dev.InternalModelConfig( + init_args=art.dev.InitArgs(max_seq_length=8192), + engine_args=art.dev.EngineArgs( + enforce_eager=True, + gpu_memory_utilization=0.8, + ), + ) + + # LocalBackend runs inference/training in the same process + backend = LocalBackend(in_process=True, path=art_path) + await model.register(backend) + + # Training iterator handles epoch cycling and batching + training_iterator = iterate_dataset( + train_scenarios, + groups_per_step=groups_per_step, + num_epochs=num_epochs, + initial_step=await model.get_step(), + ) + + all_metrics = [] + + for batch in training_iterator: + logger.info( + f"Training step {batch.step}, epoch {batch.epoch}, " + f"batch size {len(batch.items)}" + ) + + # Create trajectory groups - each scenario gets multiple rollouts + groups = [] + for scenario in batch.items: + email_scenario = EmailScenario( + step=batch.step, + scenario=scenario, + ) + groups.append( + art.TrajectoryGroup( + ( + wrap_rollout(model, rollout)( + model, email_scenario, db_path=db_path + ) + for _ in range(rollouts_per_group) + ) + ) + ) + + # Gather all trajectories (parallel execution) + finished_groups = await art.gather_trajectory_groups( + groups, + pbar_desc="rollouts", + max_exceptions=rollouts_per_group * len(batch.items), + ) + + # RULER scoring - uses LLM judge to rank trajectories + judged_groups = [] + for group in finished_groups: + judged_group = await ruler_score_group(group, ruler_model) + if judged_group: + judged_groups.append(judged_group) + + if not judged_groups: + logger.warning( + "No valid trajectory groups after RULER scoring" + ) + continue + + # Calculate metrics for logging + total_trajectories = sum( + len(g.trajectories) for g in judged_groups + ) + avg_reward = ( + sum(t.reward for g in judged_groups for t in g.trajectories) + / total_trajectories + ) + accuracy = ( + sum( + t.metrics.get("correct", 0) + for g in judged_groups + for t in g.trajectories + ) + / total_trajectories + ) + + step_metrics = { + "step": batch.step, + "epoch": batch.epoch, + "avg_reward": avg_reward, + "accuracy": accuracy, + "num_trajectories": total_trajectories, + } + all_metrics.append(step_metrics) + + # Log to ZenML Model Control Plane + log_model_metadata(metadata=step_metrics) + logger.info( + f"Step {batch.step}: reward={avg_reward:.3f}, " + f"accuracy={accuracy:.3f}" + ) + + # GRPO training step + await model.delete_checkpoints() + await model.train( + judged_groups, + config=art.TrainConfig(learning_rate=learning_rate), + _config={"logprob_calculation_chunk_size": 8}, + ) + + if batch.step >= max_steps: + logger.info(f"Reached max_steps={max_steps}, stopping") + break + + checkpoint_path = f"{art_path}/checkpoints/latest" + return checkpoint_path, {"history": all_metrics} + + return asyncio.run(_async_train()) + + +@step(enable_cache=False) +def train_agent( + model_config: dict, + train_scenarios: List[Scenario], + db_path: str, + groups_per_step: int = 2, + num_epochs: int = 20, + rollouts_per_group: int = 4, + learning_rate: float = 1e-5, + max_steps: int = 20, + ruler_model: str = "openai/o4-mini", + art_path: str = "./.art", +) -> Tuple[ + Annotated[str, "checkpoint_path"], + Annotated[dict, "training_metrics"], +]: + """Run the ART training loop for the email search agent. + + This step implements the core RL training using: + - GRPO (Group Relative Policy Optimization) for policy updates + - RULER for scoring trajectories relative to each other + - LangGraph ReAct agents for executing rollouts + + The training process: + 1. For each batch of scenarios, generate multiple rollout trajectories + 2. Use RULER to score trajectories relative to each other + 3. Apply GRPO update using the scored trajectories + 4. Log metrics to ZenML for tracking + + Args: + model_config: Configuration from setup_art_model step. + train_scenarios: List of training scenarios. + db_path: Path to the email database. + groups_per_step: Number of scenario groups per training step. + num_epochs: Total number of passes through the training data. + rollouts_per_group: Number of rollouts per scenario for GRPO. + learning_rate: Learning rate for the optimizer. + max_steps: Maximum training steps (for early stopping). + ruler_model: LiteLLM model ID for RULER scoring. + art_path: Directory for ART checkpoints and logs. + + Returns: + Tuple of (checkpoint_path, training_metrics). + """ + logger.info("Starting ART training loop...") + logger.info(f"Training on {len(train_scenarios)} scenarios") + logger.info( + f"Config: groups_per_step={groups_per_step}, " + f"rollouts_per_group={rollouts_per_group}" + ) + + checkpoint_path, metrics = _run_training_loop( + model_config=model_config, + train_scenarios=train_scenarios, + db_path=db_path, + groups_per_step=groups_per_step, + num_epochs=num_epochs, + rollouts_per_group=rollouts_per_group, + learning_rate=learning_rate, + max_steps=max_steps, + ruler_model=ruler_model, + art_path=art_path, + ) + + logger.info(f"Training complete. Checkpoint saved to {checkpoint_path}") + return checkpoint_path, metrics