From b697d271f59f56b6df6ff4b19f8ddfb1f7f85f0b Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Sun, 17 Aug 2025 00:12:55 -0700
Subject: [PATCH 1/3] experimenting
---
compare_structured_output.py | 314 ++++++++++++++++++++++++++
dataset.py | 121 ++++++++++
experimental.py | 426 +++++++++++++++++++++++++++++++++++
3 files changed, 861 insertions(+)
create mode 100644 compare_structured_output.py
create mode 100644 dataset.py
create mode 100644 experimental.py
diff --git a/compare_structured_output.py b/compare_structured_output.py
new file mode 100644
index 00000000..ea1d83db
--- /dev/null
+++ b/compare_structured_output.py
@@ -0,0 +1,314 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+import adalflow as adal
+from adalflow.core import DataClass, required_field
+from adalflow.components.model_client.openai_client import OpenAIClient
+from adalflow.components.model_client.anthropic_client import AnthropicAPIClient
+
+from adalflow.components.output_parsers.outputs import JsonOutputParser, JsonOutputParserPydanticModel
+from adalflow.utils import setup_env
+
+# OpenAI's structured output approach
+from openai import OpenAI
+from pydantic import BaseModel
+
+setup_env()
+
+# ===============================
+# SHARED DATA MODELS
+# ===============================
+
+# Pydantic models for OpenAI structured output
+class Participants(BaseModel):
+ names: List[str]
+ addresses: List[str]
+
+class CalendarEvent(BaseModel):
+ name: str
+ date: str
+ participants: Participants
+
+# AdalFlow DataClass models
+@dataclass
+class ParticipantsData(DataClass):
+ names: List[str] = field(
+ metadata={"desc": "List of participant names"},
+ default_factory=list
+ )
+ addresses: Optional[List[str]] = field(
+ metadata={"desc": "List of participant addresses"},
+ default_factory=list
+ )
+
+@dataclass
+class CalendarEventData(DataClass):
+ name: str = field(
+ metadata={"desc": "Name of the calendar event"},
+ default_factory=required_field()
+ )
+ date: str = field(
+ metadata={"desc": "Date of the event"},
+ default_factory=required_field()
+ )
+ participants: ParticipantsData = field(
+ metadata={"desc": "Event participants information"},
+ default_factory=required_field()
+ )
+
+# ===============================
+# OPENAI STRUCTURED OUTPUT
+# ===============================
+
+def test_openai_structured_output():
+ """Test OpenAI's native structured output parsing."""
+ print("\n=== Testing OpenAI Structured Output ===")
+
+ client = OpenAI()
+
+ try:
+ response = client.responses.parse(
+ model="gpt-4o-2024-08-06",
+ input=[
+ {"role": "system", "content": "Extract the event information. Use a synthesic, very complciated string as the address for the particpant to test very complicated json parsing"},
+ {
+ "role": "user",
+ "content": "Alice and Bob are going to a science fair on Friday.",
+ },
+ ],
+ text_format=CalendarEvent,
+ )
+
+ event = response.output_parsed
+ print(f"OpenAI Response: {event}")
+
+ return event
+
+ except Exception as e:
+ print(f"OpenAI structured output failed: {e}")
+ return None
+
+# ===============================
+# ADALFLOW GENERATOR + JSON PARSER
+# ===============================
+
+class AdalFlowEventExtractor(adal.Component):
+ """AdalFlow component using Generator with JsonOutputParser."""
+
+ def __init__(self, model_client: adal.ModelClient, model_kwargs: dict):
+ super().__init__()
+
+ # Set up output parser
+ self.output_parser = JsonOutputParser(
+ data_class=CalendarEventData,
+ return_data_class=True
+ )
+ self.output_parser_pydantic = JsonOutputParserPydanticModel(
+ pydantic_model=CalendarEvent,
+ return_pydantic_object=True
+ )
+ # Template for the Generator
+ self.template = r"""
+
+{{system_prompt}}
+
+{{output_format_str}}
+
+
+{{user_input}}
+
+ """.strip()
+
+ # Set up Generator
+ self.llm = adal.Generator(
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ template=self.template,
+ output_processors=self.output_parser,
+ use_cache=False,
+ )
+
+ def call(self, user_input: str, system_prompt: str) -> adal.GeneratorOutput:
+ """Extract event information using AdalFlow Generator + JsonOutputParser."""
+ # print(f" output_format_str: {self.output_parser.format_instructions()}")
+ prompt_kwargs = {
+ "system_prompt": system_prompt,
+ "output_format_str": self.output_parser.format_instructions(),
+ "user_input": user_input
+ }
+
+ return self.llm(prompt_kwargs=prompt_kwargs)
+
+def test_adalflow_json_parser():
+ """Test AdalFlow's Generator with JsonOutputParser."""
+ print("\n=== Testing AdalFlow Generator + JsonOutputParser ===")
+
+ model_config = {
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 500,
+ }
+ }
+
+ extractor = AdalFlowEventExtractor(
+ model_client=model_config["model_client"],
+ model_kwargs=model_config["model_kwargs"]
+ )
+
+ system_prompt = "Extract the event information. g"
+ user_input = "Alice and Bob are going to a science fair on Friday."
+
+ try:
+ response = extractor.call(user_input, system_prompt)
+
+ # print(f"AdalFlow Response: {response}")
+
+ if response.data:
+ event_data = response.data
+ print(f"Event Name: {event_data.name}")
+ print(f"Date: {event_data.date}")
+ print(f"Participants: {event_data.participants.names}")
+ print(f"Addresses: {event_data.participants.addresses}")
+ else:
+ print(f"Failed to parse, raw response: {response.raw_response}")
+
+ return response
+
+ except Exception as e:
+ print(f"AdalFlow extraction failed: {e}")
+ return None
+
+# ===============================
+# COMPARISON AND BENCHMARKING
+# ===============================
+
+def compare_approaches():
+ """Compare both approaches side by side."""
+ print("\n" + "="*60)
+ print("STRUCTURED OUTPUT COMPARISON")
+ print("="*60)
+
+ # Test cases with increasingly complex scenarios
+ test_cases = [
+ "Alice and Bob are going to a science fair on Friday.",
+ "John, Mary, and David will attend the birthday party on December 25th.",
+ "The team meeting with Sarah, Mike, Tom, and Lisa is scheduled for next Monday at the conference room.",
+ # Complex address challenge
+ "Dr. Katherine Martinez-O'Sullivan and Prof. Ahmed bin Rashid Al-Maktoum will attend the International Conference on AI Ethics on March 15th, 2024. Katherine lives at 1247-B Château de Malmaison, Apt. #47/C, Neuilly-sur-Seine, Île-de-France 92200, France (GPS: 48.8738°N, 2.1667°E) and Ahmed resides at Building 47/Tower C/Floor 23/Unit 2301-A, Sheikh Zayed Road Complex, Near Dubai Mall Metro Station, Dubai, United Arab Emirates, P.O. Box 112233-ABUDHABI-UAE (Emergency Contact: +971-4-XXX-YYYY).",
+ # JSON-breaking characters and edge cases
+ "Meeting participants: Alex \"The Coder\" Johnson (address: 123 Main St., Unit #5-B\nSecond Floor\n\"Special Building\"\nCity: Austin\"Texas\" 78701\nCountry: USA\"America\"), Maria José Rodríguez-Pérez (Calle de José Ortega y Gasset, 29\n28006 Madrid\nEspaña), and Zhang Wei 张伟 (北京市朝阳区建国门外大街1号\nBeijing 100001\n中华人民共和国). Event: Tech Summit 2024 on February 29th.",
+ # Unicode, emojis, and special formatting
+ "🎉 Grand Opening Party 🎊 on Saturday, Jan 20th! Attendees: Σωκράτης Παπαδόπουλος (Πλατεία Συντάγματος 1, Αθήνα 10563, Ελλάδα 🇬🇷), Владимир Иванович Петров (Красная площадь, дом 1, Москва 109012, Российская Федерация 🇷🇺), and السيد أحمد محمد علي (شارع التحرير رقم 15، القاهرة 11511، جمهورية مصر العربية 🇪🇬).",
+ # Nested quotes and complex punctuation
+ "Annual \"Best Practices\" Workshop featuring: Robert \"Bob\" O'Malley-Smith Jr., Ph.D., M.D. (address: \"The Heights\", Building A-1, Suite 100-C, Floor 2.5, 987 Oak Tree Lane & Maple Street Intersection, South Hampton, NY 11968-1234 [Note: Use side entrance during construction]), Dr. Mary-Catherine Van Der Berg-Williams III (Château \"Les Trois Rosés\", 45 Rue de la Paix & Boulevard Saint-Germain, 7ème Arrondissement, Paris 75007, République Française [Buzzer code: \"2024-SECRET\"]), scheduled for December 31st, 2024 at 11:59 PM.",
+ # JSON structure breakers and extreme formatting
+ "Emergency meeting tomorrow! Participants: {\"name\": \"John Smith\", \"role\": \"CEO\"} living at [Address Object]: {\"street\": \"123 {Main} Street\", \"apt\": \"#[456-B]\", \"city\": \"New {York}\", \"state\": \"NY\", \"zip\": \"[10001-2345]\", \"coordinates\": {\"lat\": 40.7589, \"lng\": -73.9851}, \"special_notes\": \"Ring bell 3x, say 'pizza delivery', wait 30sec, then ring 2x more\"}, and Jane Doe at \"[CLASSIFIED LOCATION]\" {GPS: [REDACTED], Building: [UNKNOWN], Floor: [N/A]}"
+ ]
+ adalflow_count = 0
+ openai_count = 0
+
+ for i, test_case in enumerate(test_cases, 1):
+ print(f"\n--- Test Case {i} ---")
+ print(f"Input: {test_case}")
+
+ # Test OpenAI
+ openai_result = test_openai_structured_output_simple(test_case)
+
+ # Test AdalFlow
+ adalflow_result = test_adalflow_json_parser_simple(test_case)
+
+ print("adalflow_result:", adalflow_result.data)
+ print("openai_result:", openai_result)
+
+
+
+ # Compare results
+ print("\n--- Comparison ---")
+ if openai_result and isinstance(openai_result, CalendarEvent):
+ print("✅Openai")
+ openai_count += 1
+ if adalflow_result and adalflow_result.data and (isinstance(adalflow_result.data, CalendarEvent) or isinstance(adalflow_result.data, CalendarEventData)):
+ print("✅AdalFlow")
+ adalflow_count += 1
+
+
+ print(f"OpenAI Count: {openai_count}, AdalFlow Count: {adalflow_count}")
+
+def test_openai_structured_output_simple(user_input: str):
+ """Simplified OpenAI test for comparison."""
+ client = OpenAI()
+
+ try:
+ response = client.responses.parse(
+ model="gpt-4o-2024-08-06",
+ input=[
+ {"role": "system", "content": "Extract the event information."},
+ {
+ "role": "user",
+ "content": user_input,
+ },
+ ],
+ text_format=CalendarEvent,
+ )
+
+ event = response.output_parsed
+ return event
+ except Exception as e:
+ print(f"OpenAI error: {e}")
+ return None
+
+def test_adalflow_json_parser_simple(user_input: str):
+ """Simplified AdalFlow test for comparison."""
+
+
+
+ model_config={
+ "model_client": AnthropicAPIClient(),
+ "model_kwargs": {
+ "model": "claude-sonnet-4-20250514"
+ }
+ }
+ model_config={
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "gpt-4o-mini",
+ }
+ }
+
+ extractor = AdalFlowEventExtractor(
+ model_client=model_config["model_client"],
+ model_kwargs=model_config["model_kwargs"]
+ )
+
+ system_prompt = "Extract the event information. "
+
+ try:
+ return extractor.call(user_input, system_prompt)
+ except Exception as e:
+ print(f"AdalFlow error: {e}")
+ return None
+
+if __name__ == "__main__":
+ # Run individual tests
+ openai_result = test_openai_structured_output()
+ adalflow_result = test_adalflow_json_parser()
+
+ # Run comparison
+ compare_approaches()
+
+ print("\n" + "="*60)
+ print("SUMMARY")
+ print("="*60)
+ print("OpenAI Structured Output:")
+ print(" + Native JSON schema validation")
+ print(" + Guaranteed structure compliance")
+ print(" - Requires specific OpenAI models")
+ print(" - Less flexibility in processing pipeline")
+
+ print("\nAdalFlow Generator + JsonOutputParser:")
+ print(" + Model-agnostic approach")
+ print(" + Flexible processing pipeline")
+ print(" + Integration with optimization framework")
+ print(" - May require retry logic for parsing failures")
+ print(" + Better integration with AdalFlow ecosystem")
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 00000000..c87af0c8
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,121 @@
+dataset = [
+ {
+ "pages": [
+ "NOVALUX SERVICES LLC — internal memo v3\nrev 03/16/2024 | cycle Q1 | ticket 004271\nnotes: last spring round 03-14-2024; prev annual 2019-03-17; audit 15 Mar 2024\nthread list: Liam Chen; Maria Gomez; Henry Stone; Noah Patel; Andrea Ruiz; David Roe\nrooftop set: AC-7C, AC-9F; site: Northwind Offices; zone: PINE-WALK\nroster excerpt: Ruiz, Andrea; Patel, Noah; Chen, Olivia; Long, Peter; Gomez, Maria\nline items:\n- 7C belts swapped 03/15/24\n- 9F filters swapped 03/14/24\n- motor check 2023-03-17\nnames mentioned elsewhere: Omar Diaz, Jason Miller, Karen Bell, Raj Patel\n--- mid-block ------------------------------------------------\nroute grid R2: Chen, Liam | Gomez, Maria | Stone, Emily | Roe, David | Patel, Noah\ndispatch key [core-record: 2024-03-15] ref A1-7C\nnorth list crossref: Henry Stone; Andrea Ruiz\n----------------------------------------------------------------\nmisc dates: 03/15/2024 09:40; 2024/03/15; 15-03-2024\nfooter tag id NLX-INT-004271\n"
+ ],
+ "date_rule": "Extract the ISO date appearing between the exact marker '[core-record: ' and the closing ']'.",
+ "expected_output": {
+ "document_main_date": "2024-03-15",
+ "client_information": {
+ "first_name": "Emily",
+ "last_name": "Stone"
+ }
+ }
+ },
+ {
+ "pages": [
+ "ORION INSURANCE COMPANY — coverage packet\nprint 2025/01/02; cycle close 2024-12-29; sample form rev 12-2024\nreference codes: AX-77-193; BQ-11-004\nnames in circulation: Ethan Li, Priya Nair, Omar Diaz, Laurel Kim, Julia Park\npage markers: P1/2\nlist A (mail sweep): Mendes, Ana; Park, Julia; Li, Ethan; Singh, Maya; Patel, Raj\n--- center band ----------------------------------------------\nwindow text row: Mendes, Carlos | addr token G-441B | @when=2025-02-01@\n--------------------------------------------------------------\nreminders: renewal cycle hits 2026-02-01; example date 02/01/2025 (US)\ntrailing mentions: Raj Patel; Maya Singh; Laurel Kim\n"
+ ],
+ "date_rule": "Select the YYYY-MM-DD date enclosed by the markers '@when=' and '@'.",
+ "expected_output": {
+ "document_main_date": "2025-02-01",
+ "client_information": {
+ "first_name": "Carlos",
+ "last_name": "Mendes"
+ }
+ }
+ },
+ {
+ "pages": [
+ "NORTH RIVER BANK — consolidated lines\nhdr date 11/05/2023; period 10/01/23–10/31/23; ref NRB-STAT-1180\nledger notes: auditor Olivia Chen 02-Nov-2023; prior msg 2023-10-29; contact Jason Miller\nmasked acct **** 4831 | manager Tomas Rivera | approver Ellen Wu\npeople mentioned: Peter Sand; Daniel Cho; Cathy Nguyen; Henry Stone\nflow:\n- ACH in 10/30/23\n- fee waive 11-02-2023\n--- center ledger --------------------------------------------\nparticipants: Sand, Peter | Nair, Priya | Chen, Olivia | Miller, Jason\nseal batch 44A\n--------------------------------------------------------------\nfooter fragments: 05-11-2023; Nov 1, 2023; 2023/11/01\n"
+ ],
+ "date_rule": "Use the date contained between '' exactly as YYYY-MM-DD.",
+ "expected_output": {
+ "document_main_date": "2023-11-01",
+ "client_information": {
+ "first_name": "Priya",
+ "last_name": "Nair"
+ }
+ }
+ },
+ {
+ "pages": [
+ "OAK CREST PROPERTY MGMT. — unit packet\nbldg: Lakeview Rd., unit 5A; intercom map rev 08/23; parking memo 08-23-2022\nnames across building: Amy Tran; Daniel Ortiz; Raj Patel; Sarah Onu; Michael Lin\nstack A roster: Blake, Sarah (prev); Grant, Oliver (current); Ruiz, Andrea; Gomez, Maria\npage 1/3\n--- carryover data -------------------------------------------\nmailbox panel: 5A GRANT O | 5B TRAN A | 6C ORTIZ D | 3D PATEL R\ninspection mentions: 08/22/2022; utilities 08/25/2022; move target 09/01/2022\n--------------------------------------------------------------\nnext page\n",
+ "OAK CREST PROPERTY MGMT. — notes\nvisitors seen: Andrea Ruiz; Maria Gomez; David Roe; Jason Miller; Olivia Chen\nrandom dates: 20-08-2022; 2022/08/20; 08/20/22\n--- mid-strip -------------------------------------------------\nkey timeline |dt|2022-08-20|dt| for archive tag LC-5A\n--------------------------------------------------------------\nother: parking review 2022-08-23; form edits 08-18-2022\npage 2/3\n",
+ "OAK CREST PROPERTY MGMT. — misc\nunit map checksum 5A-7F-2C; contact index L.Park; badge review 2022-08-21\nfooter copy ids: OCP-5A-AG\npage 3/3\n"
+ ],
+ "date_rule": "Extract the date between the exact tokens '|dt|' and '|dt|' on page 2.",
+ "expected_output": {
+ "document_main_date": "2022-08-20",
+ "client_information": {
+ "first_name": "Oliver",
+ "last_name": "Grant"
+ }
+ }
+ },
+ {
+ "pages": [
+ "CITYCARE CLINIC — visit archive\nprint 2021-07-14; prior vaccination 2021-06-10; relative visit 06/30/2021\nstaff roll: Mark Holloway; Eva Burns; Nora Lee; Raj Patel\nname scatter: Ramirez, Luis; Mendes, Carlos; Stone, Emily; Lee, Marcus; Petrova, Sofia\nsymptoms log id CC-7781\n--- middle row ------------------------------------------------\nconsent line: Ramirez, Zoe {iso:2021-07-20} sig on file\n--------------------------------------------------------------\nother timestamps: 2021/07/18; 07-20-21; 2021-01-01 (policy)\n"
+ ],
+ "date_rule": "Choose the YYYY-MM-DD value inside the braces after 'iso:' on the consent line.",
+ "expected_output": {
+ "document_main_date": "2021-07-20",
+ "client_information": {
+ "first_name": "Zoe",
+ "last_name": "Ramirez"
+ }
+ }
+ },
+ {
+ "pages": [
+ "XELTRONICS CORPORATION — hiring bundle\nrev 05/11/2020 meeting; start target 06/01/2020; approvals 2020-05-09\nnames across shortlist: Bell, Karen; Park, Julia; Diaz, Omar; Young, Samuel; Novak, Diana\npipeline list: Chen, Olivia; Patel, Raj; Lee, Marcus; Ali, Hassan; Brooks, Natalie\n--- middle band ----------------------------------------------\nroster slot SE-I: Ali, Hassan #on:2020-05-12# marker SE1-B\n--------------------------------------------------------------\nother dates: 05/12/20; 2020/05/12; 12-05-2020\n"
+ ],
+ "date_rule": "Extract the YYYY-MM-DD date between '#on:' and '#'.",
+ "expected_output": {
+ "document_main_date": "2020-05-12",
+ "client_information": {
+ "first_name": "Hassan",
+ "last_name": "Ali"
+ }
+ }
+ },
+ {
+ "pages": [
+ "GREENFIELD UNIVERSITY — decision file\nprint 2019-03-10; committee 11/03/2019; orientation 2019-08-26; deadline 04/15/2019\nstaff: Harold King; Maya Singh; Raj Patel; Olivia Chen\nnames in cohort: Cole, Jason; Lee, Marcus; Brooks, Natalie; Mendes, Carlos; Nair, Priya\n--- central cut ----------------------------------------------\nfile tag GU-2019-4412: Petrova, Sofia <<2019-03-12>> status ADMIT\n--------------------------------------------------------------\nmirrored dates: 03-12-2019; 2019/03/12\n"
+ ],
+ "date_rule": "Use the date inside the double angle brackets '<<' and '>>'.",
+ "expected_output": {
+ "document_main_date": "2019-03-12",
+ "client_information": {
+ "first_name": "Sofia",
+ "last_name": "Petrova"
+ }
+ }
+ },
+ {
+ "pages": [
+ "SKYQUEST TRAVEL — booking ledger\nprint 2018-09-04; quote 09/02/2018; depart 2018-10-05 07:25; return 2018-10-12 19:40\nagent: Irene Zhao; group coord: Hannah Park\nnames log: Cho, Daniel; Nguyen, Cathy; Cole, Jason; Petrova, Sofia; Mendes, Carlos\nPNR refs: LEE/MARCUS; CHO/DANIEL; NGUYEN/CATHY\n--- mid block -------------------------------------------------\nrecord LEE/MARCUS (iso) 2018-09-03 (iso) JFK leg confirm\n--------------------------------------------------------------\nother styles: 09-03-2018; 2018/09/03\n"
+ ],
+ "date_rule": "Extract the YYYY-MM-DD that appears between the tokens '(iso) ' and ' (iso)'.",
+ "expected_output": {
+ "document_main_date": "2018-09-03",
+ "ClientInformation": {
+ "first_name": "Marcus",
+ "last_name": "Lee"
+ }
+ }
+ },
+ {
+ "pages": [
+ "RIVERSIDE ENERGY — statement dump\nperiod 11/01/2017–11/30/2017; read 2017-11-28; prior payment 2017-11-10; rebate 2017-10-05\ntouchpoints: Olga Ivanov; Sean Murphy; Emily Stone; Oliver Grant; Priya Nair\naccount roster sample: Murphy, Sean; Grant, Oliver; Nair, Priya; Brooks, Natalie; Lee, Marcus\n--- center line ----------------------------------------------\nacct row Orchard Ave: Brooks, Natalie [bill.iso=2017-12-01] due 12/21/2017\n--------------------------------------------------------------\nfooter echoes: 2017/12/01; 01-12-2017; Dec 1, 2017\n"
+ ],
+ "date_rule": "Use the date that appears after 'bill.iso=' inside the square brackets.",
+ "expected_output": {
+ "document_main_date": "2017-12-01",
+ "client_information": {
+ "first_name": "Natalie",
+ "last_name": "Brooks"
+ }
+ }
+ }
+]
\ No newline at end of file
diff --git a/experimental.py b/experimental.py
new file mode 100644
index 00000000..810c8375
--- /dev/null
+++ b/experimental.py
@@ -0,0 +1,426 @@
+from dataclasses import dataclass, field
+from typing import Literal
+from adalflow.datasets.types import DataClass, BaseData
+from adalflow.core import DataClass, required_field
+from dataset import dataset
+from adalflow.utils import setup_env
+
+setup_env()
+@dataclass
+class DataExtractionInput(BaseData):
+ pages: list[str] = field(
+ metadata={"desc": "The pages of the document"},
+ default_factory=required_field()
+ )
+ date_rule: str = field(
+ metadata={"desc": "The rule for extracting the date from the document"},
+ default_factory=required_field()
+ )
+ expected_output: dict = field(
+ metadata={"desc": "The ground truth of the data extraction"},
+ default=None
+ )
+
+
+ __input_fields__ = ["pages", "date_rule"]
+
+
+@dataclass
+class ClientInformation(DataClass):
+
+ primary_client_reasoning_scratch_pad: str = field(
+ metadata={"desc": "The reasoning process for determining the primary client name. This is the name that should be used for the client information."},
+ default_factory=required_field()
+ )
+ first_name_reasoning_scratch_pad: str = field(
+ metadata={"desc": "The reasoning process for determining the client's first name."},
+ default_factory=required_field()
+ )
+ first_name: str = field(
+ metadata={"desc": "The client's given first name, parsed according to rules."},
+ default_factory=required_field()
+ )
+ middle_name: str = field(
+ metadata={"desc": "The client's middle name, if present as a full word after parsing."},
+ default_factory=required_field()
+ )
+ last_name_reasoning_scratch_pad: str = field(
+ metadata={"desc": "The reasoning process for determining the client's last name."},
+ default_factory=required_field()
+ )
+ last_name: str = field(
+ metadata={"desc": "The client's surname or family name, parsed according to rules."},
+ default_factory=required_field()
+ )
+
+
+@dataclass
+class DataExtractionOutput(BaseData):
+
+ document_dates: list[str] = field(
+ metadata={"desc": "The list of dates found in the document"},
+ default_factory=list,
+ )
+ document_main_date: str = field(
+ metadata={"desc": "The main date of the document, extracted from the list document_dates"},
+ default_factory=required_field()
+ )
+ client_information: ClientInformation = field(
+ metadata={"desc": "The client information of the document"},
+ default_factory=required_field(),
+ )
+
+ __output_fields__ = ["document_dates", "document_main_date", "client_information"]
+
+input_dataclass_list = []
+
+for item in dataset:
+ dataset_item = DataExtractionInput.from_dict(item)
+ # dataset_item_1 = DataExtractionInput(
+ # pages=item["pages"],
+ # date_rule=item["date_rule"],
+ # expected_output=item["expected_output"]
+ # )
+ # assert dataset_item == dataset_item_1, "Dataset item does not match the expected structure"
+ input_dataclass_list.append(dataset_item)
+
+
+train_dataset = input_dataclass_list[:2]
+val_dataset = input_dataclass_list[3:5]
+test_dataset = input_dataclass_list[6:8]
+
+len(train_dataset), len(val_dataset), len(test_dataset)
+
+from adalflow.components.model_client.openai_client import OpenAIClient
+import adalflow as adal
+
+model_openai_o4_mini = {
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "o4-mini", # or "o1"
+ "reasoning": {
+ "effort": "medium", # low, medium, high
+ }
+ }
+}
+
+template = r"""
+
+ {{system_prompt}}
+
+ {{output_format_str}}
+
+ {% if few_shot_demos is not none %}
+ Here are some examples:
+ {{few_shot_demos}}
+ {% endif %}
+
+
+ {{input_str}}
+
+ """.strip()
+
+
+task_prompt_document_extraction = r"""
+You are a helpful assistant specialized in data processing and extraction.
+""".strip()
+
+
+from typing import Union, Optional
+import adalflow as adal
+
+class DataExtractor(adal.Component):
+
+ def __init__(self, model_client: adal.ModelClient, model_kwargs: dict):
+ super().__init__()
+
+ # INPUT
+ self.data_class_input = DataExtractionInput
+ self.parser_input = adal.DataClassParser(
+ data_class=self.data_class_input, return_data_class=True, format_type="json"
+ )
+
+ # OUTPUT
+ task_desc_str = adal.Prompt(
+ template=task_prompt_document_extraction,
+ # prompt_kwargs={"classes": label_desc} #prompt variables to be hydrated
+ )()
+ self.data_class_output = DataExtractionOutput
+ self.data_class_output.set_task_desc(task_desc_str)
+
+ self.parser_output = adal.DataClassParser(
+ data_class=self.data_class_output, return_data_class=True, format_type="json"
+ )
+
+ print(f"oututput format: {self.parser_output.get_output_format_str()}")
+
+ # GENERATOR PARAMS
+ prompt_kwargs = {
+ "system_prompt": adal.Parameter(
+ data=self.parser_output.get_task_desc_str(),
+ role_desc="Task description",
+ requires_opt=True,
+ param_type=adal.ParameterType.PROMPT,
+ ),
+ "output_format_str": adal.Parameter(
+ data=self.parser_output.get_output_format_str(),
+ role_desc="Output format requirements",
+ requires_opt=True,
+ param_type=adal.ParameterType.PROMPT,
+ ),
+
+ # I didnt enable few shot demos to not overcomplicate, but it would be nice to get this working too. :D
+
+ # "few_shot_demos": adal.Parameter(
+ # data=None,
+ # requires_opt=True,
+ # role_desc="Few shot examples to help the model",
+ # param_type=adal.ParameterType.DEMOS,
+ # ),
+ }
+
+ self.llm = adal.Generator(
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ prompt_kwargs=prompt_kwargs,
+ template=template,
+ output_processors=self.parser_output,
+ use_cache=False,
+ )
+
+ print(f"system prompt: {self.llm.get_prompt()}")
+
+ def _prepare_input(self, dataset_item: DataExtractionInput):
+
+ # QUESTION:
+ # Im my use case, I have some arguments to pass to the prompt that are different for each dataset item.
+ # Normaly I would put it in the system prompt, but here I was not sure if I could change the system prompt inside _prepare_input. Maybe it could break something.
+ # So here in this example code, Im just passing it as a dump of the input data.
+
+ input_data = self.data_class_input(pages=dataset_item.pages, date_rule=dataset_item.date_rule)
+ input_str = self.parser_input.get_input_str(input_data)
+
+ prompt_kwargs = {
+ "input_str": adal.Parameter(
+ data=input_str, requires_opt=False, role_desc="input to the LLM"
+ ) }
+ return prompt_kwargs
+
+ def bicall(self,
+ dataset_item: DataExtractionInput,
+ id: Optional[str] = None
+ ) -> Union[adal.GeneratorOutput, adal.Parameter]:
+ prompt_kwargs = self._prepare_input(dataset_item)
+ output = self.llm(prompt_kwargs=prompt_kwargs, id=id)
+
+ return output
+
+task = DataExtractor(
+ model_client=model_openai_o4_mini["model_client"],
+ model_kwargs=model_openai_o4_mini["model_kwargs"],
+)
+print(task)
+
+from typing import Dict, Callable, Any, Tuple
+
+from adalflow.eval.answer_match_acc import AnswerMatchAcc
+
+def eval_fn_data_extraction(
+ y: DataExtractionOutput,
+ y_gt: DataExtractionInput
+) -> float:
+ # I got some runs where the LLM failed to parse the output correctly.
+ # TODO: try do to some retry somewhere in the code.
+
+ # QUESTION:
+ # I dont know if AdalFlow framework support it?
+ # Maybe if the framework use the response format of the API it would increase the json output accuracy?
+
+ try:
+ patient_first_name_pred = y.client_information.first_name
+ patient_first_name_gt = y_gt.expected_output["client_information"]["first_name"]
+ return AnswerMatchAcc(type="exact_match").compute_single_item(patient_first_name_pred, patient_first_name_gt)
+
+ except Exception as e:
+ print(f"Parse error: {e}")
+ return 0
+
+class DataExtractorTrainner(adal.AdalComponent):
+ def __init__(
+ self,
+ model_client: adal.ModelClient,
+ model_kwargs: Dict,
+ teacher_model_config: Dict,
+ backward_engine_model_config: Dict,
+ text_optimizer_model_config: Dict,
+ ):
+ task = DataExtractor(model_client, model_kwargs)
+ # eval_fn = AnswerMatchAcc(type="exact_match").compute_single_item
+ eval_fn = eval_fn_data_extraction
+ loss_fn = adal.EvalFnToTextLoss(
+ eval_fn=eval_fn,
+ eval_fn_desc="exact_match: 1 if str(y) == str(y_gt) else 0. When the LLM prediction failed with format parsing which results with errors, we set y_pred = -1",
+ )
+ super().__init__(
+ task=task,
+ eval_fn=eval_fn,
+ loss_fn=loss_fn,
+ backward_engine_model_config=backward_engine_model_config,
+ text_optimizer_model_config=text_optimizer_model_config,
+ teacher_model_config=teacher_model_config,
+ )
+
+ def prepare_task(self, sample: DataExtractionInput):
+ return self.task.call, {"dataset_item": sample, "id": sample.id}
+
+ def prepare_eval(
+ self, sample: DataExtractionInput, y_pred: adal.GeneratorOutput
+ ) -> float:
+
+ prediction = -1
+ if y_pred and y_pred.data is not None:
+ prediction = y_pred.data
+ return self.eval_fn, {"y": prediction, "y_gt": sample}
+
+ def prepare_loss(
+ self, dataset_item: DataExtractionInput, y_pred: adal.Parameter, *args, **kwargs
+ ) -> Tuple[Callable[..., Any], Dict]:
+
+ # QUESTION:
+ # How can I create a system that has multiple target variables to compute loss?
+
+
+ full_response = y_pred.data
+ y_label = -1 # default value for failed prediction
+ try:
+ first_name = full_response.data.client_information.first_name
+ except Exception as e:
+ print(f"Parse error: {e}")
+ first_name = -1
+
+ if isinstance(first_name, str) and first_name.strip():
+ y_label = first_name.strip()
+
+ y_pred.eval_input = y_label
+
+ y_gt = adal.Parameter(
+ name="y_gt",
+ data=dataset_item.expected_output["client_information"]["first_name"],
+ eval_input=dataset_item.expected_output["client_information"]["first_name"],
+ requires_opt=False,
+ )
+
+ return self.loss_fn, {
+ "kwargs": {
+ "y": y_pred,
+ "y_gt": y_gt
+ },
+ "id": dataset_item.id,
+ "gt": y_gt.eval_input,
+ }
+
+def train(
+ model_client: adal.ModelClient,
+ model_kwargs: Dict,
+ train_batch_size=4,
+ raw_shots: int = 0,
+ bootstrap_shots: int = 1,
+ max_steps=12,
+ num_workers=4,
+ strategy="constrained",
+ optimization_order="sequential",
+ debug=False,
+):
+ print("Starting training process...")
+
+ # Define the model configuration for all components
+ # gpt_4o_model = {
+ # "model_client": OpenAIClient(),
+ # "model_kwargs": {
+ # "model": "gpt-4o-mini",
+ # "temperature": 1,
+ # "top_p": 0.99,
+ # "max_tokens": 1000,
+ # # "frequency_penalty": 1, # high for nto repeating prompt
+ # },
+ # }
+ model_openai_o4_mini = {
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "o4-mini", # or "o1"
+ "reasoning": {
+ "effort": "medium", # low, medium, high
+ }
+ }
+ }
+
+ model_openai_gpt_5 = {
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "gpt-5", # or "o1"
+ "reasoning": {
+ "effort": "medium", # low, medium, high
+ }
+ }
+ }
+
+ print(f"Component model configuration: {model_openai_o4_mini}")
+
+ try:
+ print("Initializing ADAL component...")
+ adal_component = DataExtractorTrainner(
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ text_optimizer_model_config=model_openai_gpt_5,
+ backward_engine_model_config=model_openai_o4_mini,
+ teacher_model_config=model_openai_o4_mini,
+ )
+ print("ADAL component initialized successfully")
+
+ print("Initializing trainer...")
+ trainer = adal.Trainer(
+ train_batch_size=train_batch_size,
+ adaltask=adal_component,
+ strategy=strategy,
+ max_steps=max_steps,
+ num_workers=num_workers,
+ raw_shots=raw_shots,
+ bootstrap_shots=bootstrap_shots,
+ debug=debug,
+ weighted_sampling=True,
+ optimization_order=optimization_order,
+ exclude_input_fields_from_bootstrap_demos=True,
+ )
+ print("Trainer initialized successfully")
+
+ print("Loading datasets...")
+ # train_dataset, val_dataset, test_dataset = load_datasets()
+ print(
+ f"Datasets loaded - Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}"
+ )
+
+ print("Starting model training...")
+ trainer.fit(
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ test_dataset=test_dataset,
+ debug=debug,
+ )
+ print("Training completed successfully")
+
+ except Exception as e:
+ print(f"Error occurred: {str(e)}")
+ raise
+
+
+model_openai_o4_mini = {
+ "model_client": OpenAIClient(),
+ "model_kwargs": {
+ "model": "o4-mini", # or "o1"
+ "reasoning": {
+ "effort": "medium", # low, medium, high
+ }
+ }
+ }
+
+train(**model_openai_o4_mini)
\ No newline at end of file
From 87f9b59aa0346383de3421506cbb4a79742dba2b Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Tue, 19 Aug 2025 12:26:01 -0700
Subject: [PATCH 2/3] quick fix of memory add issue, causing no execute complte
events
---
adalflow/adalflow/components/agent/agent.py | 16 +-
adalflow/adalflow/components/agent/runner.py | 127 ++++--
.../components/memory/flexible_memory.py | 362 ++++++++++++++++++
adalflow/adalflow/components/memory/memory.py | 14 +-
.../output_parsers/dataclass_parser.py | 2 +-
.../components/output_parsers/outputs.py | 1 -
adalflow/adalflow/core/types.py | 136 ++++++-
adalflow/tests/test_memory.py | 12 +-
adalflow/tests/test_openai_client.py | 202 ++++++++++
adalflow/tests/test_runner.py | 227 +++++++++++
10 files changed, 1037 insertions(+), 62 deletions(-)
create mode 100644 adalflow/adalflow/components/memory/flexible_memory.py
diff --git a/adalflow/adalflow/components/agent/agent.py b/adalflow/adalflow/components/agent/agent.py
index 96d3cc4e..ad96d9bc 100644
--- a/adalflow/adalflow/components/agent/agent.py
+++ b/adalflow/adalflow/components/agent/agent.py
@@ -18,6 +18,8 @@
from adalflow.core.tool_manager import ToolManager
from adalflow.core.prompt_builder import Prompt
from adalflow.core.types import GeneratorOutput, ModelType, Function
+from adalflow.core.base_data_class import DataClass, DataClassFormatType
+
from adalflow.optim.parameter import Parameter, ParameterType
from adalflow.components.output_parsers import JsonOutputParser
from adalflow.utils import printc
@@ -54,7 +56,7 @@
# '
-
+# TODO: replace the agent to pydantic Function Model. But it cant control the fields
# the context will wrap the whole component
def create_default_tool_manager(
# Tool manager parameters
@@ -154,9 +156,17 @@ def create_default_planner(
include_fields = ["name", "kwargs", "_is_answer_final", "_answer"]
else:
include_fields = ["thought", "name", "kwargs", "_is_answer_final", "_answer"]
+
+ examples = [
+ Function(
+ name="example_function",
+ kwargs={"param1": "value1", "param2": "value2"},
+ _is_answer_final=False,
+ _answer=None,)
+ ]
output_parser = JsonOutputParser(
data_class=ouput_data_class,
- examples=None,
+ examples=examples,
# examples=self._examples, # TODO: add examples
return_data_class=True,
include_fields=include_fields,
@@ -169,7 +179,7 @@ def create_default_planner(
prompt_kwargs = {
"tools": tool_manager.yaml_definitions,
- "output_format_str": output_parser.format_instructions(),
+ "output_format_str": output_parser.format_instructions(format_type=DataClassFormatType.SIGNATURE_JSON),
"task_desc": Parameter(
name="react_agent_task_desc",
data=task_desc,
diff --git a/adalflow/adalflow/components/agent/runner.py b/adalflow/adalflow/components/agent/runner.py
index 2fc486c2..5ecfd068 100644
--- a/adalflow/adalflow/components/agent/runner.py
+++ b/adalflow/adalflow/components/agent/runner.py
@@ -20,7 +20,7 @@
AsyncIterable,
)
from typing_extensions import TypeAlias
-import sys
+import uuid
from adalflow.optim.parameter import Parameter
@@ -323,7 +323,7 @@ def _check_last_step(self, step: Function) -> bool:
def _get_final_answer(self, function: Function) -> Any:
"""Get and process the final answer from the function."""
- if hasattr(function, "_answer"):
+ if hasattr(function, "_answer") or (hasattr(function, "_is_answer_final") and function._is_answer_final):
return self._process_data(function._answer)
return None
@@ -338,30 +338,44 @@ def _create_runner_result(self, answer: Any, step_history, error: Optional[str]
)
def _create_execution_complete_stream_event(self, streaming_result: RunnerStreamingResult, final_output_item: FinalOutputItem):
"""Complete the streaming execution by adding a sentinel."""
- final_output_event = RunItemStreamEvent(
- name="agent.execution_complete",
- item=final_output_item,
- )
- streaming_result.put_nowait(final_output_event)
+ try:
+ final_output_event = RunItemStreamEvent(
+ name="agent.execution_complete",
+ item=final_output_item,
+ )
+ streaming_result.put_nowait(final_output_event)
- runner_result: RunnerResult = final_output_item.data
+ runner_result: RunnerResult = final_output_item.data
- # set up the final answer
- streaming_result.answer = runner_result.answer if runner_result else None
- streaming_result.step_history = self.step_history.copy()
- streaming_result._is_complete = True
+ # set up the final answer
+ streaming_result.answer = runner_result.answer if runner_result else None
+ streaming_result.step_history = self.step_history.copy()
+ streaming_result._is_complete = True
+
+ except Exception as e:
+ log.error(f"Failed to create execution complete stream event: {e}")
+ raise e
- def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem):
+ def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem, turn_id: Any = None):
# add the assistant response to the conversation memory
- if self.use_conversation_memory and self.conversation_memory._pending_user_query is not None:
- self.conversation_memory.add_assistant_response(
- AssistantResponse(
- response_str=final_output_item.data.answer,
- metadata={
- "step_history": final_output_item.data.step_history.copy()
- },
+ try:
+ if self.use_conversation_memory and turn_id is not None:
+ # Only add if we have a valid turn_id
+ self.conversation_memory.add_assistant_response(
+ AssistantResponse(
+ response_str=final_output_item.data.answer,
+ metadata={
+ "step_history": final_output_item.data.step_history.copy()
+ },
+ ),
+ turn_id=turn_id
)
- )
+ elif self.use_conversation_memory:
+ log.warning("Skipping add_assistant_response - no turn_id available")
+ except Exception as e:
+ log.error(f"Failed to add assistant response to memory: {e}")
+ # Don't re-raise, just log the error
+ return
def create_response_span(self, runner_result, step_count: int, streaming_result: RunnerStreamingResult, runner_span_instance, workflow_status: str = "stream_completed"):
@@ -395,7 +409,8 @@ async def _process_stream_final_step(
answer: Any,
step_count: int,
streaming_result,
- runner_span_instance
+ runner_span_instance,
+ turn_id: Any = None,
) -> FinalOutputItem:
"""Process the final step and trace it."""
# processed_data = self._get_final_answer(function)
@@ -422,8 +437,13 @@ async def _process_stream_final_step(
self._create_execution_complete_stream_event(
streaming_result, final_output_item
)
+
+ # Ensure event is in queue by yielding control briefly
+ # This allows the event loop to process the put_nowait before we return
+ await asyncio.sleep(0) # Small delay to ensure event is properly queued
+
# add the assistant response to the conversation memory
- self._add_assistant_response_to_memory(final_output_item)
+ self._add_assistant_response_to_memory(final_output_item, turn_id)
return final_output_item
# TODO: improved after the finish function is refactored
@@ -561,6 +581,7 @@ def call(
self.step_history
) # a reference to the step history
+ turn_id = None
if self.use_conversation_memory:
# Reset any pending query state before starting a new query
self.conversation_memory.reset_pending_query()
@@ -570,7 +591,7 @@ def call(
# meta data is all keys in the list of context_str
query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
- self.conversation_memory.add_user_query(
+ turn_id = self.conversation_memory.add_user_query(
UserQuery(
query_str=prompt_kwargs.get("input_str", None),
metadata=query_metadata,
@@ -672,14 +693,15 @@ def call(
)
# Add assistant response to conversation memory
- if self.use_conversation_memory:
+ if self.use_conversation_memory and turn_id is not None:
self.conversation_memory.add_assistant_response(
AssistantResponse(
response_str=processed_data,
metadata={
"step_history": self.step_history.copy()
},
- )
+ ),
+ turn_id=turn_id
)
step_count += 1 # Increment step count before breaking
@@ -917,6 +939,7 @@ async def acall(
self.step_history
) # a reference to the step history
+ turn_id = None
if self.use_conversation_memory:
# Reset any pending query state before starting a new query
self.conversation_memory.reset_pending_query()
@@ -926,7 +949,7 @@ async def acall(
# meta data is all keys in the list of context_str
query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
- self.conversation_memory.add_user_query(
+ turn_id = self.conversation_memory.add_user_query(
UserQuery(
query_str=prompt_kwargs.get("input_str", None),
metadata=query_metadata,
@@ -1043,14 +1066,15 @@ async def acall(
)
# Add assistant response to conversation memory
- if self.use_conversation_memory:
+ if self.use_conversation_memory and turn_id is not None:
self.conversation_memory.add_assistant_response(
AssistantResponse(
response_str=answer,
metadata={
"step_history": self.step_history.copy()
},
- )
+ ),
+ turn_id=turn_id
)
@@ -1268,6 +1292,7 @@ async def impl_astream(
"""
workflow_status: Literal["streaming", "stream_completed", "stream_failed", "stream_incomplete"] = "streaming"
# Create runner span for tracing streaming execution
+ turn_id = None
with runner_span(
runner_id=id or f"stream_runner_{hash(str(prompt_kwargs))}",
max_steps=self.max_steps,
@@ -1288,7 +1313,7 @@ async def impl_astream(
# meta data is all keys in the list of context_str
query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
- self.conversation_memory.add_user_query(
+ turn_id = self.conversation_memory.add_user_query(
UserQuery(
query_str=prompt_kwargs.get("input_str", None),
metadata=query_metadata,
@@ -1386,10 +1411,10 @@ async def impl_astream(
else: # non-streaming cases
# yield the final planner response
- if output.data is None:
+ if output.data is None or (not isinstance(output.data, Function)):
# recoverable errors, continue to create stepout
- current_error = output.error
+ current_error = f"Error: {output.error} - data: {output.data}, raw_response: {output.raw_response}"
# wrap the error in a RawResponsesStreamEvent
wrapped_event = RawResponsesStreamEvent(
data=None, # no data in this case
@@ -1411,6 +1436,10 @@ async def impl_astream(
item=step_item,
)
streaming_result.put_nowait(step_complete_event)
+
+ # Ensure event is processed before continuing
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
self.step_history.append(step_output)
if output.error is not None:
@@ -1467,14 +1496,26 @@ async def impl_astream(
log.debug(f"function: {function}")
if self._check_last_step(function): # skip stepoutput
- answer = self._get_final_answer(function)
+ try:
+ answer = self._get_final_answer(function)
+ except Exception as e:
+ # If processing the final answer fails, use the raw answer
+ log.warning(f"Failed to process final answer: {e}. Using raw answer.")
+ answer = function._answer if hasattr(function, "_answer") else str(e)
+
final_output_item = await self._process_stream_final_step(
answer=answer,
step_count=step_count,
streaming_result=streaming_result,
runner_span_instance=runner_span_instance,
+ turn_id=turn_id,
)
workflow_status = "stream_completed"
+
+ # Ensure the queue has processed the execution_complete event
+ # Add a small yield to allow the event loop to process the queued events
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
break
# Check if permission is required and emit permission event
@@ -1556,6 +1597,10 @@ async def impl_astream(
name="agent.step_complete", item=step_item
)
streaming_result.put_nowait(step_event)
+
+ # Ensure event is processed before continuing
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
step_count += 1
except asyncio.CancelledError:
@@ -1580,7 +1625,7 @@ async def impl_astream(
streaming_result._is_complete = True
# Add cancellation response to conversation memory
- if self.use_conversation_memory:
+ if self.use_conversation_memory and turn_id is not None:
self.conversation_memory.add_assistant_response(
AssistantResponse(
response_str="I apologize, but the execution was cancelled by the user.",
@@ -1589,7 +1634,8 @@ async def impl_astream(
"status": "cancelled",
"timestamp": datetime.now().isoformat()
}
- )
+ ),
+ turn_id=turn_id
)
# Signal completion and break
@@ -1627,6 +1673,12 @@ async def impl_astream(
workflow_status = "stream_incomplete"
current_error = f"No output generated after {step_count} steps (max_steps: {self.max_steps})"
+ # Only emit execution_complete if we created a new final_output_item
+ # (i.e., when the loop ended without a final answer)
+ self._create_execution_complete_stream_event(
+ streaming_result=streaming_result,
+ final_output_item=final_output_item,
+ )
runner_span_instance.span_data.update_attributes(
{
@@ -1644,11 +1696,6 @@ async def impl_astream(
error=current_error,
)
- self._create_execution_complete_stream_event(
- streaming_result=streaming_result,
- final_output_item=final_output_item,
- )
-
# create response span for final output
# if workflow_status in ["stream_incomplete", "stream_failed"]:
self.create_response_span(
diff --git a/adalflow/adalflow/components/memory/flexible_memory.py b/adalflow/adalflow/components/memory/flexible_memory.py
new file mode 100644
index 00000000..d6adbec5
--- /dev/null
+++ b/adalflow/adalflow/components/memory/flexible_memory.py
@@ -0,0 +1,362 @@
+"""Flexible conversation memory with turns containing multiple messages.
+
+This memory design uses an OrderedDict where each turn_id maps to a list of messages.
+This allows multiple user queries and assistant responses within the same turn.
+"""
+
+from uuid import uuid4
+from typing import Optional, List, Dict, Any, Literal
+from dataclasses import dataclass, field
+from datetime import datetime
+from collections import OrderedDict
+from adalflow.core.component import Component
+from adalflow.core.db import LocalDB
+from adalflow.core.types import DataClass
+from adalflow.core.prompt_builder import Prompt
+
+
+@dataclass
+class Message(DataClass):
+ """A single message in a conversation."""
+ id: str = field(default_factory=lambda: str(uuid4()))
+ role: Literal["user", "assistant", "system"] = "user"
+ content: str = ""
+ metadata: Optional[Dict[str, Any]] = None
+ timestamp: datetime = field(default_factory=datetime.now)
+
+ @classmethod
+ def from_user(cls, content: str, metadata: Optional[Dict] = None):
+ """Create a user message."""
+ return cls(role="user", content=content, metadata=metadata)
+
+ @classmethod
+ def from_assistant(cls, content: str, metadata: Optional[Dict] = None):
+ """Create an assistant message."""
+ return cls(role="assistant", content=content, metadata=metadata)
+
+ @classmethod
+ def from_system(cls, content: str, metadata: Optional[Dict] = None):
+ """Create a system message."""
+ return cls(role="system", content=content, metadata=metadata)
+
+
+@dataclass
+class Conversation(DataClass):
+ """A conversation organized as turns, where each turn can have multiple messages."""
+ id: str = field(default_factory=lambda: str(uuid4()))
+ user_id: Optional[str] = None
+ turns: OrderedDict = field(default_factory=OrderedDict) # turn_id -> List[Message]
+ metadata: Optional[Dict[str, Any]] = None
+ created_at: datetime = field(default_factory=datetime.now)
+ _current_turn_id: Optional[str] = field(default=None, init=False)
+
+ def add_message_to_turn(self, turn_id: str, message: Message) -> str:
+ """Add a message to a specific turn.
+
+ Args:
+ turn_id: The turn identifier
+ message: The message to add
+
+ Returns:
+ str: The message ID
+ """
+ if turn_id not in self.turns:
+ self.turns[turn_id] = []
+ self.turns[turn_id].append(message)
+ return message.id
+
+ def get_turn_messages(self, turn_id: str) -> List[Message]:
+ """Get all messages in a specific turn."""
+ return self.turns.get(turn_id, [])
+
+ def get_all_messages(self) -> List[Message]:
+ """Get all messages in order across all turns."""
+ messages = []
+ for turn_messages in self.turns.values():
+ messages.extend(turn_messages)
+ return messages
+
+ def get_messages_by_role(self, role: str) -> List[Message]:
+ """Get all messages from a specific role."""
+ messages = []
+ for turn_messages in self.turns.values():
+ messages.extend([msg for msg in turn_messages if msg.role == role])
+ return messages
+
+ def get_last_user_message(self) -> Optional[Message]:
+ """Get the most recent user message."""
+ for turn_messages in reversed(list(self.turns.values())):
+ for msg in reversed(turn_messages):
+ if msg.role == "user":
+ return msg
+ return None
+
+ def get_last_assistant_message(self) -> Optional[Message]:
+ """Get the most recent assistant message."""
+ for turn_messages in reversed(list(self.turns.values())):
+ for msg in reversed(turn_messages):
+ if msg.role == "assistant":
+ return msg
+ return None
+
+ def create_turn(self) -> str:
+ """Create a new turn and return its ID."""
+ turn_id = str(uuid4())
+ self.turns[turn_id] = []
+ self._current_turn_id = turn_id
+ return turn_id
+
+
+# Template for conversation formatting
+CONVERSATION_TEMPLATE = r"""
+{% for turn_id, messages in turns.items() -%}
+{% for message in messages -%}
+{% if message.role == "user" -%}
+User: {{ message.content }}
+{% if message.metadata -%}
+{% for key, value in message.metadata.items() -%}
+{% if not metadata_filter or key in metadata_filter -%}
+{{ key }}: {{ value }}
+{% endif -%}
+{% endfor -%}
+{% endif -%}
+{% elif message.role == "assistant" -%}
+Assistant: {{ message.content }}
+{% if message.metadata -%}
+{% for key, value in message.metadata.items() -%}
+{% if not metadata_filter or key in metadata_filter -%}
+{{ key }}: {{ value }}
+{% endif -%}
+{% endfor -%}
+{% endif -%}
+{% elif message.role == "system" -%}
+System: {{ message.content }}
+{% endif -%}
+{% endfor -%}
+{% endfor -%}"""
+
+
+class FlexibleConversationMemory(Component):
+ """A flexible conversation memory with turns containing multiple messages."""
+
+ def __init__(self, turn_db: LocalDB = None, user_id: str = None):
+ """Initialize the flexible memory component.
+
+ Args:
+ turn_db: Database for storing messages
+ user_id: Optional user identifier
+ """
+ super().__init__()
+ self.current_conversation = Conversation(user_id=user_id)
+ self.message_db = turn_db or LocalDB() # Store all messages
+ self.conver_db = LocalDB() # Store complete conversations
+ self.user_id = user_id
+
+ def clear_conversation(self):
+ """Clear all turns and messages in the current conversation."""
+ self.current_conversation.turns.clear()
+ self.current_conversation._current_turn_id = None
+
+ def clear_conversation_turns(self):
+ """Alias for clear_conversation for compatibility."""
+ self.clear_conversation()
+
+ def new_conversation(self):
+ """Start a new conversation, saving the current one."""
+ # Save current conversation if it has messages
+ if self.current_conversation.turns:
+ self.conver_db.add(self.current_conversation)
+
+ # Create new conversation
+ self.current_conversation = Conversation(user_id=self.user_id)
+
+ def create_turn(self) -> str:
+ """Create a new turn and return its ID.
+
+ Returns:
+ str: The new turn ID
+ """
+ return self.current_conversation.create_turn()
+
+ def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str:
+ """Add a user message to a turn.
+
+ Args:
+ content: The user's message content
+ metadata: Optional metadata
+ turn_id: Optional turn ID. If None, creates a new turn.
+
+ Returns:
+ str: The turn ID the message was added to
+ """
+ # Use provided turn_id or create new turn
+ if turn_id is None:
+ turn_id = self.create_turn()
+ elif turn_id not in self.current_conversation.turns:
+ # Turn doesn't exist, create it
+ self.current_conversation.turns[turn_id] = []
+
+ # Track as current turn
+ self.current_conversation._current_turn_id = turn_id
+
+ # Create and add the user message
+ message = Message.from_user(content, metadata)
+ self.current_conversation.add_message_to_turn(turn_id, message)
+
+ # Store in database
+ self.message_db.add({
+ "message_id": message.id,
+ "turn_id": turn_id,
+ "role": "user",
+ "content": content,
+ "metadata": metadata,
+ "timestamp": message.timestamp
+ })
+
+ return turn_id
+
+ def add_assistant_response(
+ self,
+ content: str,
+ metadata: Optional[Dict] = None,
+ turn_id: Optional[str] = None
+ ) -> str:
+ """Add an assistant message to a turn.
+
+ Args:
+ content: The assistant's message content
+ metadata: Optional metadata
+ turn_id: Optional turn ID. If None, uses current turn or creates new.
+
+ Returns:
+ str: The turn ID the message was added to
+ """
+ # Determine which turn to use
+ if turn_id is None:
+ if self.current_conversation._current_turn_id:
+ turn_id = self.current_conversation._current_turn_id
+ else:
+ # No active turn, create new one for standalone response
+ turn_id = self.create_turn()
+ elif turn_id not in self.current_conversation.turns:
+ # Turn doesn't exist, create it
+ self.current_conversation.turns[turn_id] = []
+
+ # Create and add the assistant message
+ message = Message.from_assistant(content, metadata)
+ self.current_conversation.add_message_to_turn(turn_id, message)
+
+ # Store in database
+ self.message_db.add({
+ "message_id": message.id,
+ "turn_id": turn_id,
+ "role": "assistant",
+ "content": content,
+ "metadata": metadata,
+ "timestamp": message.timestamp
+ })
+
+ return turn_id
+
+ def add_system_message(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str:
+ """Add a system message to a turn.
+
+ Args:
+ content: The system message content
+ metadata: Optional metadata
+ turn_id: Optional turn ID. If None, creates a new turn.
+
+ Returns:
+ str: The turn ID the message was added to
+ """
+ # Use provided turn_id or create new turn
+ if turn_id is None:
+ turn_id = self.create_turn()
+ elif turn_id not in self.current_conversation.turns:
+ self.current_conversation.turns[turn_id] = []
+
+ # Create and add the system message
+ message = Message.from_system(content, metadata)
+ self.current_conversation.add_message_to_turn(turn_id, message)
+
+ # Store in database
+ self.message_db.add({
+ "message_id": message.id,
+ "turn_id": turn_id,
+ "role": "system",
+ "content": content,
+ "metadata": metadata,
+ "timestamp": message.timestamp
+ })
+
+ return turn_id
+
+ def get_turn_messages(self, turn_id: str) -> List[Message]:
+ """Get all messages for a specific turn.
+
+ Args:
+ turn_id: The turn identifier
+
+ Returns:
+ List of messages in that turn
+ """
+ return self.current_conversation.get_turn_messages(turn_id)
+
+ def get_current_turn_id(self) -> Optional[str]:
+ """Get the current turn ID if any."""
+ return self.current_conversation._current_turn_id
+
+ def call(self, metadata_filter: Optional[List[str]] = None) -> str:
+ """Get the conversation history as a formatted string.
+
+ Args:
+ metadata_filter: Optional list of metadata keys to include
+
+ Returns:
+ str: Formatted conversation history
+ """
+ if not self.current_conversation.turns:
+ return ""
+
+ prompt = Prompt(
+ template=CONVERSATION_TEMPLATE,
+ prompt_kwargs={
+ "turns": self.current_conversation.turns,
+ "metadata_filter": metadata_filter,
+ },
+ )
+ return prompt.call().strip()
+
+ def get_all_messages(self) -> List[Message]:
+ """Get all messages in order across all turns."""
+ return self.current_conversation.get_all_messages()
+
+ def get_last_n_messages(self, n: int) -> List[Message]:
+ """Get the last n messages across all turns."""
+ messages = self.get_all_messages()
+ return messages[-n:] if len(messages) >= n else messages
+
+ def count_messages(self) -> Dict[str, int]:
+ """Count messages by role."""
+ counts = {"user": 0, "assistant": 0, "system": 0}
+ for msg in self.get_all_messages():
+ counts[msg.role] = counts.get(msg.role, 0) + 1
+ return counts
+
+ def count_turns(self) -> int:
+ """Count the number of turns."""
+ return len(self.current_conversation.turns)
+
+ def reset_pending_query(self):
+ """Reset current turn tracking. Included for compatibility."""
+ # Don't clear the current turn, just unset it as "current"
+ # This allows adding more messages to existing turns if needed
+ self.current_conversation._current_turn_id = None
+
+ def __call__(self, metadata_filter: Optional[List[str]] = None) -> str:
+ """Make the memory callable to get conversation history."""
+ return self.call(metadata_filter)
+
+ def __len__(self) -> int:
+ """Return the number of messages."""
+ return len(self.get_all_messages())
\ No newline at end of file
diff --git a/adalflow/adalflow/components/memory/memory.py b/adalflow/adalflow/components/memory/memory.py
index e30743a3..3b65cdd0 100644
--- a/adalflow/adalflow/components/memory/memory.py
+++ b/adalflow/adalflow/components/memory/memory.py
@@ -177,15 +177,17 @@ def add_assistant_response(
assistant_response (Union[str, AssistantResponse]): The assistant's response message.
Returns:
- str: The ID of the completed dialog turn.
-
- Raises:
- ValueError: If there's no pending user query to respond to.
+ str: The ID of the completed dialog turn, or None if no pending query exists.
"""
if self._pending_user_query is None:
- raise ValueError(
- "No pending user query found. Please add a user query first."
+ # Log a warning instead of raising an error
+ import logging
+ logging.warning(
+ "No pending user query found when adding assistant response. "
+ "This might happen if the response was already added or the conversation was reset. "
+ "Ignoring this assistant response to avoid duplication."
)
+ return None # Return None to indicate no turn was added
assistant_response = (
assistant_response
diff --git a/adalflow/adalflow/components/output_parsers/dataclass_parser.py b/adalflow/adalflow/components/output_parsers/dataclass_parser.py
index 32c80bc9..7e4647e5 100644
--- a/adalflow/adalflow/components/output_parsers/dataclass_parser.py
+++ b/adalflow/adalflow/components/output_parsers/dataclass_parser.py
@@ -135,7 +135,7 @@ def get_output_format_str(self) -> str:
if self._format_type == "yaml":
schema = self._data_class.to_yaml_signature(include=self._output_fields)
output_format_str = Prompt(template=YAML_OUTPUT_FORMAT)(schema=schema)
- else:
+ elif self._format_type == "json":
schema = self._data_class.to_json_signature(include=self._output_fields)
output_format_str = Prompt(template=JSON_OUTPUT_FORMAT)(schema=schema)
return output_format_str
diff --git a/adalflow/adalflow/components/output_parsers/outputs.py b/adalflow/adalflow/components/output_parsers/outputs.py
index 1378b428..6286a60e 100644
--- a/adalflow/adalflow/components/output_parsers/outputs.py
+++ b/adalflow/adalflow/components/output_parsers/outputs.py
@@ -57,7 +57,6 @@
- Properly escape special characters: use \\" for quotes, \\\\ for backslashes
- For multiline strings, keep them on a single line with \\n characters
**WARNING:** The JSON must be parseable by standard JSON parsers. Malformed JSON will cause parsing failures. When handling complex text with special characters, quotes, or formatting, prioritize proper escaping over readability.
-
"""
"""**CRITICAL JSON FORMATTING REQUIREMENTS:**
diff --git a/adalflow/adalflow/core/types.py b/adalflow/adalflow/core/types.py
index 8d7b4f2b..cea05b36 100644
--- a/adalflow/adalflow/core/types.py
+++ b/adalflow/adalflow/core/types.py
@@ -27,6 +27,7 @@
field,
InitVar,
)
+from pydantic import BaseModel, Field as PydanticField
from uuid import UUID
from datetime import datetime
import uuid
@@ -449,6 +450,128 @@ def add(a, b):
__output_fields__ = ["thought", "name", "kwargs", "_is_answer_final", "_answer"]
+class FunctionPydantic(BaseModel):
+ """The data modeling of a function call using Pydantic BaseModel.
+
+ This is a Pydantic-based version of the Function class that uses BaseModel
+ and Field for validation and serialization.
+
+ Example:
+ >>> def add(a, b):
+ ... return a + b
+ >>>
+ >>> # Create function call with kwargs
+ >>> func = FunctionPydantic(name="add", kwargs={"a": 1, "b": 2})
+ >>> # Evaluate the function
+ >>> result = context_map[func.name](**func.kwargs)
+ >>>
+ >>> # Create function call with positional args
+ >>> func = FunctionPydantic(name="add", args=[1, 2])
+ >>> result = context_map[func.name](*func.args)
+ """
+
+ id: Optional[str] = PydanticField(
+ default=None,
+ description="The id of the function call"
+ )
+ thought: Optional[str] = PydanticField(
+ default=None,
+ description="Your reasoning for this step. Be short for simple queries. For complex queries, provide a clear chain of thought."
+ )
+ name: str = PydanticField(
+ default="",
+ description="The name of the function"
+ )
+ args: Optional[List[Any]] = PydanticField(
+ default_factory=list,
+ description="The positional arguments of the function"
+ )
+ kwargs: Optional[Dict[str, Any]] = PydanticField(
+ default_factory=dict,
+ description="The keyword arguments of the function"
+ )
+ is_answer_final: Optional[bool] = PydanticField(
+ default=None,
+ alias="_is_answer_final",
+ description="Whether this current output is the final answer"
+ )
+ answer: Optional[Any] = PydanticField(
+ default=None,
+ alias="_answer",
+ description="The final answer if this is the final output"
+ )
+
+ class Config:
+ populate_by_name = True # Allow population by field name or alias
+ json_encoders = {
+ datetime: lambda v: v.isoformat(),
+ }
+
+ @classmethod
+ def from_function(
+ cls,
+ func: Union[Callable[..., Any], AsyncCallable],
+ thought: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> "FunctionPydantic":
+ """Create a FunctionPydantic object from a function.
+
+ Args:
+ func: The function to be converted
+ thought: Optional reasoning for this function call
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ FunctionPydantic: The FunctionPydantic object
+
+ Example:
+ >>> def add(a, b):
+ ... return a + b
+ >>>
+ >>> # Create with positional arguments
+ >>> func = FunctionPydantic.from_function(add, "Add two numbers", 1, 2)
+ >>> print(func)
+ >>> # FunctionPydantic(thought='Add two numbers', name='add', args=[1, 2])
+ """
+ return cls(
+ thought=thought,
+ name=func.__name__,
+ args=list(args) if args else [],
+ kwargs=kwargs if kwargs else {},
+ )
+
+ def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]:
+ """Convert to dictionary representation.
+
+ Args:
+ exclude_none: Whether to exclude None values
+
+ Returns:
+ Dictionary representation of the function call
+ """
+ data = self.model_dump(exclude_none=exclude_none)
+ # Include aliased fields with their original names if present
+ if self.is_answer_final is not None:
+ data['_is_answer_final'] = self.is_answer_final
+ if self.answer is not None:
+ data['_answer'] = self.answer
+ return data
+
+ def to_json(self, exclude_none: bool = True, indent: int = 2) -> str:
+ """Convert to JSON string representation.
+
+ Args:
+ exclude_none: Whether to exclude None values
+ indent: JSON indentation level
+
+ Returns:
+ JSON string representation
+ """
+ return self.model_dump_json(exclude_none=exclude_none, indent=indent)
+
+
_action_desc = """FuncName() \
Valid function call expression. \
Example: "FuncName(a=1, b=2)" \
@@ -1634,12 +1757,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
yield event
# mark the task as done
self._event_queue.task_done()
- # if the event is a RunItemStreamEvent and the name is agent.execution_complete then additionally break the loop
- if (
- isinstance(event, RunItemStreamEvent)
- and event.name == "agent.execution_complete"
- ):
- break
+ # Don't break on agent.execution_complete - let QueueCompleteSentinel handle the break
+ # This ensures the execution_complete event is properly processed by all consumers
+ # if (
+ # isinstance(event, RunItemStreamEvent)
+ # and event.name == "agent.execution_complete"
+ # ):
+ # break
except asyncio.CancelledError:
# Clean up and re-raise to allow proper cancellation
diff --git a/adalflow/tests/test_memory.py b/adalflow/tests/test_memory.py
index 3680f683..20093b5e 100644
--- a/adalflow/tests/test_memory.py
+++ b/adalflow/tests/test_memory.py
@@ -98,13 +98,15 @@ def test_error_double_user_query():
def test_error_assistant_response_without_query():
- """Test error when adding assistant response without user query."""
+ """Test that adding assistant response without user query returns None and logs warning."""
memory = Memory()
- with pytest.raises(ValueError) as exc_info:
- memory.add_assistant_response("Random response")
-
- assert "No pending user query found" in str(exc_info.value)
+ # Should return None instead of raising error
+ result = memory.add_assistant_response("Random response")
+ assert result is None
+
+ # The conversation should remain empty
+ assert len(memory.current_conversation.dialog_turns) == 0
def test_mixed_methods():
diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py
index e0b50a39..3946320f 100644
--- a/adalflow/tests/test_openai_client.py
+++ b/adalflow/tests/test_openai_client.py
@@ -761,6 +761,208 @@ def test_multiple_images_input(self):
self.assertEqual(image_contents[1]["image_url"], "https://example.com/image2.jpg")
self.assertTrue(image_contents[2]["image_url"].startswith("data:image/png;base64,"))
+ @patch(
+ "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client"
+ )
+ @patch("adalflow.components.model_client.openai_client.OpenAI")
+ def test_sync_streaming_response_format(self, MockSyncOpenAI, mock_init_sync_client):
+ """Test that sync streaming returns raw_response as generator and data field contains complete result."""
+ mock_sync_client = Mock()
+ MockSyncOpenAI.return_value = mock_sync_client
+ mock_init_sync_client.return_value = mock_sync_client
+
+ # Create streaming events generator
+ def mock_stream():
+ for event in self.streaming_events:
+ yield event
+
+ mock_sync_client.responses.create.return_value = mock_stream()
+ self.client.sync_client = mock_sync_client
+
+ # Test streaming API kwargs
+ api_kwargs = {
+ "model": "gpt-4",
+ "input": "Tell me a story",
+ "stream": True
+ }
+
+ # Call sync streaming method
+ result = self.client.call(api_kwargs, ModelType.LLM)
+
+ # Verify raw response is a generator
+ self.assertTrue(hasattr(result, '__iter__'), "raw_response should be iterable/generator")
+
+ # Parse the response through the client's parser
+ parsed_result = self.client.parse_chat_completion(result)
+
+ # Verify structure: raw_response should be the generator, data should be None for streaming
+ self.assertIsNotNone(parsed_result.raw_response, "raw_response should contain the generator")
+ self.assertIsNone(parsed_result.data, "data should be None for streaming responses")
+
+ # Consume the generator to verify it works
+ events_collected = list(result)
+ self.assertEqual(len(events_collected), 3, "Should collect all streaming events")
+
+ # Verify the final completed event contains the full result
+ final_event = events_collected[-1]
+ self.assertEqual(final_event.type, "response.completed")
+ self.assertEqual(final_event.response.output_text, "Once upon ")
+
+ async def test_async_streaming_response_format(self):
+ """Test that async streaming returns raw_response as async generator and data field handling."""
+ mock_async_client = AsyncMock()
+
+ # Create async streaming events generator
+ async def mock_async_stream():
+ for event in self.streaming_events:
+ yield event
+ await asyncio.sleep(0.001) # Small delay to simulate real streaming
+
+ mock_async_client.responses.create.return_value = mock_async_stream()
+ self.client.async_client = mock_async_client
+
+ # Test streaming API kwargs
+ api_kwargs = {
+ "model": "gpt-4",
+ "input": "Tell me a story",
+ "stream": True
+ }
+
+ # Call async streaming method
+ result = await self.client.acall(api_kwargs, ModelType.LLM)
+
+ # Verify raw response is an async generator
+ self.assertTrue(hasattr(result, '__aiter__'), "raw_response should be async iterable")
+
+ # Parse the response through the client's parser
+ parsed_result = self.client.parse_chat_completion(result)
+
+ # Verify structure: raw_response should contain the generator, data should be None for streaming
+ self.assertIsNotNone(parsed_result.raw_response, "raw_response should contain the async generator")
+ self.assertIsNone(parsed_result.data, "data should be None for streaming responses")
+
+ # Consume the async generator to verify it works and extract complete text
+ events_collected = []
+ complete_text = ""
+
+ async for event in result:
+ events_collected.append(event)
+ # Extract text from streaming events
+ if hasattr(event, 'delta'):
+ complete_text += event.delta
+ elif hasattr(event, 'response') and hasattr(event.response, 'output_text'):
+ complete_text = event.response.output_text # Final complete result
+
+ self.assertEqual(len(events_collected), 3, "Should collect all streaming events")
+ self.assertEqual(complete_text, "Once upon ", "Should extract complete text from stream")
+
+ # Verify the final completed event contains the full result
+ final_event = events_collected[-1]
+ self.assertEqual(final_event.type, "response.completed")
+ self.assertEqual(final_event.response.output_text, "Once upon ")
+
+ async def test_streaming_text_extraction_and_final_result(self):
+ """Test that streaming properly extracts incremental text and provides final complete result."""
+ mock_async_client = AsyncMock()
+
+ # Create more comprehensive streaming events with incremental text
+ async def comprehensive_mock_stream():
+ # Start event
+ start_event = Mock()
+ start_event.type = "response.created"
+ yield start_event
+
+ # Multiple delta events building up text
+ delta_texts = ["Hello", " there!", " How", " are", " you?"]
+ for delta_text in delta_texts:
+ delta_event = Mock()
+ delta_event.type = "response.output_text.delta"
+ delta_event.delta = delta_text
+ yield delta_event
+ await asyncio.sleep(0.001)
+
+ # Final completion event with complete text
+ complete_response = Mock()
+ complete_response.id = "resp-final"
+ complete_response.model = "gpt-4"
+ complete_response.output_text = "Hello there! How are you?"
+ complete_response.usage = ResponseUsage(
+ input_tokens=5, output_tokens=6, total_tokens=11,
+ input_tokens_details={"cached_tokens": 0},
+ output_tokens_details={"reasoning_tokens": 0}
+ )
+
+ completion_event = Mock()
+ completion_event.type = "response.completed"
+ completion_event.response = complete_response
+ yield completion_event
+
+ mock_async_client.responses.create.return_value = comprehensive_mock_stream()
+ self.client.async_client = mock_async_client
+
+ api_kwargs = {
+ "model": "gpt-4",
+ "input": "Say hello",
+ "stream": True
+ }
+
+ # Get streaming result
+ result = await self.client.acall(api_kwargs, ModelType.LLM)
+
+ # Verify it's an async generator
+ self.assertTrue(hasattr(result, '__aiter__'))
+
+ # Process stream and extract text using the utility function
+ incremental_text = ""
+ final_complete_text = None
+ event_count = 0
+
+ async for event in result:
+ event_count += 1
+
+ # Use the utility function to extract text
+ text_fragment = extract_text_from_response_stream(event)
+ if text_fragment:
+ incremental_text += text_fragment
+
+ # Check for final completion
+ if hasattr(event, 'type') and event.type == "response.completed":
+ if hasattr(event, 'response') and hasattr(event.response, 'output_text'):
+ final_complete_text = event.response.output_text
+
+ # Verify results
+ self.assertEqual(event_count, 7, "Should have start + 5 deltas + completion events")
+ self.assertEqual(incremental_text, "Hello there! How are you?", "Incremental text should match")
+ self.assertEqual(final_complete_text, "Hello there! How are you?", "Final complete text should match")
+ self.assertEqual(incremental_text, final_complete_text, "Incremental and final text should be identical")
+
+ def test_streaming_vs_non_streaming_data_field_behavior(self):
+ """Test that data field behavior differs correctly between streaming and non-streaming responses."""
+ # Test non-streaming: data field should contain the result
+ non_streaming_result = self.client.parse_chat_completion(self.mock_response)
+ self.assertIsNotNone(non_streaming_result.data, "Non-streaming should have data field populated")
+ self.assertEqual(non_streaming_result.data, "Hello, world!", "Data should contain response text")
+ self.assertEqual(non_streaming_result.raw_response, "Hello, world!", "Raw response should match data for non-streaming")
+
+ # Test streaming: simulate the actual client behavior for streaming
+ # Set the parser to streaming mode first
+ original_parser = self.client.response_parser
+ self.client.response_parser = self.client.streaming_response_parser_sync
+
+ def mock_generator():
+ yield from self.streaming_events
+
+ streaming_result = self.client.parse_chat_completion(mock_generator())
+ self.assertIsNone(streaming_result.data, "Streaming should have data field as None")
+ self.assertIsNotNone(streaming_result.raw_response, "Streaming should have raw_response as generator")
+
+ # Verify we can iterate over the raw_response
+ events_from_raw = list(streaming_result.raw_response)
+ self.assertEqual(len(events_from_raw), 3, "Should be able to consume raw_response generator")
+
+ # Restore original parser
+ self.client.response_parser = original_parser
+
if __name__ == "__main__":
unittest.main()
diff --git a/adalflow/tests/test_runner.py b/adalflow/tests/test_runner.py
index cb098504..08e289a9 100644
--- a/adalflow/tests/test_runner.py
+++ b/adalflow/tests/test_runner.py
@@ -1082,6 +1082,233 @@ class TestModel(BaseModel):
runner._process_data('"just a string"')
self.assertIn("Expected dict after JSON parsing", str(cm.exception))
+ def test_execution_complete_event_emitted_once(self):
+ """Test that agent.execution_complete event is emitted exactly once.
+
+ Edge Case: Previously, execution_complete was emitted twice - once in
+ _process_stream_final_step and again outside the loop. This test ensures
+ it's only emitted once.
+ """
+ async def async_test():
+ from adalflow.core.types import FunctionOutput
+
+ # Create a function with _is_answer_final=True
+ fn = DummyFunction(
+ name="answer_output",
+ _is_answer_final=True,
+ _answer="test_complete"
+ )
+ agent = DummyAgent(
+ planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]),
+ answer_data_type=None,
+ )
+ runner = Runner(agent=agent)
+
+ async def mock_tool_execute_async(func, streaming_result=None):
+ return FunctionOutput(name=func.name, input=func, output="test_complete")
+
+ runner._tool_execute_async = mock_tool_execute_async
+
+ # Start streaming
+ streaming_result = runner.astream(prompt_kwargs={})
+
+ # Collect all execution_complete events
+ execution_complete_events = []
+ all_events = []
+
+ async for event in streaming_result.stream_events():
+ all_events.append(event)
+ if (isinstance(event, RunItemStreamEvent) and
+ event.name == "agent.execution_complete"):
+ execution_complete_events.append(event)
+
+ # Should have exactly one execution_complete event
+ self.assertEqual(len(execution_complete_events), 1,
+ f"Expected 1 execution_complete event, got {len(execution_complete_events)}")
+
+ # Verify the event contains correct data
+ final_event = execution_complete_events[0]
+ self.assertIsInstance(final_event.item, FinalOutputItem)
+ self.assertIsInstance(final_event.item.data, RunnerResult)
+ self.assertEqual(final_event.item.data.answer, "test_complete")
+
+ asyncio.run(async_test())
+
+ def test_execution_complete_event_properly_consumed(self):
+ """Test that execution_complete event is properly consumed by stream_to_json.
+
+ Edge Case: Previously, stream_events() would break immediately after yielding
+ execution_complete, preventing proper consumption by stream_to_json and other
+ consumers. This test ensures the event is properly consumed.
+ """
+ async def async_test():
+ from adalflow.core.types import FunctionOutput
+ import json
+ import os
+ import tempfile
+
+ # Create a function with _is_answer_final=True
+ fn = DummyFunction(
+ name="answer_output",
+ _is_answer_final=True,
+ _answer="stream_to_json_test"
+ )
+ agent = DummyAgent(
+ planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]),
+ answer_data_type=None,
+ )
+ runner = Runner(agent=agent)
+
+ async def mock_tool_execute_async(func, streaming_result=None):
+ return FunctionOutput(name=func.name, input=func, output="stream_to_json_test")
+
+ runner._tool_execute_async = mock_tool_execute_async
+
+ # Create temp file for JSON output
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ temp_file = f.name
+
+ try:
+ # Start streaming to JSON
+ streaming_result = runner.astream(prompt_kwargs={})
+
+ # Stream to JSON file
+ events_consumed = []
+ async for event in streaming_result.stream_to_json(temp_file):
+ if isinstance(event, RunItemStreamEvent):
+ events_consumed.append(event.name)
+
+ # Verify execution_complete was consumed
+ self.assertIn("agent.execution_complete", events_consumed,
+ "execution_complete event was not consumed by stream_to_json")
+
+ # Verify JSON file contains the execution_complete event
+ with open(temp_file, 'r') as f:
+ json_content = json.load(f)
+
+ # Check that execution_complete event is in the JSON
+ execution_complete_found = False
+ for event_entry in json_content:
+ if (event_entry.get("event_type") == "RunItemStreamEvent" and
+ "agent.execution_complete" in str(event_entry.get("event_data", {}))):
+ execution_complete_found = True
+ break
+
+ self.assertTrue(execution_complete_found,
+ "execution_complete event not found in JSON file")
+
+ finally:
+ # Clean up temp file
+ if os.path.exists(temp_file):
+ os.remove(temp_file)
+
+ asyncio.run(async_test())
+
+ def test_no_duplicate_execution_complete_on_incomplete(self):
+ """Test that execution_complete is only emitted once even when max_steps reached.
+
+ Edge Case: When the loop ends without a final answer (max_steps reached),
+ execution_complete should be emitted only once in the cleanup section.
+ """
+ async def async_test():
+ from adalflow.core.types import FunctionOutput
+
+ # Create functions that are NOT final
+ functions = [
+ DummyFunction(name=f"step_{i}", _is_answer_final=False)
+ for i in range(3)
+ ]
+ outputs = [GeneratorOutput(data=fn) for fn in functions]
+
+ agent = DummyAgent(
+ planner=FakeStreamingPlanner(outputs),
+ answer_data_type=None,
+ max_steps=3 # Will reach max_steps without final answer
+ )
+ runner = Runner(agent=agent)
+
+ async def mock_tool_execute_async(func, streaming_result=None):
+ return FunctionOutput(name=func.name, input=func, output=f"output_{func.name}")
+
+ runner._tool_execute_async = mock_tool_execute_async
+
+ # Start streaming
+ streaming_result = runner.astream(prompt_kwargs={})
+
+ # Collect all execution_complete events
+ execution_complete_events = []
+
+ async for event in streaming_result.stream_events():
+ if (isinstance(event, RunItemStreamEvent) and
+ event.name == "agent.execution_complete"):
+ execution_complete_events.append(event)
+
+ # Should have exactly one execution_complete event even when incomplete
+ self.assertEqual(len(execution_complete_events), 1,
+ f"Expected 1 execution_complete event for incomplete run, got {len(execution_complete_events)}")
+
+ # Verify it contains the "No output generated" message
+ final_event = execution_complete_events[0]
+ self.assertIsInstance(final_event.item, FinalOutputItem)
+ self.assertIsInstance(final_event.item.data, RunnerResult)
+ self.assertIn("No output generated", final_event.item.data.answer)
+
+ asyncio.run(async_test())
+
+ def test_execution_complete_with_early_break_scenario(self):
+ """Test execution_complete event when consumer breaks early from stream.
+
+ Edge Case: Test that even if a consumer breaks early from the stream,
+ the execution_complete event is available for consumption if they
+ resume iteration (after our fix).
+ """
+ async def async_test():
+ from adalflow.core.types import FunctionOutput
+
+ fn = DummyFunction(
+ name="answer_output",
+ _is_answer_final=True,
+ _answer="early_break_test"
+ )
+ agent = DummyAgent(
+ planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]),
+ answer_data_type=None,
+ )
+ runner = Runner(agent=agent)
+
+ async def mock_tool_execute_async(func, streaming_result=None):
+ return FunctionOutput(name=func.name, input=func, output="early_break_test")
+
+ runner._tool_execute_async = mock_tool_execute_async
+
+ # Start streaming
+ streaming_result = runner.astream(prompt_kwargs={})
+
+ # First consumer breaks after raw event
+ first_consumer_events = []
+ async for event in streaming_result.stream_events():
+ first_consumer_events.append(event)
+ if isinstance(event, RawResponsesStreamEvent):
+ break # Break early after first event
+
+ # Second consumer continues from where first left off
+ second_consumer_events = []
+ async for event in streaming_result.stream_events():
+ second_consumer_events.append(event)
+ if isinstance(event, RunItemStreamEvent):
+ second_consumer_events.append(event)
+
+ # Check if execution_complete was available to second consumer
+ execution_complete_found = any(
+ isinstance(e, RunItemStreamEvent) and e.name == "agent.execution_complete"
+ for e in second_consumer_events
+ )
+
+ self.assertTrue(execution_complete_found,
+ "execution_complete should be available even after early break")
+
+ asyncio.run(async_test())
+
if __name__ == "__main__":
unittest.main()
From 268737881086286bae8e1cfba0055db99bc7cbc3 Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Tue, 19 Aug 2025 15:33:44 -0700
Subject: [PATCH 3/3] Add flexible memory components and runner with tests
---
adalflow/adalflow/components/agent/prompts.py | 3 +-
.../components/agent/runner_flexible.py | 1957 +++++++++++++++++
.../components/memory/flexible_memory.py | 96 +-
adalflow/tests/test_flexible_memory.py | 1121 ++++++++++
.../tests/test_flexible_memory_template.py | 538 +++++
adalflow/tests/test_runner_flexible.py | 517 +++++
6 files changed, 4189 insertions(+), 43 deletions(-)
create mode 100644 adalflow/adalflow/components/agent/runner_flexible.py
create mode 100644 adalflow/tests/test_flexible_memory.py
create mode 100644 adalflow/tests/test_flexible_memory_template.py
create mode 100644 adalflow/tests/test_runner_flexible.py
diff --git a/adalflow/adalflow/components/agent/prompts.py b/adalflow/adalflow/components/agent/prompts.py
index dc37eb2b..d35c38c4 100644
--- a/adalflow/adalflow/components/agent/prompts.py
+++ b/adalflow/adalflow/components/agent/prompts.py
@@ -28,7 +28,8 @@
- If the last observation starts with "Run into error", you should try to fix the error in the next step.
"""
-# TODO: access the max steps in the agent prompt or not
+# Chat history should be user: message, assistant: message + meta data {step_history}
+# step_history is the observations.
DEFAULT_ADALFLOW_AGENT_SYSTEM_PROMPT = r"""
{{task_desc}}
- You cant use more than {{max_steps}} steps. At the {{max_steps}}th current step, must set `_is_answer_final` to True and provide the answer.
diff --git a/adalflow/adalflow/components/agent/runner_flexible.py b/adalflow/adalflow/components/agent/runner_flexible.py
new file mode 100644
index 00000000..a942d7cf
--- /dev/null
+++ b/adalflow/adalflow/components/agent/runner_flexible.py
@@ -0,0 +1,1957 @@
+"""Agent runner component with flexible memory for managing and executing agent workflows."""
+
+from pydantic import BaseModel
+import logging
+import inspect
+import asyncio
+import uuid
+import json
+from datetime import datetime
+
+from typing import (
+ Any,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ AsyncIterable,
+)
+from typing_extensions import TypeAlias
+import uuid
+
+
+from adalflow.optim.parameter import Parameter
+from adalflow.utils import printc
+from adalflow.core.component import Component
+from adalflow.components.agent.agent import Agent
+
+from adalflow.core.types import (
+ GeneratorOutput,
+ FunctionOutput,
+ Function,
+ StepOutput,
+ RawResponsesStreamEvent,
+ RunItemStreamEvent,
+ ToolCallRunItem,
+ ToolOutputRunItem,
+ StepRunItem,
+ FinalOutputItem,
+ RunnerStreamingResult,
+ RunnerResult,
+ QueueCompleteSentinel,
+ ToolOutput,
+ ToolCallActivityRunItem,
+ UserQuery,
+ AssistantResponse,
+)
+from adalflow.apps.permission_manager import PermissionManager
+from adalflow.components.memory.flexible_memory import FlexibleConversationMemory
+from adalflow.core.functional import _is_pydantic_dataclass, _is_adalflow_dataclass
+from adalflow.tracing import (
+ runner_span,
+ tool_span,
+ response_span,
+ step_span,
+)
+
+
+__all__ = ["RunnerFlexible"]
+
+log = logging.getLogger(__name__)
+
+T = TypeVar("T", bound=BaseModel) # Changed to use Pydantic BaseModel
+
+
+def _is_unrecoverable_error(error: Optional[str]) -> bool: # pragma: no cover
+ """Check if an error string indicates an unrecoverable error.
+
+ Unrecoverable errors include:
+ - HTTP 400: Bad request (e.g., context too long)
+ - HTTP 429: Rate limit exceeded
+ - HTTP 404: Model not found
+ - "Connection error": Network connection issues
+
+ This is marked as uncoverable for testing purposes.
+
+ Args:
+ error: Error string to check
+
+ Returns:
+ True if the error is unrecoverable, False otherwise
+ """
+ if not error:
+ return False
+
+ # Check for connection error string pattern (case insensitive)
+ if "connection error" in error.lower():
+ return True
+
+ # Check for HTTP error codes
+ if "400" in error or "429" in error or "404" in error:
+ return True
+
+ return False
+
+BuiltInType: TypeAlias = Union[str, int, float, bool, list, dict, tuple, set, None]
+PydanticDataClass: TypeAlias = Type[BaseModel]
+AdalflowDataClass: TypeAlias = Type[
+ Any
+] # Replace with your actual Adalflow dataclass type if available
+
+
+# The runner will create tool call request, add a unique call id.
+# Runner with flexible memory and robust error handling
+class RunnerFlexible(Component):
+ """Executes Agent instances with multi-step iterative planning and tool execution.
+
+ The Runner orchestrates the execution of an Agent through multiple reasoning and action
+ cycles. It manages the step-by-step execution loop where the Agent's planner generates
+ Function calls that get executed by the ToolManager, with results fed back into the
+ planning context for the next iteration.
+
+ Execution Flow:
+ 1. Initialize step history and prompt context
+ 2. For each step (up to max_steps):
+ a. Call Agent's planner to get next Function
+ b. Execute the Function using ToolManager
+ c. Add step result to history
+ d. Check if Function is "finish" to terminate
+ 3. Process final answer to expected output type
+
+ The Runner supports both synchronous and asynchronous execution modes, as well as
+ streaming execution with real-time event emission. It includes comprehensive tracing
+ and error handling throughout the execution pipeline.
+
+ Attributes:
+ agent (Agent): The Agent instance to execute
+ max_steps (int): Maximum number of execution steps allowed
+ answer_data_type (Type): Expected type for final answer processing
+ step_history (List[StepOutput]): History of all execution steps
+ ctx (Optional[Dict]): Additional context passed to tools
+ """
+
+ def __init__(
+ self,
+ agent: Agent,
+ ctx: Optional[Dict] = None,
+ max_steps: Optional[int] = None, # this will overwrite the agent's max_steps
+ permission_manager: Optional[PermissionManager] = None,
+ conversation_memory: Optional[FlexibleConversationMemory] = None,
+ **kwargs,
+ ) -> None:
+ """Initialize runner with an agent and configuration.
+
+ Args:
+ agent: The agent instance to execute
+ stream_parser: Optional stream parser
+ output_type: Optional Pydantic data class type
+ max_steps: Maximum number of steps to execute
+ permission_manager: Optional permission manager for tool approval
+ conversation_memory: Optional conversation memory
+ """
+ super().__init__(**kwargs)
+ self.agent = agent
+ self.tool_manager = agent.tool_manager
+ self.permission_manager = permission_manager
+ # pass the tool_manager to the permission_manager
+ if permission_manager is not None:
+ permission_manager.set_tool_manager(self.tool_manager)
+
+ self.conversation_memory = conversation_memory
+
+ self.use_conversation_memory = conversation_memory is not None
+
+ # get agent requirements
+ self.max_steps = max_steps
+ if max_steps is None:
+ self.max_steps = agent.max_steps
+ else:
+ # overwrite the agent's max_steps
+ self.agent.max_steps = max_steps
+ self.answer_data_type = agent.answer_data_type or str
+
+ self.step_history: List[StepOutput] = []
+
+ # add ctx (it is just a reference, and only get added to the final response)
+ # assume intermediate tool is gonna modify the ctx
+ self.ctx = ctx
+
+ # Initialize permission manager
+ self._init_permission_manager()
+
+ # Initialize cancellation flag
+ self._cancelled = False
+ self._cancel_callbacks = []
+ self._current_task = None # Track the current running task
+ self._current_streaming_result = None # Track the current streaming result
+
+ # support thinking model
+ self.is_thinking_model = agent.is_thinking_model if hasattr(agent, 'is_thinking_model') else False
+
+ # Token tracking
+ self._token_consumption: Dict[str, Any] = {
+ 'total_prompt_tokens': 0,
+ 'current_step_tokens': 0,
+ 'steps_token_history': [],
+ 'last_total_tokens': 0 # Track last total to calculate step difference
+ }
+
+ # ============== Safe Memory Operations ==============
+ def _safe_create_turn(self) -> Optional[str]:
+ """Safely create a new turn in memory.
+
+ Returns:
+ Turn ID if successful, None if failed
+ """
+ if not self.use_conversation_memory:
+ return None
+
+ try:
+ return self.conversation_memory.create_turn()
+ except Exception as e:
+ log.warning(f"Failed to create turn in memory: {e}")
+ return None
+
+ def _safe_add_user_query(self, query: Union[str, UserQuery], turn_id: Optional[str], metadata: Optional[Dict] = None) -> Optional[str]:
+ """Safely add user query to memory.
+
+ Returns:
+ Turn ID if successful, None if failed
+ """
+ if not self.use_conversation_memory or turn_id is None:
+ return None
+
+ try:
+ if isinstance(query, str):
+ return self.conversation_memory.add_user_query(query, turn_id, metadata)
+ else:
+ return self.conversation_memory.add_user_query(query.query_str, turn_id, metadata or query.metadata)
+ except Exception as e:
+ log.warning(f"Failed to add user query to memory: {e}")
+ return None
+
+ def _safe_add_assistant_response(self, response: Union[str, AssistantResponse], turn_id: Optional[str], metadata: Optional[Dict] = None) -> Optional[str]:
+ """Safely add assistant response to memory.
+
+ Returns:
+ Turn ID if successful, None if failed
+ """
+ if not self.use_conversation_memory or turn_id is None:
+ return None
+
+ try:
+ if isinstance(response, str):
+ return self.conversation_memory.add_assistant_response(response, turn_id, metadata)
+ else:
+ return self.conversation_memory.add_assistant_response(
+ response.response_str,
+ turn_id,
+ metadata or response.metadata
+ )
+ except Exception as e:
+ log.warning(f"Failed to add assistant response to memory: {e}")
+ return None
+
+ def _safe_get_conversation_history(self) -> str:
+ """Safely get conversation history from memory.
+
+ Returns:
+ Conversation history string, empty string if failed
+ """
+ if not self.use_conversation_memory:
+ return ""
+
+ try:
+ return self.conversation_memory() or ""
+ except Exception as e:
+ log.warning(f"Failed to get conversation history: {e}")
+ return ""
+
+ def _safe_reset_pending_query(self) -> None:
+ """Safely reset pending query in memory."""
+ if not self.use_conversation_memory:
+ return
+
+ try:
+ self.conversation_memory.reset_pending_query()
+ except Exception as e:
+ log.warning(f"Failed to reset pending query: {e}")
+ # Continue execution even if reset fails
+
+ def _init_permission_manager(self):
+ """Initialize the permission manager and register tools that require approval."""
+ if self.permission_manager and hasattr(self.agent, "tool_manager"):
+ # Iterate through tools in the ComponentList
+ for tool in self.agent.tool_manager.tools:
+ if hasattr(tool, "definition") and hasattr(tool, "require_approval"):
+ tool_name = tool.definition.func_name
+ self.permission_manager.register_tool(
+ tool_name, tool.require_approval
+ )
+
+ def set_permission_manager(
+ self, permission_manager: Optional[PermissionManager]
+ ) -> None:
+ """Set or update the permission manager after runner initialization.
+
+ Args:
+ permission_manager: The permission manager instance to use for tool approval
+ """
+ self.permission_manager = permission_manager
+ # Re-initialize to register tools with the new permission manager
+ self._init_permission_manager()
+
+ # pass the tool_manager to the permission_manager
+ if permission_manager is not None:
+ permission_manager.set_tool_manager(self.tool_manager)
+
+
+
+ def is_cancelled(self) -> bool:
+ """Check if execution has been cancelled."""
+ return self._cancelled
+
+ def reset_cancellation(self) -> None:
+ """Reset the cancellation flag for a new execution."""
+ self._cancelled = False
+
+ def get_token_consumption(self) -> Dict[str, Any]:
+ """Get the current token consumption statistics.
+
+ Returns:
+ Dict containing token consumption data:
+ - total_prompt_tokens: Total tokens consumed across all steps
+ - current_step_tokens: Tokens from the most recent step
+ - steps_token_history: List of token counts per step
+ """
+ return self._token_consumption.copy()
+
+ def _update_token_consumption(self) -> None:
+ """Update token consumption statistics by checking the planner's accumulated token count.
+
+ Since the generator accumulates tokens, we calculate the step tokens as the difference
+ from the last recorded total.
+ """
+ if hasattr(self.agent.planner, 'estimated_token_count'):
+ current_total = self.agent.planner.estimated_token_count
+ step_tokens = current_total - self._token_consumption['last_total_tokens']
+
+ self._token_consumption['current_step_tokens'] = step_tokens
+ self._token_consumption['total_prompt_tokens'] = current_total
+ self._token_consumption['steps_token_history'].append(step_tokens)
+ self._token_consumption['last_total_tokens'] = current_total
+
+ return step_tokens
+ return 0
+
+ def register_cancel_callback(self, callback) -> None:
+ """Register a callback to be called when execution is cancelled."""
+ self._cancel_callbacks.append(callback)
+
+ async def cancel(self) -> None:
+ """Cancel the current execution.
+
+ This will stop the current execution but preserve state like memory.
+ """
+ log.info("Runner.cancel() called - setting cancelled flag")
+ self._cancelled = True
+
+ # Try to emit a test event if we have a streaming result
+ if hasattr(self, '_current_streaming_result') and self._current_streaming_result:
+ try:
+ cancel_received_event = RunItemStreamEvent(
+ name="runner.cancel_received",
+ item=FinalOutputItem(
+ data={
+ "status": "cancel_received",
+ "message": "Cancel request received",
+ })
+ )
+ self._current_streaming_result.put_nowait(cancel_received_event)
+ log.info("Emitted cancel_received event")
+ except Exception as e:
+ log.error(f"Failed to emit cancel_received event: {e}")
+
+ # Cancel the current streaming task if it exists
+ if self._current_task and not self._current_task.done():
+ log.info(f"Cancelling runner task: {self._current_task}")
+ self._current_task.cancel()
+
+ # Create a task to wait for cancellation to complete
+ await self._wait_for_cancellation()
+
+ async def _wait_for_cancellation(self):
+ """Wait for task to be cancelled with timeout."""
+ if self._current_task:
+ try:
+ # Wait up to 1 second for task to cancel gracefully
+ await asyncio.wait_for(
+ self._current_task,
+ timeout=1.0
+ )
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ # Task didn't cancel in time or was cancelled - that's ok
+ pass
+
+ def _check_last_step(self, step: Function) -> bool:
+ """Check if the last step has is_answer_final set to True."""
+ if hasattr(step, "_is_answer_final") and step._is_answer_final:
+ return True
+
+ return False
+
+ def _get_final_answer(self, function: Function) -> Any:
+ """Get and process the final answer from the function."""
+ if hasattr(function, "_answer") or (hasattr(function, "_is_answer_final") and function._is_answer_final):
+ return self._process_data(function._answer)
+ return None
+
+
+ def _create_runner_result(self, answer: Any, step_history, error: Optional[str] = None, ) -> RunnerResult:
+ """Create a RunnerResult object with the final answer and error."""
+ return RunnerResult(
+ answer=answer,
+ step_history=step_history.copy(),
+ error=error,
+ # ctx=self.ctx,
+ )
+ def _create_execution_complete_stream_event(self, streaming_result: RunnerStreamingResult, final_output_item: FinalOutputItem):
+ """Complete the streaming execution by adding a sentinel."""
+ try:
+ final_output_event = RunItemStreamEvent(
+ name="agent.execution_complete",
+ item=final_output_item,
+ )
+ streaming_result.put_nowait(final_output_event)
+
+ runner_result: RunnerResult = final_output_item.data
+
+ # set up the final answer
+ streaming_result.answer = runner_result.answer if runner_result else None
+ streaming_result.step_history = self.step_history.copy()
+ streaming_result._is_complete = True
+
+ except Exception as e:
+ log.error(f"Failed to create execution complete stream event: {e}")
+ raise e
+
+ def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem, turn_id: Any = None):
+ # add the assistant response to the conversation memory using safe wrapper
+ if self.use_conversation_memory and turn_id is not None:
+ # Only add if we have a valid turn_id
+ self._safe_add_assistant_response(
+ AssistantResponse(
+ response_str=final_output_item.data.answer,
+ metadata={
+ "step_history": final_output_item.data.step_history.copy()
+ },
+ ),
+ turn_id=turn_id
+ )
+ elif self.use_conversation_memory:
+ log.warning("Skipping add_assistant_response - no turn_id available")
+
+ def create_response_span(self, runner_result, step_count: int, streaming_result: RunnerStreamingResult, runner_span_instance, workflow_status: str = "stream_completed"):
+
+ runner_span_instance.span_data.update_attributes(
+ {
+ "steps_executed": step_count + 1,
+ "final_answer": runner_result.answer,
+ "workflow_status": workflow_status,
+ }
+ )
+
+ # Create response span for tracking final streaming result
+ with response_span(
+ answer=runner_result.answer,
+ result_type=type(runner_result.answer).__name__,
+ execution_metadata={
+ "steps_executed": step_count + 1,
+ "max_steps": self.max_steps,
+ "workflow_status": workflow_status,
+ "streaming": True,
+ },
+ response=runner_result,
+ ):
+ pass
+
+
+
+
+ async def _process_stream_final_step(
+ self,
+ answer: Any,
+ step_count: int,
+ streaming_result,
+ runner_span_instance,
+ turn_id: Any = None,
+ ) -> FinalOutputItem:
+ """Process the final step and trace it."""
+
+ # Runner result is the same as the sync/async call result
+
+ runner_result = self._create_runner_result(
+ answer=answer,
+ step_history=self.step_history,
+ )
+
+ # Emit execution complete event
+ final_output_item = FinalOutputItem(data=runner_result)
+ self._create_execution_complete_stream_event(
+ streaming_result, final_output_item
+ )
+
+ # Ensure event is in queue by yielding control briefly
+ # This allows the event loop to process the put_nowait before we return
+ await asyncio.sleep(0) # Small delay to ensure event is properly queued
+
+ # add the assistant response to the conversation memory
+ self._add_assistant_response_to_memory(final_output_item, turn_id)
+ return final_output_item
+
+ # TODO: improved after the finish function is refactored
+ def _process_data(
+ self,
+ data: Union[BuiltInType, PydanticDataClass, AdalflowDataClass],
+ id: Optional[str] = None,
+ ) -> T:
+ """Process the generator output data field and convert to the specified pydantic data class of output_type.
+
+ Args:
+ data: The data to process
+ id: Optional identifier for the output
+
+ Returns:
+ str: The processed data as a string
+ """
+
+ try:
+ model_output = None
+ log.info(f"answer_data_type: {type(self.answer_data_type)}")
+
+ # returns a dictionary in this case
+ if _is_pydantic_dataclass(self.answer_data_type):
+ log.info(
+ f"initial answer returned by finish when user passed a pydantic type: {data}, type: {type(data)}"
+ )
+ # if it has not yet been deserialized then deserialize into dictionary using json loads
+ if isinstance(data, str):
+ try:
+ data = json.loads(data)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON in data: {e}")
+ if not isinstance(data, dict):
+ raise ValueError(f"Expected dict after JSON parsing, got {type(data)}")
+ log.info(
+ f"initial answer after being evaluated using json: {data}, type: {type(data)}"
+ )
+ # data should be a string that represents a dictionary
+ model_output = self.answer_data_type(**data)
+ elif _is_adalflow_dataclass(self.answer_data_type):
+ log.info(
+ f"initial answer returned by finish when user passed a adalflow type: {data}, type: {type(data)}"
+ )
+
+ if isinstance(data, str):
+ try:
+ data = json.loads(data)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON in data: {e}")
+ if not isinstance(data, dict):
+ raise ValueError(f"Expected dict after JSON parsing, got {type(data)}")
+ log.info(
+ f"initial answer after being evaluated using json: {data}, type: {type(data)}"
+ )
+ # data should be a string that represents a dictionary
+ model_output = self.answer_data_type.from_dict(data)
+ else: # expect data to be a python built_in_type
+ log.info(
+ f"type of answer is neither a pydantic dataclass or adalflow dataclass, answer before being casted again for safety: {data}, type: {type(data)}"
+ )
+ data = self.answer_data_type(
+ data
+ ) # directly cast using the answer_data_type
+ if not isinstance(data, self.answer_data_type):
+ raise ValueError(
+ f"Expected data of type {self.answer_data_type}, but got {type(data)}"
+ )
+ model_output = data
+
+ if not model_output:
+ raise ValueError(f"Failed to parse output: {data}")
+
+ return model_output
+
+ except Exception as e:
+ log.error(f"Error processing output: {str(e)}")
+ raise ValueError(f"Error processing output: {str(e)}")
+
+ @classmethod
+ def _get_planner_function(self, output: GeneratorOutput) -> Optional[Function]:
+ """Check the planner output and return the function.
+
+ Args:
+ output: The planner output
+ """
+ if not isinstance(output, GeneratorOutput):
+ raise ValueError(
+ f"Expected GeneratorOutput, but got {type(output)}, value: {output}"
+ )
+
+ function = output.data
+
+ if not isinstance(function, Function):
+ # can still self-recover in the agent for formatting.
+ # raise ValueError(
+ # f"Expected Function in the data field of the GeneratorOutput, but got {type(function)}, value: {function}"
+ # )
+ return None
+
+ return function
+
+ def call(
+ self,
+ prompt_kwargs: Dict[str, Any],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ use_cache: Optional[bool] = None,
+ id: Optional[str] = None, # global run id
+ ) -> RunnerResult:
+ """Execute the planner synchronously for multiple steps with function calling support.
+
+ At the last step the action should be set to "finish" instead which terminates the sequence
+
+ Args:
+ prompt_kwargs: Dictionary of prompt arguments for the generator
+ model_kwargs: Optional model parameters to override defaults
+ use_cache: Whether to use cached results if available
+ id: Optional unique identifier for the request
+
+ Returns:
+ RunnerResult containing step history and final processed output
+ """
+ # Create runner span for tracing
+ with runner_span(
+ runner_id=id or f"runner_{hash(str(prompt_kwargs))}",
+ max_steps=self.max_steps,
+ workflow_status="starting",
+ ) as runner_span_instance:
+ # reset the step history
+ self.step_history = []
+
+ # take in the query in prompt_kwargs
+ prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {}
+ prompt_kwargs["step_history"] = (
+ self.step_history
+ ) # a reference to the step history
+
+ turn_id = None
+ if self.use_conversation_memory:
+ # Reset any pending query state before starting a new query
+ self._safe_reset_pending_query()
+
+ # Create new turn
+ turn_id = self._safe_create_turn()
+
+ prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history()
+ # save the user query to the conversation memory
+
+ # meta data is all keys in the list of context_str
+ query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
+ if turn_id:
+ self._safe_add_user_query(
+ UserQuery(
+ query_str=prompt_kwargs.get("input_str", None),
+ metadata=query_metadata,
+ ),
+ turn_id=turn_id
+ )
+
+ # set maximum number of steps for the planner into the prompt
+ prompt_kwargs["max_steps"] = self.max_steps
+
+ model_kwargs = model_kwargs.copy() if model_kwargs else {}
+
+ step_count = 0
+ last_output = None
+ current_error = None
+
+ while step_count < self.max_steps:
+ try:
+ log.debug(f"Running step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}")
+ # Create step span for each iteration
+ with step_span(
+ step_number=step_count, action_type="planning"
+ ) as step_span_instance:
+
+ # Call the planner first to get the output
+ output = self.agent.planner.call(
+ prompt_kwargs=prompt_kwargs,
+ model_kwargs=model_kwargs,
+ use_cache=use_cache,
+ id=id,
+ )
+
+ # Track token usage
+ step_tokens = self._update_token_consumption()
+ if step_tokens > 0:
+ log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}")
+
+ log.debug(f"planner output: {output}")
+
+ # consistency with impl_astream, break if output is not a Generator Output
+ if not isinstance(output, GeneratorOutput):
+ # Create runner finish event with error and stop the loop
+ current_error = (
+ f"Expected GeneratorOutput, but got {output}"
+ )
+ # add this to the step history
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ self.step_history.append(step_output)
+ break
+
+ function = output.data
+
+ log.debug(f"function: {function}")
+ if function is None:
+ error_msg = f"Run into error: {output.error}, raw response: {output.raw_response}"
+ # Handle recoverable vs unrecoverable errors
+ if output.error is not None:
+ if _is_unrecoverable_error(output.error):
+ # Unrecoverable errors: context too long, rate limit, model not found
+ current_error = output.error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=f"Unrecoverable error: {output.error}",
+ )
+ self.step_history.append(step_output)
+ break # Stop execution for unrecoverable errors
+ # Recoverable errors: JSON format errors, parsing errors, etc.
+ current_error = output.error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+ continue # Continue to next step for recoverable errors
+
+ # start to process correct function
+ function.id = str(uuid.uuid4()) # add function id
+ thinking = output.thinking if hasattr(output, 'thinking') else None
+ if thinking is not None and self.is_thinking_model:
+ function.thought = thinking
+
+
+ if self._check_last_step(function):
+ processed_data = self._process_data(function._answer)
+ # Wrap final output in RunnerResult
+ last_output = RunnerResult(
+ answer=processed_data,
+ step_history=self.step_history.copy(),
+ # ctx=self.ctx,
+ )
+
+ # Add assistant response to conversation memory
+ if self.use_conversation_memory and turn_id is not None:
+ self._safe_add_assistant_response(
+ AssistantResponse(
+ response_str=processed_data,
+ metadata={
+ "step_history": self.step_history.copy()
+ },
+ ),
+ turn_id=turn_id
+ )
+
+ step_count += 1 # Increment step count before breaking
+ break
+
+ step_output: Optional[StepOutput] = None
+
+ # Create tool span for function execution
+ with tool_span(
+ tool_name=function.name,
+ function_name=function.name,
+ function_args=function.args,
+ function_kwargs=function.kwargs,
+ ) as tool_span_instance:
+ function_results = self._tool_execute_sync(function)
+ # Update span attributes using update_attributes for MLflow compatibility
+ tool_span_instance.span_data.update_attributes(
+ {"output_result": function_results.output}
+ )
+
+ function_output = function_results.output
+ real_function_output = None
+
+ # Handle generator outputs in sync call
+ if inspect.iscoroutine(function_output):
+ # For sync call, we need to run the coroutine
+ real_function_output = asyncio.run(function_output)
+ elif inspect.isasyncgen(function_output):
+ # Collect all values from async generator
+ async def collect_async_gen():
+ collected_items = []
+ async for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # Skip activity items
+ continue
+ else:
+ collected_items.append(item)
+ return collected_items
+ real_function_output = asyncio.run(collect_async_gen())
+ elif inspect.isgenerator(function_output):
+ # Collect all values from sync generator
+ collected_items = []
+ for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # Skip activity items
+ continue
+ else:
+ collected_items.append(item)
+ real_function_output = collected_items
+ else:
+ real_function_output = function_output
+
+ # Use the processed output
+ function_output = real_function_output
+ function_output_observation = function_output
+ if isinstance(function_output, ToolOutput) and hasattr(
+ function_output, "observation"
+ ):
+ function_output_observation = function_output.observation
+
+ # create a step output
+ step_output = StepOutput(
+ step=step_count,
+ action=function,
+ function=function,
+ observation=function_output_observation,
+ )
+
+ # Update step span with results
+ step_span_instance.span_data.update_attributes(
+ {
+ "tool_name": function.name,
+ "tool_output": function_results,
+ "is_final": self._check_last_step(function),
+ "observation": function_output_observation,
+ }
+ )
+
+ log.debug(
+ "The prompt with the prompt template is {}".format(
+ self.agent.planner.get_prompt(**prompt_kwargs)
+ )
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+
+ except Exception as e:
+ error_msg = f"Error in step {step_count}: {str(e)}"
+ log.error(error_msg)
+
+ # Create response span for error tracking
+ with response_span(
+ answer=error_msg,
+ result_type="error",
+ execution_metadata={
+ "steps_executed": step_count,
+ "max_steps": self.max_steps,
+ "workflow_status": "failed",
+ },
+ response=None,
+ ):
+ pass
+
+ # Continue to next step instead of returning
+ step_count += 1
+ current_error = error_msg
+ break
+
+ # Update runner span with final results
+ # Update runner span with completion info using update_attributes
+ runner_span_instance.span_data.update_attributes(
+ {
+ "steps_executed": step_count,
+ "final_answer": last_output.answer if last_output else None,
+ "workflow_status": "completed",
+ }
+ )
+
+ # Create response span for tracking final result
+ with response_span(
+ answer=(
+ last_output.answer
+ if last_output
+ else f"No output generated after {step_count} steps (max_steps: {self.max_steps})"
+ ),
+ result_type=(
+ type(last_output.answer).__name__ if last_output else "no_output"
+ ),
+ execution_metadata={
+ "steps_executed": step_count,
+ "max_steps": self.max_steps,
+ "workflow_status": "completed" if last_output else "incomplete",
+ },
+ response=last_output, # can be None if Runner has not finished in the max steps
+ ):
+ pass
+
+ # Always return a RunnerResult, even if no successful completion
+ return last_output or RunnerResult(
+ answer=current_error or f"No output generated after {step_count} steps (max_steps: {self.max_steps})",
+ step_history=self.step_history.copy(),
+ error=current_error,
+ )
+
+ def _tool_execute_sync(
+ self,
+ func: Function,
+ ) -> Union[FunctionOutput, Parameter]:
+ """
+ Call this in the call method.
+ Handles both sync and async functions by running async ones in event loop.
+ Includes permission checking if permission_manager is configured.
+ """
+
+ # execute permission and blocking mechanism in check_permission
+ # TODO: permission manager might be better to be put inside of tool manager
+ if self.permission_manager:
+
+ result = asyncio.run(self.permission_manager.check_permission(func))
+
+ # Handle both old (2 values) and new (3 values) return formats
+ if len(result) == 3:
+ allowed, modified_func, _ = result
+ else:
+ allowed, modified_func = result
+
+ if not allowed:
+ return FunctionOutput(
+ name=func.name,
+ input=func,
+ output=ToolOutput(
+ output="Tool execution cancelled by user",
+ observation="Tool execution cancelled by user",
+ display="Permission denied",
+ status="cancelled",
+ ),
+ )
+
+ # Use modified function if user edited it
+ func = modified_func or func
+
+ result = self.agent.tool_manager.execute_func(func=func)
+
+ if not isinstance(result, FunctionOutput):
+ raise ValueError("Result is not a FunctionOutput")
+
+ # check error
+ if result.error is not None:
+ log.warning(f"Error in tool execution: {result.error}")
+ # TODO: specify how to handle this error
+
+ return result
+
+ # support both astream and non-stream
+ async def acall(
+ self,
+ prompt_kwargs: Dict[str, Any],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ use_cache: Optional[bool] = None,
+ id: Optional[str] = None,
+ ) -> Optional[RunnerResult]:
+ """Execute the planner asynchronously for multiple steps with function calling support.
+
+ At the last step the action should be set to "finish" instead which terminates the sequence
+
+ Args:
+ prompt_kwargs: Dictionary of prompt arguments for the generator
+ model_kwargs: Optional model parameters to override defaults
+ use_cache: Whether to use cached results if available
+ id: Optional unique identifier for the request
+
+ Returns:
+ RunnerResponse containing step history and final processed output
+ """
+
+
+ workflow_status = "starting"
+ runner_id = id or f"async_runner_{hash(str(prompt_kwargs))}"
+
+
+
+ # Create runner span for tracing
+ with runner_span(
+ runner_id=runner_id,
+ max_steps=self.max_steps,
+ workflow_status= workflow_status,
+ ) as runner_span_instance:
+ # Reset cancellation flag at start of new execution
+ self.reset_cancellation()
+
+ self.step_history = []
+ prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {}
+
+ prompt_kwargs["step_history"] = (
+ self.step_history
+ ) # a reference to the step history
+
+ turn_id = None
+ if self.use_conversation_memory:
+ # Reset any pending query state before starting a new query
+ self._safe_reset_pending_query()
+
+ # Create new turn
+ turn_id = self._safe_create_turn()
+
+ prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history()
+ # save the user query to the conversation memory
+
+ # meta data is all keys in the list of context_str
+ query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
+ if turn_id:
+ self._safe_add_user_query(
+ UserQuery(
+ query_str=prompt_kwargs.get("input_str", None),
+ metadata=query_metadata,
+ ),
+ turn_id=turn_id
+ )
+
+ # set maximum number of steps for the planner into the prompt
+ prompt_kwargs["max_steps"] = self.max_steps
+
+ model_kwargs = model_kwargs.copy() if model_kwargs else {}
+
+ step_count = 0
+ last_output = None
+ current_error = None
+
+ while step_count < self.max_steps and not self.is_cancelled():
+ try:
+ log.debug(f"Running async step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}")
+
+ # Create step span for each iteration
+ with step_span(
+ step_number=step_count, action_type="async_planning"
+ ) as step_span_instance:
+
+ log.debug(f"Running async step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}")
+
+ if self.is_cancelled():
+ raise asyncio.CancelledError("Execution cancelled by user")
+
+ # Call the planner first to get the output
+ output: GeneratorOutput = await self.agent.planner.acall(
+ prompt_kwargs=prompt_kwargs,
+ model_kwargs=model_kwargs,
+ use_cache=use_cache,
+ id=id,
+ )
+
+ # Track token usage
+ step_tokens = self._update_token_consumption()
+ if step_tokens > 0:
+ log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}")
+
+ log.debug(f"planner output: {output}")
+
+ if not isinstance(output, GeneratorOutput):
+ # Create runner finish event with error and stop the loop
+ current_error = (
+ f"Expected GeneratorOutput, but got {type(output)}"
+ )
+ # create a step output for the error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+ break
+
+
+
+ function = output.data
+
+ log.debug(f"function: {function}")
+ if function is None:
+ error_msg = f"Run into error: {output.error}, raw response: {output.raw_response}"
+ # Handle recoverable vs unrecoverable errors
+ if output.error is not None:
+ if _is_unrecoverable_error(output.error):
+ # Unrecoverable errors: context too long, rate limit, model not found
+ current_error = output.error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=f"Unrecoverable error: {output.error}",
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+ break # Stop execution for unrecoverable errors
+ # Recoverable errors: JSON format errors, parsing errors, etc.
+ current_error = output.error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+
+ continue # Continue to next step for recoverable errors`
+
+
+
+ thinking = output.thinking if hasattr(output, 'thinking') else None
+ if function is not None:
+ # add a function id
+ function.id = str(uuid.uuid4())
+ if thinking is not None and self.is_thinking_model:
+ function.thought = thinking
+
+
+
+ if self._check_last_step(function):
+ answer = self._get_final_answer(function)
+ # Wrap final output in RunnerResult
+ last_output = RunnerResult(
+ answer=answer,
+ step_history=self.step_history.copy(),
+ error=current_error,
+ # ctx=self.ctx,
+ )
+
+ # Add assistant response to conversation memory
+ if self.use_conversation_memory and turn_id is not None:
+ self._safe_add_assistant_response(
+ AssistantResponse(
+ response_str=answer,
+ metadata={
+ "step_history": self.step_history.copy()
+ },
+ ),
+ turn_id=turn_id
+ )
+
+
+
+ step_count += 1 # Increment step count before breaking
+
+ break
+
+ # Create tool span for function execution
+ with tool_span(
+ tool_name=function.name,
+ function_name=function.name,
+ function_args=function.args,
+ function_kwargs=function.kwargs,
+ ) as tool_span_instance:
+ function_results = await self._tool_execute_async(
+ func=function
+ )
+ function_output = function_results.output
+ # add the process of the generator and async generator
+ real_function_output = None
+
+ # Handle generator outputs similar to astream implementation
+ if inspect.iscoroutine(function_output):
+ real_function_output = await function_output
+ elif inspect.isasyncgen(function_output):
+ # Collect all values from async generator
+ collected_items = []
+ async for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # Skip activity items in acall
+ continue
+ else:
+ collected_items.append(item)
+ # Use collected items as output
+ real_function_output = collected_items
+ elif inspect.isgenerator(function_output):
+ # Collect all values from sync generator
+ collected_items = []
+ for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # Skip activity items in acall
+ continue
+ else:
+ collected_items.append(item)
+ # Use collected items as output
+ real_function_output = collected_items
+ else:
+ real_function_output = function_output
+
+ # Use the processed output
+ function_output = real_function_output
+ function_output_observation = function_output
+
+ if isinstance(function_output, ToolOutput) and hasattr(
+ function_output, "observation"
+ ):
+ function_output_observation = (
+ function_output.observation
+ )
+
+ # Update tool span attributes using update_attributes for MLflow compatibility
+ tool_span_instance.span_data.update_attributes(
+ {"output_result": function_output}
+ )
+
+ step_output: StepOutput = StepOutput(
+ step=step_count,
+ action=function,
+ function=function,
+ observation=function_output_observation,
+ )
+ self.step_history.append(step_output)
+
+ # Update step span with results
+ step_span_instance.span_data.update_attributes(
+ {
+ "tool_name": function.name,
+ "tool_output": function_results,
+ "is_final": self._check_last_step(function),
+ "observation": function_output_observation,
+ }
+ )
+
+ log.debug(
+ "The prompt with the prompt template is {}".format(
+ self.agent.planner.get_prompt(**prompt_kwargs)
+ )
+ )
+
+ step_count += 1
+
+ except Exception as e:
+ error_msg = f"Error in step {step_count}: {str(e)}"
+ log.error(error_msg)
+
+ # Create response span for error tracking
+ with response_span(
+ answer=error_msg,
+ result_type="error",
+ execution_metadata={
+ "steps_executed": step_count,
+ "max_steps": self.max_steps,
+ "workflow_status": "failed",
+ },
+ response=None,
+ ):
+ pass
+
+ # Continue to next step instead of returning
+ step_count += 1
+ current_error = error_msg
+ break
+
+ # Update runner span with final results
+ # Update runner span with completion info using update_attributes
+ runner_span_instance.span_data.update_attributes(
+ {
+ "steps_executed": step_count,
+ "final_answer": last_output.answer if last_output else None,
+ "workflow_status": "completed",
+ }
+ )
+
+ # Create response span for tracking final result
+ with response_span(
+ answer=(
+ last_output.answer
+ if last_output
+ else f"No output generated after {step_count} steps (max_steps: {self.max_steps})"
+ ),
+ result_type=(
+ type(last_output.answer).__name__ if last_output else "no_output"
+ ),
+ execution_metadata={
+ "steps_executed": step_count,
+ "max_steps": self.max_steps,
+ "workflow_status": "completed" if last_output else "incomplete",
+ },
+ response=last_output, # can be None if Runner has not finished in the max steps
+ ):
+ pass
+
+ # Always return a RunnerResult, even if no successful completion
+ return last_output or RunnerResult(
+ answer=current_error or f"No output generated after {step_count} steps (max_steps: {self.max_steps})",
+ step_history=self.step_history.copy(),
+ error=current_error,
+ )
+
+
+ def astream(
+ self,
+ prompt_kwargs: Dict[str, Any],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ use_cache: Optional[bool] = None,
+ id: Optional[str] = None,
+ ) -> RunnerStreamingResult:
+ """
+ Execute the runner asynchronously with streaming support.
+
+ Returns:
+ RunnerStreamingResult: A streaming result object with stream_events() method
+ """
+ # Cancel any previous task that might still be running
+ # TODO might have problems of overwriting and cancelling other tasks if we call await astream two times asychronously with the same runner / agent instance.
+ if self._current_task and not self._current_task.done():
+ self._current_task.cancel()
+ log.info("Cancelled previous streaming task")
+ # Don't wait for cancellation here - just cancel and move on
+ self._current_task = None
+
+ # Reset cancellation flag for new execution
+ self._cancelled = False
+
+ result = RunnerStreamingResult()
+ # Store the streaming result so we can emit events to it during cancellation
+ self._current_streaming_result = result
+
+ self.reset_cancellation()
+
+ # Store the task so we can cancel it if needed
+ self._current_task = asyncio.get_event_loop().create_task(
+ self.impl_astream(prompt_kwargs, model_kwargs, use_cache, id, result)
+ )
+ result._run_task = self._current_task
+ return result
+
+ async def impl_astream(
+ self,
+ prompt_kwargs: Dict[str, Any],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ use_cache: Optional[bool] = None,
+ id: Optional[str] = None,
+ streaming_result: Optional[RunnerStreamingResult] = None,
+ ) -> None:
+ """
+ Behave exactly the same as `acall` but with streaming support.
+
+ - GeneratorOutput will be emitted as RawResponsesStreamEvent
+
+ - StepOutput will be emitted as RunItemStreamEvent with name "agent.step_complete".
+
+ - Finally, there will be a FinalOutputItem with the final answer or error.
+
+ Execute the planner asynchronously for multiple steps with function calling support.
+
+ At the last step the action should be set to "finish" instead which terminates the sequence
+
+ Args:
+ prompt_kwargs: Dictionary of prompt arguments for the generator
+ model_kwargs: Optional model parameters to override defaults
+ use_cache: Whether to use cached results if available
+ id: Optional unique identifier for the request
+ """
+ workflow_status: Literal["streaming", "stream_completed", "stream_failed", "stream_incomplete"] = "streaming"
+ # Create runner span for tracing streaming execution
+ turn_id = None
+ with runner_span(
+ runner_id=id or f"stream_runner_{hash(str(prompt_kwargs))}",
+ max_steps=self.max_steps,
+ workflow_status= workflow_status,
+ ) as runner_span_instance:
+
+ # Reset cancellation flag at start of new execution
+ self.step_history = []
+ prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {}
+
+ prompt_kwargs["step_history"] = self.step_history
+ if self.use_conversation_memory:
+ # Reset any pending query state before starting a new query
+ self._safe_reset_pending_query()
+
+ # Create new turn
+ turn_id = self._safe_create_turn()
+
+ prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history()
+ # save the user query to the conversation memory
+ # meta data is all keys in the list of context_str
+ query_metadata = {"context_str": prompt_kwargs.get("context_str", None)}
+ if turn_id:
+ self._safe_add_user_query(
+ UserQuery(
+ query_str=prompt_kwargs.get("input_str", None),
+ metadata=query_metadata,
+ ),
+ turn_id
+ )
+ # a reference to the step history
+ # set maximum number of steps for the planner into the prompt
+ prompt_kwargs["max_steps"] = self.max_steps
+
+ model_kwargs = model_kwargs.copy() if model_kwargs else {}
+ step_count = 0
+ final_output_item = None
+ current_error = None
+
+ # whenever we have the final output, we break the loop, this includes
+ # (1) final_answer (check final step)
+ # (2) unrecoverable error in llm planner
+ # (3) any exception
+
+ # for normal, we will have raw_response_event, request_permission, tool_call_start, tool_call_activity, tool_call_complete, step_complete
+ # for error, we can skip any step but will always have step_complete []
+
+
+ # ToolOutput
+ # has three status: success, error, canceled
+ while step_count < self.max_steps and not self.is_cancelled():
+ try:
+ # Create step span for each streaming iteration
+ # error handing: when run into any error, it creates a runner finish event. and stops the loop
+ # it should directly sent the execution complete with error event
+ with step_span(
+ step_number=step_count, action_type="stream_planning"
+ ) as step_span_instance:
+ # important to ensure the prompt at each step is correct
+ log.debug(
+ "The prompt with the prompt template is {}".format(
+ self.agent.planner.get_prompt(**prompt_kwargs)
+ )
+ )
+
+ # Check cancellation before calling planner
+ # TODO seems slightly unnecessary we are calling .cancel on the task in cancel which will raise this exception regardless unless we want to terminate earlier by checking the cancelled field
+ if self.is_cancelled():
+ raise asyncio.CancelledError("Execution cancelled by user")
+
+ # when it's streaming, the output will be an async generator
+ output: GeneratorOutput = await self.agent.planner.acall(
+ prompt_kwargs=prompt_kwargs,
+ model_kwargs=model_kwargs,
+ use_cache=use_cache,
+ id=id,
+ )
+ planner_prompt = self.agent.planner.get_prompt(**prompt_kwargs)
+ log.debug(f"Planner output: {output}, prompt: {planner_prompt}")
+
+ # Track token usage
+ step_tokens = self._update_token_consumption()
+ if step_tokens > 0:
+ log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}")
+
+ if not isinstance(output, GeneratorOutput):
+ # Create runner finish event with error and stop the loop
+ error_msg = (
+ f"Expected GeneratorOutput, but got {type(output)}"
+ )
+ final_output_item = FinalOutputItem(
+ error=error_msg,
+ )
+ workflow_status = "stream_failed"
+ current_error = error_msg
+ # create a step output for the error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+ break
+
+
+ # handle the generator output data and error
+ wrapped_event = None
+
+ if isinstance(output.raw_response, AsyncIterable):
+ log.debug(
+ f"Streaming raw response from planner: {output.raw_response}"
+ )
+ # Streaming llm call - iterate through the async generator
+ async for event in output.raw_response:
+ # TODO seems slightly unnecessary we are calling .cancel on the task in cancel which will raise this exception regardless
+ if self.is_cancelled():
+ raise asyncio.CancelledError("Execution cancelled by user")
+ wrapped_event = RawResponsesStreamEvent(data=event)
+ streaming_result.put_nowait(wrapped_event)
+
+ else: # non-streaming cases
+ # yield the final planner response
+ if output.data is None or (not isinstance(output.data, Function)):
+
+ # recoverable errors, continue to create stepout
+ current_error = f"Error: {output.error} - data: {output.data}, raw_response: {output.raw_response}"
+ # wrap the error in a RawResponsesStreamEvent
+ wrapped_event = RawResponsesStreamEvent(
+ data=None, # no data in this case
+ error= output.error,
+ )
+ streaming_result.put_nowait(wrapped_event)
+
+
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=current_error,
+ )
+ # emit the step complete event with error which matches the step_output
+ step_item = StepRunItem(data=step_output)
+ step_complete_event = RunItemStreamEvent(
+ name="agent.step_complete",
+ item=step_item,
+ )
+ streaming_result.put_nowait(step_complete_event)
+
+ # Ensure event is processed before continuing
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
+ self.step_history.append(step_output)
+
+ if output.error is not None:
+
+ if _is_unrecoverable_error(output.error): # context too long or rate limite, not recoverable
+ # 404 model not exist
+ # create a final output item with error and stop the loop
+ final_output_item = FinalOutputItem(
+ error=output.error,
+ )
+ workflow_status = "stream_failed"
+ current_error = output.error
+ step_output = StepOutput(
+ step=step_count,
+ action=None,
+ function=None,
+ observation=f"Unrecoverable error: {output.error}",
+ )
+ self.step_history.append(step_output)
+ step_count += 1
+ break
+ step_count += 1
+ continue # continue to next step
+
+
+ # normal functions
+ wrapped_event = RawResponsesStreamEvent(
+ data=output.data,
+ input=planner_prompt,
+ ) # wrap on the data field to be the final output, the data might be null
+ streaming_result.put_nowait(wrapped_event)
+
+ # asychronously consuming the raw response will
+ # update the data field of output with the result of the output processor
+
+ # handle function output
+
+ function = output.data # here are the recoverable errors, should continue to step output
+ thinking = output.thinking # check the reasoning model response
+ if thinking is not None and self.is_thinking_model:
+ # if the thinking is not None, we will add it to the function
+ if function is not None and isinstance(function, Function):
+ function.thought = thinking
+
+ function.id = str(uuid.uuid4()) # add function id
+ function_result = None
+ function_output_observation = None
+
+ if thinking is not None and self.is_thinking_model:
+ function.thought = thinking
+
+ # TODO: simplify this
+ tool_call_id = function.id
+ tool_call_name = function.name
+ log.debug(f"function: {function}")
+
+ if self._check_last_step(function): # skip stepoutput
+ try:
+ answer = self._get_final_answer(function)
+ except Exception as e:
+ # If processing the final answer fails, use the raw answer
+ log.warning(f"Failed to process final answer: {e}. Using raw answer.")
+ answer = function._answer if hasattr(function, "_answer") else str(e)
+
+ final_output_item = await self._process_stream_final_step(
+ answer=answer,
+ step_count=step_count,
+ streaming_result=streaming_result,
+ runner_span_instance=runner_span_instance,
+ turn_id=turn_id,
+ )
+ workflow_status = "stream_completed"
+
+ # Ensure the queue has processed the execution_complete event
+ # Add a small yield to allow the event loop to process the queued events
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
+ break
+
+ # Check if permission is required and emit permission event
+ # TODO: trace the permission event
+
+ function_output_observation = None
+ function_result = None
+ print("function name", function.name)
+ complete_step = False
+ if (
+ self.permission_manager
+ and self.permission_manager.is_approval_required(
+ function.name
+ )
+ ):
+ permission_event = (
+ self.permission_manager.create_permission_event(
+ function
+ )
+ )
+ # there is an error
+ if isinstance(permission_event, ToolOutput):
+ # need a tool complete event
+ function_result = FunctionOutput(
+ name=function.name,
+ input=function,
+ output=permission_event,
+ )
+ tool_complete_event = RunItemStreamEvent(
+ name="agent.tool_call_complete",
+ # error is already tracked in output
+ # TODO: error tracking is not needed in RunItem, it is tracked in the tooloutput status.
+ item=ToolOutputRunItem(
+ data=function_result,
+ id=tool_call_id,
+ error=permission_event.observation if permission_event.status == "error" else None, # error message sent to the frontend
+ ),
+ )
+ streaming_result.put_nowait(tool_complete_event)
+ function_output_observation = permission_event.observation
+ complete_step = True
+ else:
+ permission_stream_event = RunItemStreamEvent(
+ name="agent.tool_permission_request",
+ item=permission_event,
+ )
+ streaming_result.put_nowait(permission_stream_event)
+ if not complete_step:
+ # Execute the tool with streaming support
+ function_result, function_output, function_output_observation = await self.stream_tool_execution(
+ function=function,
+ tool_call_id=tool_call_id,
+ tool_call_name=tool_call_name,
+ streaming_result=streaming_result,
+ )
+
+ # Add step to history for approved tools (same as non-permission branch)
+ step_output: StepOutput = StepOutput(
+ step=step_count,
+ action=function,
+ function=function,
+ observation=function_output_observation,
+ )
+ self.step_history.append(step_output)
+
+ # Update step span with results (for both recoverable errors and normal function execution)
+ step_span_instance.span_data.update_attributes(
+ {
+ "tool_name": function.name if function else None,
+ "tool_output": function_result,
+ "is_final": self._check_last_step(function),
+ "observation": function_output_observation,
+ }
+ )
+
+ # Emit step completion event (with error if any)
+ step_item = StepRunItem(data=step_output)
+ step_event = RunItemStreamEvent(
+ name="agent.step_complete", item=step_item
+ )
+ streaming_result.put_nowait(step_event)
+
+ # Ensure event is processed before continuing
+ await asyncio.sleep(0) # Yield control to allow queue processing
+
+ step_count += 1
+
+ except asyncio.CancelledError:
+ # Handle cancellation gracefully
+ cancel_msg = "Execution cancelled by user"
+ log.info(cancel_msg)
+
+ # Emit cancellation event so frontend/logs can see it
+ cancel_event = RunItemStreamEvent(
+ name="runner.cancelled",
+ item=FinalOutputItem(data={
+ "status": "cancelled",
+ "message": cancel_msg,
+ "step_count": step_count,
+ })
+ )
+ streaming_result.put_nowait(cancel_event)
+
+ # Store cancellation result
+ streaming_result.answer = cancel_msg
+ streaming_result.step_history = self.step_history.copy()
+ streaming_result._is_complete = True
+
+ # Add cancellation response to conversation memory
+ if self.use_conversation_memory and turn_id is not None:
+ self._safe_add_assistant_response(
+ AssistantResponse(
+ response_str="I apologize, but the execution was cancelled by the user.",
+ metadata={
+ "step_history": self.step_history.copy(),
+ "status": "cancelled",
+ "timestamp": datetime.now().isoformat()
+ }
+ ),
+ turn_id=turn_id
+ )
+
+ # Signal completion and break
+ streaming_result.put_nowait(QueueCompleteSentinel())
+ break
+
+ except Exception as e:
+ # these excepts should almost never happen
+ error_msg = f"Error in step {step_count}: {str(e)}"
+ log.error(error_msg)
+
+ workflow_status = "stream_failed"
+ streaming_result._exception = error_msg
+
+ # Emit error as FinalOutputItem to queue
+ final_output_item = FinalOutputItem(error=error_msg)
+ # error_event = RunItemStreamEvent(
+ # name="runner_finished", item=error_final_item
+ # )
+ current_error = error_msg
+ break
+
+ # If loop terminated without creating a final output item, create our own
+ # TODO this might be redundant
+ if final_output_item is None:
+ # Create a RunnerResult with incomplete status
+ runner_result = RunnerResult(
+ answer=f"No output generated after {step_count} steps (max_steps: {self.max_steps})",
+ error=current_error,
+ step_history=self.step_history.copy(),
+
+ )
+ final_output_item = FinalOutputItem(data=runner_result)
+
+ workflow_status = "stream_incomplete"
+ current_error = f"No output generated after {step_count} steps (max_steps: {self.max_steps})"
+
+ # Only emit execution_complete if we created a new final_output_item
+ # (i.e., when the loop ended without a final answer)
+ self._create_execution_complete_stream_event(
+ streaming_result=streaming_result,
+ final_output_item=final_output_item,
+ )
+
+ runner_span_instance.span_data.update_attributes(
+ {
+ "steps_executed": step_count,
+ "final_answer": final_output_item.data.answer if final_output_item.data else None,
+ "workflow_status": workflow_status,
+ }
+ )
+
+ # create runner result with or without error
+
+ runner_result = RunnerResult(
+ answer=final_output_item.data.answer if final_output_item.data else None,
+ step_history=self.step_history.copy(),
+ error=current_error,
+ )
+
+ # create response span for final output
+ # if workflow_status in ["stream_incomplete", "stream_failed"]:
+ self.create_response_span(
+ runner_result=runner_result,
+ step_count=step_count,
+ streaming_result=streaming_result,
+ runner_span_instance=runner_span_instance,
+ workflow_status=workflow_status,
+ )
+
+ # Signal completion of streaming
+ streaming_result.put_nowait(QueueCompleteSentinel())
+
+ async def _tool_execute_async(
+ self,
+ func: Function,
+ streaming_result: Optional[RunnerStreamingResult] = None,
+ ) -> Union[FunctionOutput, Parameter]:
+ """
+ Call this in the acall method.
+ Handles both sync and async functions.
+ Note: this version has no support for streaming.
+ Includes permission checking if permission_manager is configured.
+ """
+
+ # Check permission before execution
+ if self.permission_manager:
+ result = await self.permission_manager.check_permission(func)
+ # Handle both old (2 values) and new (3 values) return formats
+ if len(result) == 3:
+ allowed, modified_func, _ = result
+ else:
+ allowed, modified_func = result
+
+ if not allowed:
+ return FunctionOutput(
+ name=func.name,
+ input=func,
+ output=ToolOutput(
+ output="Tool execution cancelled by user",
+ observation="Tool execution cancelled by user",
+ display="Permission denied",
+ status="cancelled",
+ ),
+ )
+
+ # Use modified function if user edited it
+ func = modified_func or func
+
+ # Emit tool call event
+ if streaming_result is not None:
+ tool_call_item = ToolCallRunItem(data=func, id=func.id)
+ tool_call_event = RunItemStreamEvent(
+ name="agent.tool_call_start", item=tool_call_item
+ )
+ streaming_result.put_nowait(tool_call_event)
+
+ # if streaming_result is not None:
+ # result = await self.agent.tool_manager.execute_func_astream(func=func)
+ # else:
+ result = await self.agent.tool_manager.execute_func_async(func=func)
+
+ if not isinstance(result, FunctionOutput):
+ raise ValueError("Result is not a FunctionOutput")
+ return result
+
+ async def stream_tool_execution(
+ self,
+ function: Function,
+ tool_call_id: str,
+ tool_call_name: str,
+ streaming_result: RunnerStreamingResult,
+ ) -> tuple[Any, Any, Any]:
+ """
+ Execute a tool/function call with streaming support and proper event handling.
+
+ This method handles:
+ - Tool span creation for tracing
+ - Async generator support for streaming results
+ - Tool activity events
+ - Tool completion events
+ - Error handling and observation extraction
+
+ Args:
+ function: The Function object to execute
+ tool_call_id: Unique identifier for this tool call
+ tool_call_name: Name of the tool being called
+ streaming_result: Queue for streaming events
+
+ Returns:
+ tuple: (function_output, function_output_observation)
+ """
+ # Create tool span for streaming function execution
+ with tool_span(
+ tool_name=tool_call_name,
+ function_name=function.name, # TODO fix attributes
+ function_args=function.args,
+ function_kwargs=function.kwargs,
+ ) as tool_span_instance:
+
+ # TODO: inside of FunctionTool execution, it should ensure the types of async generator item
+ # to be either ToolCallActivityRunItem or ToolOutput(maybe)
+ # Call activity might be better designed
+
+ function_result = await self._tool_execute_async(
+ func=function, streaming_result=streaming_result
+ ) # everything must be wrapped in FunctionOutput
+
+ if not isinstance(function_result, FunctionOutput):
+ raise ValueError(
+ f"Result must be wrapped in FunctionOutput, got {type(function_result)}"
+ )
+
+ function_output = function_result.output
+ real_function_output = None
+
+ # TODO: validate when the function is a generator
+
+ if inspect.iscoroutine(function_output):
+ real_function_output = await function_output
+ elif inspect.isasyncgen(function_output):
+ async for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # add the tool_call_id to the item
+ item.id = tool_call_id
+ tool_call_event = RunItemStreamEvent(
+ name="agent.tool_call_activity", item=item
+ )
+ streaming_result.put_nowait(tool_call_event)
+ else:
+ real_function_output = item
+
+ elif inspect.isgenerator(function_output):
+ for item in function_output:
+ if isinstance(item, ToolCallActivityRunItem):
+ # add the tool_call_id to the item
+ item.id = tool_call_id
+ tool_call_event = RunItemStreamEvent(
+ name="agent.tool_call_activity", item=item
+ )
+ streaming_result.put_nowait(tool_call_event)
+ else:
+ real_function_output = item
+ else:
+ real_function_output = function_output
+
+ # create call complete
+ call_complete_event = RunItemStreamEvent(
+ name="agent.tool_call_complete",
+ item=ToolOutputRunItem(
+ id=tool_call_id,
+ data=FunctionOutput(
+ name=function.name,
+ input=function,
+ output=real_function_output,
+ ),
+ ),
+ )
+ streaming_result.put_nowait(call_complete_event)
+
+ function_output = real_function_output
+ function_output_observation = function_output
+
+ if isinstance(function_output, ToolOutput) and hasattr(
+ function_output, "observation"
+ ):
+ function_output_observation = (
+ function_output.observation
+ )
+ # Update tool span attributes using update_attributes for MLflow compatibility
+
+ tool_span_instance.span_data.update_attributes(
+ {"output_result": real_function_output}
+ )
+
+ return function_result, function_output, function_output_observation
diff --git a/adalflow/adalflow/components/memory/flexible_memory.py b/adalflow/adalflow/components/memory/flexible_memory.py
index d6adbec5..30f51973 100644
--- a/adalflow/adalflow/components/memory/flexible_memory.py
+++ b/adalflow/adalflow/components/memory/flexible_memory.py
@@ -21,7 +21,7 @@ class Message(DataClass):
id: str = field(default_factory=lambda: str(uuid4()))
role: Literal["user", "assistant", "system"] = "user"
content: str = ""
- metadata: Optional[Dict[str, Any]] = None
+ metadata: Optional[Dict[str, Any]] = None # example metadata: "step_history": List[Dict[str, Any]] for agent response, "images" for user query with images
timestamp: datetime = field(default_factory=datetime.now)
@classmethod
@@ -42,7 +42,7 @@ def from_system(cls, content: str, metadata: Optional[Dict] = None):
@dataclass
class Conversation(DataClass):
- """A conversation organized as turns, where each turn can have multiple messages."""
+ """A conversation organized as turns, where each turn can have multiple messages and is one agent loop."""
id: str = field(default_factory=lambda: str(uuid4()))
user_id: Optional[str] = None
turns: OrderedDict = field(default_factory=OrderedDict) # turn_id -> List[Message]
@@ -61,9 +61,15 @@ def add_message_to_turn(self, turn_id: str, message: Message) -> str:
str: The message ID
"""
if turn_id not in self.turns:
- self.turns[turn_id] = []
- self.turns[turn_id].append(message)
- return message.id
+ raise ValueError(f"Turn '{turn_id}' does not exist. Please create the turn first.")
+ if not isinstance(message, Message):
+ raise TypeError("message must be an instance of Message")
+
+ try:
+ self.turns[turn_id].append(message)
+ return message.id
+ except Exception as e:
+ raise RuntimeError(f"Failed to add message to turn '{turn_id}': {e}")
def get_turn_messages(self, turn_id: str) -> List[Message]:
"""Get all messages in a specific turn."""
@@ -101,10 +107,13 @@ def get_last_assistant_message(self) -> Optional[Message]:
def create_turn(self) -> str:
"""Create a new turn and return its ID."""
- turn_id = str(uuid4())
- self.turns[turn_id] = []
- self._current_turn_id = turn_id
- return turn_id
+ try:
+ turn_id = str(uuid4())
+ self.turns[turn_id] = []
+ self._current_turn_id = turn_id
+ return turn_id
+ except Exception as e:
+ raise RuntimeError(f"Failed to create turn: {e}")
# Template for conversation formatting
@@ -176,29 +185,29 @@ def create_turn(self) -> str:
Returns:
str: The new turn ID
"""
- return self.current_conversation.create_turn()
+ turn_id= self.current_conversation.create_turn()
+ return turn_id
- def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str:
+ def add_user_query(self, content: str, turn_id: str, metadata: Optional[Dict] = None) -> str:
"""Add a user message to a turn.
Args:
content: The user's message content
+ turn_id: The turn ID to add the message to (required, must exist)
metadata: Optional metadata
- turn_id: Optional turn ID. If None, creates a new turn.
Returns:
str: The turn ID the message was added to
- """
- # Use provided turn_id or create new turn
- if turn_id is None:
- turn_id = self.create_turn()
- elif turn_id not in self.current_conversation.turns:
- # Turn doesn't exist, create it
- self.current_conversation.turns[turn_id] = []
- # Track as current turn
- self.current_conversation._current_turn_id = turn_id
-
+ Raises:
+ ValueError: If the turn_id doesn't exist
+ """
+ # Check if turn exists
+ if turn_id not in self.current_conversation.turns:
+ raise ValueError(
+ f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() "
+ f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}"
+ )
# Create and add the user message
message = Message.from_user(content, metadata)
self.current_conversation.add_message_to_turn(turn_id, message)
@@ -218,29 +227,28 @@ def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id:
def add_assistant_response(
self,
content: str,
+ turn_id: str,
metadata: Optional[Dict] = None,
- turn_id: Optional[str] = None
) -> str:
"""Add an assistant message to a turn.
Args:
content: The assistant's message content
+ turn_id: The turn ID to add the message to (required, must exist)
metadata: Optional metadata
- turn_id: Optional turn ID. If None, uses current turn or creates new.
Returns:
str: The turn ID the message was added to
+
+ Raises:
+ ValueError: If the turn_id doesn't exist
"""
- # Determine which turn to use
- if turn_id is None:
- if self.current_conversation._current_turn_id:
- turn_id = self.current_conversation._current_turn_id
- else:
- # No active turn, create new one for standalone response
- turn_id = self.create_turn()
- elif turn_id not in self.current_conversation.turns:
- # Turn doesn't exist, create it
- self.current_conversation.turns[turn_id] = []
+ # Check if turn exists
+ if turn_id not in self.current_conversation.turns:
+ raise ValueError(
+ f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() "
+ f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}"
+ )
# Create and add the assistant message
message = Message.from_assistant(content, metadata)
@@ -258,23 +266,27 @@ def add_assistant_response(
return turn_id
- def add_system_message(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str:
+ def add_system_message(self, content: str, turn_id: str, metadata: Optional[Dict] = None) -> str:
"""Add a system message to a turn.
Args:
content: The system message content
+ turn_id: The turn ID to add the message to (required, must exist)
metadata: Optional metadata
- turn_id: Optional turn ID. If None, creates a new turn.
Returns:
str: The turn ID the message was added to
- """
- # Use provided turn_id or create new turn
- if turn_id is None:
- turn_id = self.create_turn()
- elif turn_id not in self.current_conversation.turns:
- self.current_conversation.turns[turn_id] = []
+ Raises:
+ ValueError: If the turn_id doesn't exist
+ """
+ # Check if turn exists
+ if turn_id not in self.current_conversation.turns:
+ raise ValueError(
+ f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() "
+ f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}"
+ )
+
# Create and add the system message
message = Message.from_system(content, metadata)
self.current_conversation.add_message_to_turn(turn_id, message)
diff --git a/adalflow/tests/test_flexible_memory.py b/adalflow/tests/test_flexible_memory.py
new file mode 100644
index 00000000..08bd8fe6
--- /dev/null
+++ b/adalflow/tests/test_flexible_memory.py
@@ -0,0 +1,1121 @@
+"""Comprehensive test suite for FlexibleConversationMemory.
+
+This test module covers:
+1. Basic message operations (creation, addition, retrieval)
+2. Turn management (creation, message assignment, ordering)
+3. Conversation management (new conversations, clearing, persistence)
+4. Error handling and validation
+5. Database integration
+6. Metadata handling and filtering
+7. Edge cases and boundary conditions
+"""
+
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from unittest.mock import Mock, patch
+from adalflow.components.memory.flexible_memory import (
+ Message,
+ Conversation,
+ FlexibleConversationMemory,
+)
+from adalflow.core.db import LocalDB
+
+
+class TestMessage:
+ """Test the Message dataclass functionality."""
+
+ def test_message_creation_default(self):
+ """Test creating a message with default values.
+
+ Tests:
+ - Default role is 'user'
+ - Content defaults to empty string
+ - ID is automatically generated
+ - Timestamp is automatically set
+ - Metadata is None by default
+ """
+ msg = Message()
+ assert msg.role == "user"
+ assert msg.content == ""
+ assert msg.id is not None
+ assert len(msg.id) > 0
+ assert isinstance(msg.timestamp, datetime)
+ assert msg.metadata is None
+
+ def test_message_creation_with_values(self):
+ """Test creating a message with custom values.
+
+ Tests:
+ - Custom role, content, and metadata are properly set
+ - Values are stored correctly
+ """
+ metadata = {"key": "value", "number": 42}
+ msg = Message(role="assistant", content="Hello", metadata=metadata)
+ assert msg.role == "assistant"
+ assert msg.content == "Hello"
+ assert msg.metadata == metadata
+
+ def test_message_from_user(self):
+ """Test creating a user message using class method.
+
+ Tests:
+ - from_user() creates message with role='user'
+ - Content and metadata are properly set
+ """
+ metadata = {"context": "test"}
+ msg = Message.from_user("User query", metadata)
+ assert msg.role == "user"
+ assert msg.content == "User query"
+ assert msg.metadata == metadata
+
+ def test_message_from_assistant(self):
+ """Test creating an assistant message using class method.
+
+ Tests:
+ - from_assistant() creates message with role='assistant'
+ - Content and metadata are properly set
+ """
+ msg = Message.from_assistant("Assistant response", {"confidence": 0.9})
+ assert msg.role == "assistant"
+ assert msg.content == "Assistant response"
+ assert msg.metadata == {"confidence": 0.9}
+
+ def test_message_from_system(self):
+ """Test creating a system message using class method.
+
+ Tests:
+ - from_system() creates message with role='system'
+ - Content and metadata are properly set
+ """
+ msg = Message.from_system("System prompt")
+ assert msg.role == "system"
+ assert msg.content == "System prompt"
+ assert msg.metadata is None
+
+ def test_message_unique_ids(self):
+ """Test that each message gets a unique ID.
+
+ Tests:
+ - Multiple messages have different IDs
+ - IDs are valid UUID strings
+ """
+ msg1 = Message()
+ msg2 = Message()
+ msg3 = Message()
+ assert msg1.id != msg2.id
+ assert msg2.id != msg3.id
+ assert msg1.id != msg3.id
+
+
+class TestConversation:
+ """Test the Conversation dataclass functionality."""
+
+ def test_conversation_creation(self):
+ """Test creating a conversation with default values.
+
+ Tests:
+ - Conversation ID is generated
+ - turns is an OrderedDict
+ - user_id can be set
+ - Metadata defaults to None
+ - Timestamp is set
+ """
+ conv = Conversation()
+ assert conv.id is not None
+ assert isinstance(conv.turns, OrderedDict)
+ assert len(conv.turns) == 0
+ assert conv.user_id is None
+ assert conv.metadata is None
+ assert isinstance(conv.created_at, datetime)
+
+ def test_conversation_with_user_id(self):
+ """Test creating a conversation with a specific user ID.
+
+ Tests:
+ - User ID is properly stored
+ """
+ conv = Conversation(user_id="user123")
+ assert conv.user_id == "user123"
+
+ def test_create_turn(self):
+ """Test creating a new turn in the conversation.
+
+ Tests:
+ - Turn is created with unique ID
+ - Turn is added to OrderedDict
+ - Current turn ID is tracked
+ - Multiple turns can be created
+ """
+ conv = Conversation()
+
+ turn_id1 = conv.create_turn()
+ assert turn_id1 is not None
+ assert turn_id1 in conv.turns
+ assert conv.turns[turn_id1] == []
+ assert conv._current_turn_id == turn_id1
+
+ turn_id2 = conv.create_turn()
+ assert turn_id2 != turn_id1
+ assert turn_id2 in conv.turns
+ assert conv._current_turn_id == turn_id2
+ assert len(conv.turns) == 2
+
+ def test_add_message_to_turn(self):
+ """Test adding messages to a specific turn.
+
+ Tests:
+ - Messages can be added to existing turns
+ - Messages are stored in order
+ - Message IDs are returned
+ - Multiple messages can be added to same turn
+ """
+ conv = Conversation()
+ turn_id = conv.create_turn()
+
+ msg1 = Message.from_user("First message")
+ msg_id1 = conv.add_message_to_turn(turn_id, msg1)
+ assert msg_id1 == msg1.id
+ assert len(conv.turns[turn_id]) == 1
+ assert conv.turns[turn_id][0] == msg1
+
+ msg2 = Message.from_assistant("Second message")
+ msg_id2 = conv.add_message_to_turn(turn_id, msg2)
+ assert msg_id2 == msg2.id
+ assert len(conv.turns[turn_id]) == 2
+ assert conv.turns[turn_id][1] == msg2
+
+ def test_add_message_to_nonexistent_turn(self):
+ """Test error handling when adding message to non-existent turn.
+
+ Tests:
+ - ValueError is raised with appropriate message
+ - Error message includes available turns
+ """
+ conv = Conversation()
+ msg = Message.from_user("Test")
+
+ with pytest.raises(ValueError) as exc_info:
+ conv.add_message_to_turn("fake_turn_id", msg)
+ assert "Turn 'fake_turn_id' does not exist" in str(exc_info.value)
+
+ def test_add_non_message_object(self):
+ """Test error handling when adding non-Message object.
+
+ Tests:
+ - TypeError is raised when non-Message object is provided
+ """
+ conv = Conversation()
+ turn_id = conv.create_turn()
+
+ with pytest.raises(TypeError) as exc_info:
+ conv.add_message_to_turn(turn_id, "not a message")
+ assert "must be an instance of Message" in str(exc_info.value)
+
+ def test_get_turn_messages(self):
+ """Test retrieving messages from a specific turn.
+
+ Tests:
+ - Can retrieve all messages from a turn
+ - Empty list returned for non-existent turn
+ - Messages maintain order
+ """
+ conv = Conversation()
+ turn_id = conv.create_turn()
+
+ msg1 = Message.from_user("User msg")
+ msg2 = Message.from_assistant("Assistant msg")
+ conv.add_message_to_turn(turn_id, msg1)
+ conv.add_message_to_turn(turn_id, msg2)
+
+ messages = conv.get_turn_messages(turn_id)
+ assert len(messages) == 2
+ assert messages[0] == msg1
+ assert messages[1] == msg2
+
+ # Non-existent turn returns empty list
+ assert conv.get_turn_messages("fake_id") == []
+
+ def test_get_all_messages(self):
+ """Test retrieving all messages across all turns.
+
+ Tests:
+ - Messages from all turns are returned
+ - Order is maintained (turn order and message order within turns)
+ - Empty list for empty conversation
+ """
+ conv = Conversation()
+
+ # Empty conversation
+ assert conv.get_all_messages() == []
+
+ # Add messages to multiple turns
+ turn1 = conv.create_turn()
+ msg1 = Message.from_user("Turn 1 User")
+ msg2 = Message.from_assistant("Turn 1 Assistant")
+ conv.add_message_to_turn(turn1, msg1)
+ conv.add_message_to_turn(turn1, msg2)
+
+ turn2 = conv.create_turn()
+ msg3 = Message.from_user("Turn 2 User")
+ msg4 = Message.from_assistant("Turn 2 Assistant")
+ conv.add_message_to_turn(turn2, msg3)
+ conv.add_message_to_turn(turn2, msg4)
+
+ all_messages = conv.get_all_messages()
+ assert len(all_messages) == 4
+ assert all_messages == [msg1, msg2, msg3, msg4]
+
+ def test_get_messages_by_role(self):
+ """Test filtering messages by role.
+
+ Tests:
+
+ - Can filter messages by user/assistant/system role
+ - Returns messages from all turns
+ - Empty list for roles with no messages
+ """
+ conv = Conversation()
+
+ turn1 = conv.create_turn()
+ conv.add_message_to_turn(turn1, Message.from_user("User 1"))
+ conv.add_message_to_turn(turn1, Message.from_assistant("Assistant 1"))
+
+ turn2 = conv.create_turn()
+ conv.add_message_to_turn(turn2, Message.from_user("User 2"))
+ conv.add_message_to_turn(turn2, Message.from_system("System 1"))
+
+ user_messages = conv.get_messages_by_role("user")
+ assert len(user_messages) == 2
+ assert all(msg.role == "user" for msg in user_messages)
+
+ assistant_messages = conv.get_messages_by_role("assistant")
+ assert len(assistant_messages) == 1
+ assert assistant_messages[0].content == "Assistant 1"
+
+ system_messages = conv.get_messages_by_role("system")
+ assert len(system_messages) == 1
+ assert system_messages[0].content == "System 1"
+
+ def test_get_last_user_message(self):
+ """Test retrieving the most recent user message.
+
+ Tests:
+ - Returns the last user message across all turns
+ - Returns None if no user messages exist
+ - Skips over non-user messages
+ """
+ conv = Conversation()
+
+ # No messages
+ assert conv.get_last_user_message() is None
+
+ turn1 = conv.create_turn()
+ user_msg1 = Message.from_user("First user")
+ conv.add_message_to_turn(turn1, user_msg1)
+ conv.add_message_to_turn(turn1, Message.from_assistant("Assistant"))
+
+ turn2 = conv.create_turn()
+ user_msg2 = Message.from_user("Second user")
+ conv.add_message_to_turn(turn2, user_msg2)
+ conv.add_message_to_turn(turn2, Message.from_assistant("Another assistant"))
+
+ last_user = conv.get_last_user_message()
+ assert last_user == user_msg2
+
+ def test_get_last_assistant_message(self):
+ """Test retrieving the most recent assistant message.
+
+ Tests:
+ - Returns the last assistant message across all turns
+ - Returns None if no assistant messages exist
+ - Skips over non-assistant messages
+ """
+ conv = Conversation()
+
+ # No messages
+ assert conv.get_last_assistant_message() is None
+
+ turn1 = conv.create_turn()
+ conv.add_message_to_turn(turn1, Message.from_user("User"))
+ assistant_msg1 = Message.from_assistant("First assistant")
+ conv.add_message_to_turn(turn1, assistant_msg1)
+
+ turn2 = conv.create_turn()
+ conv.add_message_to_turn(turn2, Message.from_user("Another user"))
+ assistant_msg2 = Message.from_assistant("Second assistant")
+ conv.add_message_to_turn(turn2, assistant_msg2)
+
+ last_assistant = conv.get_last_assistant_message()
+ assert last_assistant == assistant_msg2
+
+
+class TestFlexibleConversationMemory:
+ """Test the FlexibleConversationMemory component."""
+
+ def test_memory_initialization(self):
+ """Test initializing memory with default and custom settings.
+
+ Tests:
+ - Default initialization creates empty conversation
+ - Custom database can be provided
+ - User ID is properly set
+ """
+ # Default initialization
+ memory = FlexibleConversationMemory()
+ assert memory.current_conversation is not None
+ assert memory.user_id is None
+ assert memory.message_db is not None
+ assert memory.conver_db is not None
+
+ # With custom settings
+ custom_db = LocalDB()
+ memory = FlexibleConversationMemory(turn_db=custom_db, user_id="test_user")
+ assert memory.message_db == custom_db
+ assert memory.user_id == "test_user"
+ assert memory.current_conversation.user_id == "test_user"
+
+ def test_create_turn(self):
+ """Test creating turns through memory interface.
+
+ Tests:
+ - Turn creation returns valid ID
+ - Multiple turns can be created
+ - Turn IDs are unique
+ """
+ memory = FlexibleConversationMemory()
+
+ turn_id1 = memory.create_turn()
+ assert turn_id1 is not None
+ assert turn_id1 in memory.current_conversation.turns
+
+ turn_id2 = memory.create_turn()
+ assert turn_id2 != turn_id1
+ assert len(memory.current_conversation.turns) == 2
+
+ def test_add_user_query(self):
+ """Test adding user queries to turns.
+
+ Tests:
+ - User query is added to specified turn
+ - Message is stored in database
+ - Metadata is properly handled
+ - Error raised for non-existent turn
+ """
+ memory = FlexibleConversationMemory()
+ turn_id = memory.create_turn()
+
+ # Add user query
+ returned_id = memory.add_user_query("Hello", turn_id, {"context": "greeting"})
+ assert returned_id == turn_id
+
+ # Check message was added
+ messages = memory.get_turn_messages(turn_id)
+ assert len(messages) == 1
+ assert messages[0].role == "user"
+ assert messages[0].content == "Hello"
+ assert messages[0].metadata == {"context": "greeting"}
+
+ # Check database storage
+ assert len(memory.message_db.items) == 1
+ db_item = memory.message_db.items[0]
+ assert db_item["role"] == "user"
+ assert db_item["content"] == "Hello"
+ assert db_item["turn_id"] == turn_id
+
+ def test_add_user_query_nonexistent_turn(self):
+ """Test error handling when adding user query to non-existent turn.
+
+ Tests:
+ - ValueError is raised with helpful message
+ - Message includes available turns
+ """
+ memory = FlexibleConversationMemory()
+
+ with pytest.raises(ValueError) as exc_info:
+ memory.add_user_query("Hello", "fake_turn")
+ assert "Turn 'fake_turn' does not exist" in str(exc_info.value)
+ assert "create_turn()" in str(exc_info.value)
+
+ def test_add_assistant_response(self):
+ """Test adding assistant responses to turns.
+
+ Tests:
+ - Assistant response is added to specified turn
+ - Message is stored in database
+ - Metadata is properly handled
+ - Error raised for non-existent turn
+ """
+ memory = FlexibleConversationMemory()
+ turn_id = memory.create_turn()
+
+ # Add assistant response
+ returned_id = memory.add_assistant_response(
+ "Hi there!", turn_id, {"confidence": 0.95}
+ )
+ assert returned_id == turn_id
+
+ # Check message was added
+ messages = memory.get_turn_messages(turn_id)
+ assert len(messages) == 1
+ assert messages[0].role == "assistant"
+ assert messages[0].content == "Hi there!"
+ assert messages[0].metadata == {"confidence": 0.95}
+
+ # Check database storage
+ assert len(memory.message_db.items) == 1
+ db_item = memory.message_db.items[0]
+ assert db_item["role"] == "assistant"
+ assert db_item["content"] == "Hi there!"
+
+ def test_add_system_message(self):
+ """Test adding system messages to turns.
+
+ Tests:
+ - System message is added to specified turn
+ - Message is stored in database
+ - Metadata is properly handled
+ - Error raised for non-existent turn
+ """
+ memory = FlexibleConversationMemory()
+ turn_id = memory.create_turn()
+
+ # Add system message
+ returned_id = memory.add_system_message(
+ "System initialized", turn_id, {"version": "1.0"}
+ )
+ assert returned_id == turn_id
+
+ # Check message was added
+ messages = memory.get_turn_messages(turn_id)
+ assert len(messages) == 1
+ assert messages[0].role == "system"
+ assert messages[0].content == "System initialized"
+ assert messages[0].metadata == {"version": "1.0"}
+
+ def test_complete_conversation_flow(self):
+ """Test a complete conversation flow with multiple turns.
+
+ Tests:
+ - Multiple turns with user/assistant exchanges
+ - Messages maintain order within and across turns
+ - All retrieval methods work correctly
+ """
+ memory = FlexibleConversationMemory()
+
+ # Turn 1: Initial greeting
+ turn1 = memory.create_turn()
+ memory.add_user_query("Hello, how are you?", turn1)
+ memory.add_assistant_response("I'm doing well, thank you!", turn1)
+
+ # Turn 2: Follow-up question
+ turn2 = memory.create_turn()
+ memory.add_user_query("What can you help me with?", turn2)
+ memory.add_assistant_response("I can help with many things!", turn2)
+ memory.add_user_query("Can you be more specific?", turn2) # Multiple user queries in same turn
+ memory.add_assistant_response("I can help with coding, analysis, and more.", turn2)
+
+ # Check all messages
+ all_messages = memory.get_all_messages()
+ assert len(all_messages) == 6
+
+ # Check message ordering
+ assert all_messages[0].content == "Hello, how are you?"
+ assert all_messages[1].content == "I'm doing well, thank you!"
+ assert all_messages[2].content == "What can you help me with?"
+ assert all_messages[3].content == "I can help with many things!"
+ assert all_messages[4].content == "Can you be more specific?"
+ assert all_messages[5].content == "I can help with coding, analysis, and more."
+
+ def test_clear_conversation(self):
+ """Test clearing the conversation.
+
+ Tests:
+ - All turns are cleared
+ - Current turn ID is reset
+ - Database is not affected
+ """
+ memory = FlexibleConversationMemory()
+
+ # Add some content
+ turn1 = memory.create_turn()
+ memory.add_user_query("Message 1", turn1)
+ turn2 = memory.create_turn()
+ memory.add_user_query("Message 2", turn2)
+
+ # Verify content exists
+ assert len(memory.current_conversation.turns) == 2
+ assert memory.current_conversation._current_turn_id == turn2
+
+ # Clear conversation
+ memory.clear_conversation()
+
+ # Verify cleared
+ assert len(memory.current_conversation.turns) == 0
+ assert memory.current_conversation._current_turn_id is None
+
+ # Database should still have items
+ assert len(memory.message_db.items) == 2
+
+ def test_clear_conversation_turns_alias(self):
+ """Test that clear_conversation_turns is an alias for clear_conversation.
+
+ Tests:
+ - Both methods have the same effect
+ """
+ memory = FlexibleConversationMemory()
+ turn = memory.create_turn()
+ memory.add_user_query("Test", turn)
+
+ memory.clear_conversation_turns()
+ assert len(memory.current_conversation.turns) == 0
+
+ def test_new_conversation(self):
+ """Test starting a new conversation.
+
+ Tests:
+ - Current conversation is saved to conver_db
+ - New empty conversation is created
+ - User ID is preserved
+ """
+ memory = FlexibleConversationMemory(user_id="test_user")
+
+ # Add content to first conversation
+ turn = memory.create_turn()
+ memory.add_user_query("First conversation", turn)
+ first_conv_id = memory.current_conversation.id
+
+ # Start new conversation
+ memory.new_conversation()
+
+ # Check new conversation is empty and different
+ assert len(memory.current_conversation.turns) == 0
+ assert memory.current_conversation.id != first_conv_id
+ assert memory.current_conversation.user_id == "test_user"
+
+ # Check first conversation was saved
+ assert len(memory.conver_db.items) == 1
+ saved_conv = memory.conver_db.items[0]
+ assert saved_conv.id == first_conv_id
+
+ def test_get_current_turn_id(self):
+ """Test retrieving the current turn ID.
+
+ Tests:
+ - Returns None when no turns exist
+ - Returns correct turn ID after creation
+ - Updates when new turn is created
+ """
+ memory = FlexibleConversationMemory()
+
+ # No turns initially
+ assert memory.get_current_turn_id() is None
+
+ # After creating turn
+ turn1 = memory.create_turn()
+ assert memory.get_current_turn_id() == turn1
+
+ # After creating another turn
+ turn2 = memory.create_turn()
+ assert memory.get_current_turn_id() == turn2
+
+ def test_call_empty_conversation(self):
+ """Test calling memory with empty conversation.
+
+ Tests:
+ - Returns empty string for empty conversation
+ """
+ memory = FlexibleConversationMemory()
+ assert memory() == ""
+ assert memory.call() == ""
+
+ def test_call_with_messages(self):
+ """Test formatting conversation for output.
+
+ Tests:
+ - Messages are formatted correctly
+ - Roles are properly labeled
+ - Multiple turns are handled
+ - Metadata is included when present
+ """
+ memory = FlexibleConversationMemory()
+
+ turn1 = memory.create_turn()
+ memory.add_user_query("What is Python?", turn1)
+ memory.add_assistant_response("Python is a programming language.", turn1)
+
+ turn2 = memory.create_turn()
+ memory.add_user_query("Tell me more", turn2, {"priority": "high"})
+ memory.add_assistant_response("It's known for its simplicity.", turn2)
+
+ output = memory()
+
+ # Check basic content
+ assert "User: What is Python?" in output
+ assert "Assistant: Python is a programming language." in output
+ assert "User: Tell me more" in output
+ assert "Assistant: It's known for its simplicity." in output
+
+ # Check metadata
+ assert "priority: high" in output
+
+ def test_call_with_metadata_filter(self):
+ """Test filtering metadata in conversation output.
+
+ Tests:
+ - Only specified metadata keys are included
+ - Other metadata is filtered out
+ - Works across multiple messages
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query("Query", turn, {
+ "public": "visible",
+ "private": "hidden",
+ "context": "test"
+ })
+ memory.add_assistant_response("Response", turn, {
+ "confidence": 0.9,
+ "internal": "secret"
+ })
+
+ # Filter to only show 'public' and 'confidence'
+ output = memory(metadata_filter=["public", "confidence"])
+
+ assert "public: visible" in output
+ assert "confidence: 0.9" in output
+ assert "private: hidden" not in output
+ assert "context: test" not in output
+ assert "internal: secret" not in output
+
+ def test_call_with_system_messages(self):
+ """Test that system messages are included in output.
+
+ Tests:
+ - System messages are formatted correctly
+ - Mixed message types are handled
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_system_message("System: Initializing", turn)
+ memory.add_user_query("Hello", turn)
+ memory.add_assistant_response("Hi there", turn)
+
+ output = memory()
+ assert "System: System: Initializing" in output
+ assert "User: Hello" in output
+ assert "Assistant: Hi there" in output
+
+ def test_get_last_n_messages(self):
+ """Test retrieving the last N messages.
+
+ Tests:
+ - Returns correct number of messages
+ - Returns all messages if N > total
+ - Maintains correct order
+ """
+ memory = FlexibleConversationMemory()
+
+ # Add 5 messages
+ turn = memory.create_turn()
+ for i in range(5):
+ if i % 2 == 0:
+ memory.add_user_query(f"Message {i}", turn)
+ else:
+ memory.add_assistant_response(f"Message {i}", turn)
+
+ # Get last 3
+ last_3 = memory.get_last_n_messages(3)
+ assert len(last_3) == 3
+ assert last_3[0].content == "Message 2"
+ assert last_3[1].content == "Message 3"
+ assert last_3[2].content == "Message 4"
+
+ # Get more than available
+ last_10 = memory.get_last_n_messages(10)
+ assert len(last_10) == 5
+
+ def test_count_messages(self):
+ """Test counting messages by role.
+
+ Tests:
+ - Counts are accurate for each role
+ - Handles empty conversation
+ - Works across multiple turns
+ """
+ memory = FlexibleConversationMemory()
+
+ # Empty conversation
+ counts = memory.count_messages()
+ assert counts == {"user": 0, "assistant": 0, "system": 0}
+
+ # Add messages
+ turn1 = memory.create_turn()
+ memory.add_user_query("User 1", turn1)
+ memory.add_assistant_response("Assistant 1", turn1)
+
+ turn2 = memory.create_turn()
+ memory.add_user_query("User 2", turn2)
+ memory.add_user_query("User 3", turn2)
+ memory.add_system_message("System 1", turn2)
+
+ counts = memory.count_messages()
+ assert counts["user"] == 3
+ assert counts["assistant"] == 1
+ assert counts["system"] == 1
+
+ def test_count_turns(self):
+ """Test counting the number of turns.
+
+ Tests:
+ - Returns 0 for empty conversation
+ - Counts turns correctly
+ """
+ memory = FlexibleConversationMemory()
+
+ assert memory.count_turns() == 0
+
+ memory.create_turn()
+ assert memory.count_turns() == 1
+
+ memory.create_turn()
+ memory.create_turn()
+ assert memory.count_turns() == 3
+
+ def test_reset_pending_query(self):
+ """Test resetting pending query (compatibility method).
+
+ Tests:
+ - Current turn ID is unset
+ - Turns are not deleted
+ - Messages remain intact
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query("Test", turn)
+
+ # Current turn is set
+ assert memory.get_current_turn_id() == turn
+
+ # Reset pending query
+ memory.reset_pending_query()
+
+ # Current turn is unset but turn still exists
+ assert memory.get_current_turn_id() is None
+ assert turn in memory.current_conversation.turns
+ assert len(memory.get_turn_messages(turn)) == 1
+
+ def test_len_magic_method(self):
+ """Test the __len__ magic method.
+
+ Tests:
+ - Returns 0 for empty conversation
+ - Returns correct count of all messages
+ """
+ memory = FlexibleConversationMemory()
+
+ assert len(memory) == 0
+
+ turn = memory.create_turn()
+ memory.add_user_query("1", turn)
+ memory.add_assistant_response("2", turn)
+ memory.add_system_message("3", turn)
+
+ assert len(memory) == 3
+
+ def test_multiple_queries_same_turn(self):
+ """Test adding multiple user queries to the same turn.
+
+ Tests:
+ - Multiple user queries can be added to one turn
+ - Order is maintained
+ - Common in clarification scenarios
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query("First question", turn)
+ memory.add_user_query("Follow-up question", turn)
+ memory.add_user_query("Another clarification", turn)
+ memory.add_assistant_response("Comprehensive answer", turn)
+
+ messages = memory.get_turn_messages(turn)
+ assert len(messages) == 4
+ assert messages[0].content == "First question"
+ assert messages[1].content == "Follow-up question"
+ assert messages[2].content == "Another clarification"
+ assert messages[3].content == "Comprehensive answer"
+
+ def test_complex_metadata_handling(self):
+ """Test handling complex metadata structures.
+
+ Tests:
+ - Nested dictionaries in metadata
+ - Lists in metadata
+ - Mixed data types
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ complex_metadata = {
+ "step_history": [
+ {"action": "search", "result": "found"},
+ {"action": "analyze", "result": "complete"}
+ ],
+ "images": ["image1.png", "image2.jpg"],
+ "confidence": 0.95,
+ "nested": {
+ "level1": {
+ "level2": "deep value"
+ }
+ }
+ }
+
+ memory.add_user_query("Complex query", turn, complex_metadata)
+
+ messages = memory.get_turn_messages(turn)
+ assert messages[0].metadata == complex_metadata
+ assert messages[0].metadata["step_history"][0]["action"] == "search"
+ assert messages[0].metadata["nested"]["level1"]["level2"] == "deep value"
+
+ def test_conversation_persistence(self):
+ """Test that conversations are properly persisted.
+
+ Tests:
+ - Conversations are saved when starting new one
+ - Empty conversations are not saved
+ - Multiple conversations can be saved
+ """
+ memory = FlexibleConversationMemory()
+
+ # Empty conversation not saved
+ memory.new_conversation()
+ assert len(memory.conver_db.items) == 0
+
+ # Add content and start new conversation
+ turn = memory.create_turn()
+ memory.add_user_query("First conv", turn)
+ memory.new_conversation()
+ assert len(memory.conver_db.items) == 1
+
+ # Add more conversations
+ turn = memory.create_turn()
+ memory.add_user_query("Second conv", turn)
+ memory.new_conversation()
+ assert len(memory.conver_db.items) == 2
+
+ def test_error_handling_in_create_turn(self):
+ """Test error handling in turn creation.
+
+ Tests:
+ - RuntimeError is raised if turn creation fails
+ - Error message is informative
+ """
+ memory = FlexibleConversationMemory()
+
+ # Mock a failure in UUID generation
+ with patch('adalflow.components.memory.flexible_memory.uuid4', side_effect=Exception("UUID error")):
+ with pytest.raises(RuntimeError) as exc_info:
+ memory.create_turn()
+ assert "Failed to create turn" in str(exc_info.value)
+
+ def test_database_integration(self):
+ """Test integration with LocalDB for message storage.
+
+ Tests:
+ - Messages are stored with correct fields
+ - Turn IDs are properly tracked
+ - Timestamps are preserved
+ - Multiple databases work independently
+ """
+ db1 = LocalDB()
+ db2 = LocalDB()
+
+ memory1 = FlexibleConversationMemory(turn_db=db1)
+ memory2 = FlexibleConversationMemory(turn_db=db2)
+
+ # Add to memory1
+ turn1 = memory1.create_turn()
+ memory1.add_user_query("Memory 1 message", turn1)
+
+ # Add to memory2
+ turn2 = memory2.create_turn()
+ memory2.add_user_query("Memory 2 message", turn2)
+
+ # Check databases are independent
+ assert len(db1.items) == 1
+ assert len(db2.items) == 1
+ assert db1.items[0]["content"] == "Memory 1 message"
+ assert db2.items[0]["content"] == "Memory 2 message"
+
+ def test_edge_cases(self):
+ """Test various edge cases and boundary conditions.
+
+ Tests:
+ - Empty strings as content
+ - Very long content
+ - Special characters in content
+ - None values where applicable
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+
+ # Empty string content
+ memory.add_user_query("", turn)
+ messages = memory.get_turn_messages(turn)
+ assert messages[0].content == ""
+
+ # Very long content
+ long_content = "x" * 10000
+ memory.add_assistant_response(long_content, turn)
+ assert len(memory.get_turn_messages(turn)[1].content) == 10000
+
+ # Special characters
+ special = "Hello\n\tWorld!@#$%^&*()[]{}|\\<>?/~`"
+ memory.add_user_query(special, turn)
+ assert memory.get_turn_messages(turn)[2].content == special
+
+ # None metadata (should work fine)
+ memory.add_user_query("Test", turn, None)
+ assert memory.get_turn_messages(turn)[3].metadata is None
+
+
+class TestIntegration:
+ """Integration tests for the complete memory system."""
+
+ def test_realistic_conversation_flow(self):
+ """Test a realistic multi-turn conversation with all features.
+
+ Tests:
+ - Complete conversation flow
+ - Mixed message types
+ - Metadata handling
+ - Turn management
+ - Output formatting
+ """
+ memory = FlexibleConversationMemory(user_id="test_user")
+
+ # Turn 1: Initial setup
+ turn1 = memory.create_turn()
+ memory.add_system_message("You are a helpful assistant.", turn1)
+ memory.add_user_query("Hi, I need help with Python", turn1, {"source": "web_ui"})
+ memory.add_assistant_response(
+ "Hello! I'd be happy to help you with Python. What specific topic?",
+ turn1,
+ {"confidence": 0.95}
+ )
+
+ # Turn 2: Follow-up
+ turn2 = memory.create_turn()
+ memory.add_user_query("How do I read a file?", turn2)
+ memory.add_assistant_response(
+ "You can use the open() function with a context manager.",
+ turn2,
+ {"confidence": 0.98, "sources": ["python_docs"]}
+ )
+ memory.add_user_query("Can you show an example?", turn2) # Clarification in same turn
+ memory.add_assistant_response(
+ "with open('file.txt', 'r') as f:\n content = f.read()",
+ turn2,
+ {"code_block": True}
+ )
+
+ # Verify conversation state
+ assert memory.count_turns() == 2
+ counts = memory.count_messages()
+ assert counts["user"] == 3
+ assert counts["assistant"] == 3
+ assert counts["system"] == 1
+
+ # Check output formatting
+ output = memory()
+ assert "You are a helpful assistant" in output
+ assert "How do I read a file?" in output
+ assert "with open('file.txt', 'r')" in output
+
+ # Test filtered output
+ filtered = memory(metadata_filter=["confidence"])
+ assert "confidence: 0.95" in filtered
+ assert "confidence: 0.98" in filtered
+ assert "source: web_ui" not in filtered
+
+ # Get last messages
+ last_2 = memory.get_last_n_messages(2)
+ assert last_2[0].content == "Can you show an example?"
+ assert "with open" in last_2[1].content
+
+ # Start new conversation but preserve the old one
+ memory.new_conversation()
+ assert len(memory.conver_db.items) == 1
+ assert memory.count_messages() == {"user": 0, "assistant": 0, "system": 0}
+
+ def test_agent_style_conversation(self):
+ """Test conversation flow typical of an AI agent with step tracking.
+
+ Tests:
+ - Agent-style metadata with step history
+ - Multiple assistant messages in sequence
+ - Complex metadata structures
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+
+ # User query with context
+ memory.add_user_query(
+ "Find all Python files in the project",
+ turn,
+ {"request_type": "search", "timestamp": "2024-01-01T10:00:00"}
+ )
+
+ # Agent thinking steps
+ memory.add_assistant_response(
+ "I'll search for Python files in the project.",
+ turn,
+ {
+ "step": 1,
+ "action": "plan",
+ "confidence": 0.9
+ }
+ )
+
+ memory.add_assistant_response(
+ "Found 15 Python files in 3 directories.",
+ turn,
+ {
+ "step": 2,
+ "action": "search",
+ "results": {
+ "count": 15,
+ "directories": ["src", "tests", "scripts"]
+ }
+ }
+ )
+
+ memory.add_assistant_response(
+ "Here are the Python files I found: [list of files]",
+ turn,
+ {
+ "step": 3,
+ "action": "report",
+ "completion": True
+ }
+ )
+
+ # Verify step tracking
+ messages = memory.get_turn_messages(turn)
+ assert len(messages) == 4
+
+ # Check assistant messages have increasing step numbers
+ assistant_messages = [m for m in messages if m.role == "assistant"]
+ assert assistant_messages[0].metadata["step"] == 1
+ assert assistant_messages[1].metadata["step"] == 2
+ assert assistant_messages[2].metadata["step"] == 3
+
+ # Verify complex metadata structure
+ assert assistant_messages[1].metadata["results"]["count"] == 15
+ assert "src" in assistant_messages[1].metadata["results"]["directories"]
\ No newline at end of file
diff --git a/adalflow/tests/test_flexible_memory_template.py b/adalflow/tests/test_flexible_memory_template.py
new file mode 100644
index 00000000..cfdb3a00
--- /dev/null
+++ b/adalflow/tests/test_flexible_memory_template.py
@@ -0,0 +1,538 @@
+"""Test the Jinja2 template rendering in FlexibleConversationMemory.
+
+This module specifically tests:
+1. Template rendering with call() method
+2. Metadata filtering in templates
+3. Multiple turns and messages formatting
+4. System, user, and assistant message formatting
+5. Edge cases in template rendering
+"""
+
+import pytest
+from adalflow.components.memory.flexible_memory import (
+ Message,
+ Conversation,
+ FlexibleConversationMemory,
+ CONVERSATION_TEMPLATE,
+)
+from adalflow.core.prompt_builder import Prompt
+
+
+class TestFlexibleMemoryTemplateRendering:
+ """Test the Jinja2 template rendering functionality."""
+
+ def test_basic_template_rendering(self):
+ """Test basic conversation rendering with the template.
+
+ Tests:
+ - Single turn with user and assistant messages
+ - Proper formatting with "User:" and "Assistant:" prefixes
+ - Newline handling
+ """
+ memory = FlexibleConversationMemory()
+
+ turn_id = memory.create_turn()
+ memory.add_user_query("Hello, how are you?", turn_id)
+ memory.add_assistant_response("I'm doing well, thank you!", turn_id)
+
+ # Call the memory to render the template
+ output = memory.call()
+
+ # Check the formatted output
+ assert "User: Hello, how are you?" in output
+ assert "Assistant: I'm doing well, thank you!" in output
+
+ # Check order
+ lines = output.strip().split('\n')
+ user_line_idx = next(i for i, line in enumerate(lines) if "User:" in line)
+ assistant_line_idx = next(i for i, line in enumerate(lines) if "Assistant:" in line)
+ assert user_line_idx < assistant_line_idx
+
+ def test_multiple_turns_rendering(self):
+ """Test rendering multiple conversation turns.
+
+ Tests:
+ - Multiple turns are rendered in order
+ - Each turn's messages are grouped together
+ - Proper spacing between turns
+ """
+ memory = FlexibleConversationMemory()
+
+ # Turn 1
+ turn1 = memory.create_turn()
+ memory.add_user_query("What is Python?", turn1)
+ memory.add_assistant_response("Python is a programming language.", turn1)
+
+ # Turn 2
+ turn2 = memory.create_turn()
+ memory.add_user_query("What can I do with it?", turn2)
+ memory.add_assistant_response("You can build web apps, analyze data, and more.", turn2)
+
+ output = memory.call()
+
+ # Check all messages are present
+ assert "What is Python?" in output
+ assert "Python is a programming language" in output
+ assert "What can I do with it?" in output
+ assert "You can build web apps" in output
+
+ # Check order is maintained
+ python_idx = output.index("What is Python?")
+ python_answer_idx = output.index("Python is a programming")
+ what_can_idx = output.index("What can I do with it?")
+ web_apps_idx = output.index("You can build web apps")
+
+ assert python_idx < python_answer_idx < what_can_idx < web_apps_idx
+
+ def test_metadata_rendering(self):
+ """Test metadata rendering in the template.
+
+ Tests:
+ - Metadata is rendered below the message
+ - Metadata keys and values are formatted correctly
+ - Multiple metadata items are shown
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Search for Python tutorials",
+ turn,
+ metadata={
+ "source": "web_interface",
+ "priority": "high",
+ "timestamp": "2024-01-01T10:00:00"
+ }
+ )
+ memory.add_assistant_response(
+ "Here are some Python tutorials",
+ turn,
+ metadata={
+ "confidence": 0.95,
+ "sources_count": 5
+ }
+ )
+
+ output = memory.call()
+
+ # Check metadata is rendered
+ assert "source: web_interface" in output
+ assert "priority: high" in output
+ assert "timestamp: 2024-01-01T10:00:00" in output
+ assert "confidence: 0.95" in output
+ assert "sources_count: 5" in output
+
+ def test_metadata_filtering(self):
+ """Test metadata filtering in template rendering.
+
+ Tests:
+ - Only specified metadata keys are shown
+ - Other metadata is filtered out
+ - Filtering works across multiple messages
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Query with metadata",
+ turn,
+ metadata={
+ "public": "visible",
+ "private": "hidden",
+ "internal": "secret"
+ }
+ )
+ memory.add_assistant_response(
+ "Response with metadata",
+ turn,
+ metadata={
+ "public": "also visible",
+ "confidence": "high",
+ "debug": "hidden info"
+ }
+ )
+
+ # Call with metadata filter
+ output = memory.call(metadata_filter=["public", "confidence"])
+
+ # Check filtered metadata
+ assert "public: visible" in output
+ assert "public: also visible" in output
+ assert "confidence: high" in output
+
+ # Check filtered out metadata is not present
+ assert "private: hidden" not in output
+ assert "internal: secret" not in output
+ assert "debug: hidden info" not in output
+
+ def test_system_messages_rendering(self):
+ """Test system message rendering in the template.
+
+ Tests:
+ - System messages are properly formatted
+ - System messages appear with "System:" prefix
+ - Mixed message types work correctly
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_system_message("You are a helpful assistant.", turn)
+ memory.add_user_query("Hello", turn)
+ memory.add_assistant_response("Hi! How can I help you?", turn)
+
+ output = memory.call()
+
+ # Check system message formatting
+ assert "System: You are a helpful assistant." in output
+ assert "User: Hello" in output
+ assert "Assistant: Hi! How can I help you?" in output
+
+ # Check order
+ system_idx = output.index("System:")
+ user_idx = output.index("User:")
+ assistant_idx = output.index("Assistant:")
+ assert system_idx < user_idx < assistant_idx
+
+ def test_multiple_messages_same_role_in_turn(self):
+ """Test multiple messages from the same role in one turn.
+
+ Tests:
+ - Multiple user queries in one turn
+ - Multiple assistant responses in one turn
+ - Proper ordering and formatting
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query("First question", turn)
+ memory.add_user_query("Follow-up question", turn)
+ memory.add_user_query("Another clarification", turn)
+ memory.add_assistant_response("Comprehensive answer addressing all questions", turn)
+
+ output = memory.call()
+
+ # All messages should be present
+ assert "User: First question" in output
+ assert "User: Follow-up question" in output
+ assert "User: Another clarification" in output
+ assert "Assistant: Comprehensive answer" in output
+
+ # Check ordering
+ first_idx = output.index("First question")
+ followup_idx = output.index("Follow-up question")
+ clarification_idx = output.index("Another clarification")
+ answer_idx = output.index("Comprehensive answer")
+
+ assert first_idx < followup_idx < clarification_idx < answer_idx
+
+ def test_empty_conversation_rendering(self):
+ """Test rendering an empty conversation.
+
+ Tests:
+ - Empty conversation returns empty string
+ - No errors are raised
+ """
+ memory = FlexibleConversationMemory()
+
+ output = memory.call()
+ assert output == ""
+
+ # Also test with metadata filter
+ output_filtered = memory.call(metadata_filter=["some_key"])
+ assert output_filtered == ""
+
+ def test_special_characters_in_content(self):
+ """Test rendering with special characters in content.
+
+ Tests:
+ - Newlines in content
+ - Special characters (quotes, brackets, etc.)
+ - Unicode characters
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Can you explain:\n1. Lists\n2. Dictionaries\n3. Sets",
+ turn
+ )
+ memory.add_assistant_response(
+ 'Sure! Here\'s an explanation:\n• Lists: [1, 2, 3]\n• Dicts: {"key": "value"}\n• Sets: {1, 2, 3}',
+ turn
+ )
+
+ output = memory.call()
+
+ # Check special characters are preserved
+ assert "1. Lists" in output
+ assert "2. Dictionaries" in output
+ assert '{"key": "value"}' in output
+ assert "• Lists:" in output
+
+ def test_complex_metadata_in_template(self):
+ """Test rendering complex metadata structures.
+
+ Tests:
+ - Nested dictionaries in metadata
+ - Lists in metadata
+ - Numbers and booleans in metadata
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Complex query",
+ turn,
+ metadata={
+ "nested": {
+ "level1": {
+ "level2": "deep value"
+ }
+ },
+ "list_data": ["item1", "item2", "item3"],
+ "numeric": 42,
+ "boolean": True
+ }
+ )
+
+ output = memory.call()
+
+ # Check complex metadata is rendered (as string representations)
+ assert "Complex query" in output
+ # The template will render these as strings
+ assert "nested:" in output
+ assert "list_data:" in output
+ assert "numeric: 42" in output
+ assert "boolean: True" in output
+
+ def test_template_with_none_metadata(self):
+ """Test template handling of None metadata.
+
+ Tests:
+ - Messages with None metadata don't show metadata section
+ - Messages with empty dict metadata don't show metadata section
+ - Mixed messages with and without metadata
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ # Message with None metadata
+ memory.add_user_query("No metadata query", turn, metadata=None)
+ # Message with empty metadata
+ memory.add_assistant_response("No metadata response", turn, metadata={})
+ # Message with actual metadata
+ memory.add_user_query("With metadata", turn, metadata={"key": "value"})
+
+ output = memory.call()
+
+ # Check messages are present
+ assert "User: No metadata query" in output
+ assert "Assistant: No metadata response" in output
+ assert "User: With metadata" in output
+ assert "key: value" in output
+
+ # Count occurrences of "key:" - should only be one
+ assert output.count("key:") == 1
+
+ def test_callable_interface(self):
+ """Test that the callable interface works the same as call().
+
+ Tests:
+ - __call__ produces same output as call()
+ - Metadata filtering works with __call__
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query("Test query", turn, metadata={"meta1": "value1"})
+ memory.add_assistant_response("Test response", turn, metadata={"meta2": "value2"})
+
+ # Test call() and __call__() produce same output
+ output_call = memory.call()
+ output_callable = memory()
+ assert output_call == output_callable
+
+ # Test metadata filtering with both
+ output_call_filtered = memory.call(metadata_filter=["meta1"])
+ output_callable_filtered = memory(metadata_filter=["meta1"])
+ assert output_call_filtered == output_callable_filtered
+ assert "meta1: value1" in output_callable_filtered
+ assert "meta2: value2" not in output_callable_filtered
+
+ def test_template_direct_rendering(self):
+ """Test the template can be rendered directly with Prompt.
+
+ Tests:
+ - CONVERSATION_TEMPLATE constant is valid
+ - Can create Prompt with the template
+ - Direct rendering matches memory.call()
+ """
+ from collections import OrderedDict
+
+ # Create data structure manually
+ turns = OrderedDict()
+ turn_id = "test_turn"
+ turns[turn_id] = [
+ Message.from_user("Hello"),
+ Message.from_assistant("Hi there!")
+ ]
+
+ # Render with Prompt directly
+ prompt = Prompt(
+ template=CONVERSATION_TEMPLATE,
+ prompt_kwargs={
+ "turns": turns,
+ "metadata_filter": None
+ }
+ )
+ direct_output = prompt.call().strip()
+
+ # Create same conversation with memory
+ memory = FlexibleConversationMemory()
+ memory_turn = memory.create_turn()
+ memory.add_user_query("Hello", memory_turn)
+ memory.add_assistant_response("Hi there!", memory_turn)
+ memory_output = memory.call().strip()
+
+ # Should produce same output
+ assert direct_output == memory_output
+
+ def test_long_conversation_rendering(self):
+ """Test rendering of long conversations.
+
+ Tests:
+ - Many turns (10+)
+ - Many messages per turn
+ - Performance doesn't degrade
+ """
+ memory = FlexibleConversationMemory()
+
+ # Create 10 turns with multiple messages each
+ for i in range(10):
+ turn = memory.create_turn()
+ memory.add_user_query(f"Question {i+1}", turn)
+ memory.add_user_query(f"Clarification {i+1}", turn)
+ memory.add_assistant_response(f"Answer {i+1}", turn)
+ if i % 2 == 0:
+ memory.add_system_message(f"System note {i+1}", turn)
+
+ output = memory.call()
+
+ # Check all messages are present
+ for i in range(10):
+ assert f"Question {i+1}" in output
+ assert f"Clarification {i+1}" in output
+ assert f"Answer {i+1}" in output
+ if i % 2 == 0:
+ assert f"System note {i+1}" in output
+
+ # Check it's not empty and has reasonable length
+ assert len(output) > 500 # Should be a long conversation
+
+ # Check ordering is maintained
+ q1_idx = output.index("Question 1")
+ q10_idx = output.index("Question 10")
+ assert q1_idx < q10_idx
+
+
+class TestTemplateEdgeCases:
+ """Test edge cases and error conditions in template rendering."""
+
+ def test_template_with_jinja2_special_chars_in_content(self):
+ """Test content that contains Jinja2 special characters.
+
+ Tests:
+ - Content with {{ }} doesn't break template
+ - Content with {% %} doesn't break template
+ - Content is properly escaped
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Show me a template: {{ variable }} and {% if condition %}",
+ turn
+ )
+ memory.add_assistant_response(
+ "Here's the template: {{ var }} and {% for item in items %}",
+ turn
+ )
+
+ output = memory.call()
+
+ # Special characters should be preserved in output
+ assert "{{ variable }}" in output
+ assert "{% if condition %}" in output
+ assert "{{ var }}" in output
+ assert "{% for item in items %}" in output
+
+ def test_metadata_with_jinja2_chars(self):
+ """Test metadata containing Jinja2 special characters.
+
+ Tests:
+ - Metadata with template syntax doesn't break rendering
+ - Values are properly escaped
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(
+ "Query",
+ turn,
+ metadata={
+ "template": "{{ user_input }}",
+ "condition": "{% if x > 0 %}",
+ "normal": "regular value"
+ }
+ )
+
+ output = memory.call()
+
+ # Metadata should be rendered safely
+ assert "template: {{ user_input }}" in output
+ assert "condition: {% if x > 0 %}" in output
+ assert "normal: regular value" in output
+
+ def test_very_long_messages(self):
+ """Test rendering very long messages.
+
+ Tests:
+ - Long messages don't break template
+ - Full content is preserved
+ """
+ memory = FlexibleConversationMemory()
+
+ # Create a very long message
+ long_content = "This is a very long message. " * 100 # ~2500 characters
+
+ turn = memory.create_turn()
+ memory.add_user_query(long_content, turn)
+ memory.add_assistant_response("Short response", turn)
+
+ output = memory.call()
+
+ # Check long content is fully preserved
+ assert long_content in output
+ assert "Short response" in output
+
+ def test_whitespace_preservation(self):
+ """Test that whitespace in messages is preserved.
+
+ Tests:
+ - Leading/trailing spaces
+ - Multiple spaces
+ - Tabs and newlines
+ """
+ memory = FlexibleConversationMemory()
+
+ turn = memory.create_turn()
+ memory.add_user_query(" Message with spaces ", turn)
+ memory.add_assistant_response("\tTabbed\tmessage\t", turn)
+ memory.add_system_message("Line1\nLine2\nLine3", turn)
+
+ output = memory.call()
+
+ # Check whitespace is preserved
+ assert " Message with spaces " in output
+ assert "\tTabbed\tmessage\t" in output
+ assert "Line1\nLine2\nLine3" in output
\ No newline at end of file
diff --git a/adalflow/tests/test_runner_flexible.py b/adalflow/tests/test_runner_flexible.py
new file mode 100644
index 00000000..043985ef
--- /dev/null
+++ b/adalflow/tests/test_runner_flexible.py
@@ -0,0 +1,517 @@
+"""Test suite for RunnerFlexible with FlexibleConversationMemory integration.
+
+This module tests:
+1. Runner initialization with flexible memory
+2. Safe memory operations that never fail the runner
+3. Turn management during execution
+4. Error handling and recovery
+5. Memory persistence across multiple runs
+"""
+
+import pytest
+import asyncio
+from unittest.mock import Mock, patch, MagicMock
+from datetime import datetime
+
+from adalflow.components.agent.runner_flexible import RunnerFlexible
+from adalflow.components.memory.flexible_memory import FlexibleConversationMemory
+from adalflow.components.agent.agent import Agent
+from adalflow.core.types import (
+ GeneratorOutput,
+ Function,
+ FunctionOutput,
+ RunnerResult,
+ UserQuery,
+ AssistantResponse,
+ ToolOutput,
+)
+
+
+class TestRunnerFlexibleMemoryIntegration:
+ """Test RunnerFlexible with FlexibleConversationMemory."""
+
+ @pytest.fixture
+ def mock_agent(self):
+ """Create a mock agent for testing."""
+ agent = Mock(spec=Agent)
+ agent.max_steps = 3
+ agent.answer_data_type = str
+ agent.is_thinking_model = False
+
+ # Mock planner
+ agent.planner = Mock()
+ agent.planner.call = Mock()
+ agent.planner.acall = Mock()
+ agent.planner.get_prompt = Mock(return_value="test prompt")
+ agent.planner.estimated_token_count = 100
+
+ # Mock tool manager
+ agent.tool_manager = Mock()
+ agent.tool_manager.tools = []
+ agent.tool_manager.execute_func = Mock()
+ agent.tool_manager.execute_func_async = Mock()
+
+ return agent
+
+ @pytest.fixture
+ def flexible_memory(self):
+ """Create a flexible memory instance."""
+ return FlexibleConversationMemory()
+
+ @pytest.fixture
+ def runner_with_memory(self, mock_agent, flexible_memory):
+ """Create a runner with flexible memory."""
+ return RunnerFlexible(
+ agent=mock_agent,
+ conversation_memory=flexible_memory
+ )
+
+ def test_runner_initialization_with_memory(self, mock_agent, flexible_memory):
+ """Test that runner initializes correctly with flexible memory.
+
+ Tests:
+ - Runner accepts FlexibleConversationMemory
+ - use_conversation_memory flag is set correctly
+ - Safe memory operation methods are available
+ """
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory)
+
+ assert runner.conversation_memory == flexible_memory
+ assert runner.use_conversation_memory is True
+ assert hasattr(runner, '_safe_create_turn')
+ assert hasattr(runner, '_safe_add_user_query')
+ assert hasattr(runner, '_safe_add_assistant_response')
+ assert hasattr(runner, '_safe_get_conversation_history')
+
+ def test_safe_create_turn(self, runner_with_memory):
+ """Test safe turn creation that never fails.
+
+ Tests:
+ - Normal turn creation works
+ - Returns None when memory fails
+ - Logs warning on failure
+ """
+ # Normal operation
+ turn_id = runner_with_memory._safe_create_turn()
+ assert turn_id is not None
+ assert isinstance(turn_id, str)
+
+ # Test failure handling
+ runner_with_memory.conversation_memory.create_turn = Mock(
+ side_effect=Exception("Memory error")
+ )
+
+ with patch('adalflow.components.agent.runner_flexible.log') as mock_log:
+ result = runner_with_memory._safe_create_turn()
+ assert result is None
+ mock_log.warning.assert_called_once()
+ assert "Failed to create turn" in str(mock_log.warning.call_args)
+
+ def test_safe_add_user_query(self, runner_with_memory):
+ """Test safe user query addition.
+
+ Tests:
+ - Adding string queries
+ - Adding UserQuery objects
+ - Handling memory failures gracefully
+ """
+ turn_id = runner_with_memory._safe_create_turn()
+
+ # Test string query
+ result = runner_with_memory._safe_add_user_query(
+ "Test query", turn_id, {"meta": "data"}
+ )
+ assert result == turn_id
+
+ # Test UserQuery object
+ user_query = UserQuery(query_str="Object query", metadata={"key": "value"})
+ result = runner_with_memory._safe_add_user_query(user_query, turn_id)
+ assert result == turn_id
+
+ # Test with None turn_id
+ result = runner_with_memory._safe_add_user_query("Query", None)
+ assert result is None
+
+ # Test failure handling
+ runner_with_memory.conversation_memory.add_user_query = Mock(
+ side_effect=Exception("Add failed")
+ )
+
+ with patch('adalflow.components.agent.runner_flexible.log') as mock_log:
+ result = runner_with_memory._safe_add_user_query("Query", turn_id)
+ assert result is None
+ mock_log.warning.assert_called_once()
+
+ def test_safe_add_assistant_response(self, runner_with_memory):
+ """Test safe assistant response addition.
+
+ Tests:
+ - Adding string responses
+ - Adding AssistantResponse objects
+ - Handling memory failures gracefully
+ """
+ turn_id = runner_with_memory._safe_create_turn()
+
+ # Test string response
+ result = runner_with_memory._safe_add_assistant_response(
+ "Test response", turn_id, {"meta": "data"}
+ )
+ assert result == turn_id
+
+ # Test AssistantResponse object
+ assistant_response = AssistantResponse(
+ response_str="Object response",
+ metadata={"key": "value"}
+ )
+ result = runner_with_memory._safe_add_assistant_response(assistant_response, turn_id)
+ assert result == turn_id
+
+ # Test with None turn_id
+ result = runner_with_memory._safe_add_assistant_response("Response", None)
+ assert result is None
+
+ # Test failure handling
+ runner_with_memory.conversation_memory.add_assistant_response = Mock(
+ side_effect=Exception("Add failed")
+ )
+
+ with patch('adalflow.components.agent.runner_flexible.log') as mock_log:
+ result = runner_with_memory._safe_add_assistant_response("Response", turn_id)
+ assert result is None
+ mock_log.warning.assert_called_once()
+
+ def test_safe_get_conversation_history(self, runner_with_memory):
+ """Test safe conversation history retrieval.
+
+ Tests:
+ - Normal history retrieval
+ - Returns empty string on failure
+ - Never raises exceptions
+ """
+ # Add some conversation
+ turn_id = runner_with_memory._safe_create_turn()
+ runner_with_memory._safe_add_user_query("Hello", turn_id)
+ runner_with_memory._safe_add_assistant_response("Hi there", turn_id)
+
+ # Get history
+ history = runner_with_memory._safe_get_conversation_history()
+ assert "Hello" in history
+ assert "Hi there" in history
+
+ # Test failure handling - mock the call method directly
+ def mock_call():
+ raise Exception("History error")
+
+ original_call = runner_with_memory.conversation_memory.call
+ runner_with_memory.conversation_memory.call = mock_call
+ # Also need to replace __call__ since Python uses that
+ original_dunder_call = runner_with_memory.conversation_memory.__class__.__call__
+ runner_with_memory.conversation_memory.__class__.__call__ = lambda self: mock_call()
+
+ try:
+ with patch('adalflow.components.agent.runner_flexible.log') as mock_log:
+ result = runner_with_memory._safe_get_conversation_history()
+ assert result == ""
+ mock_log.warning.assert_called_once()
+ finally:
+ # Restore original methods
+ runner_with_memory.conversation_memory.call = original_call
+ runner_with_memory.conversation_memory.__class__.__call__ = original_dunder_call
+
+ def test_runner_continues_on_memory_failure(self, mock_agent, flexible_memory):
+ """Test that runner continues execution even when memory operations fail.
+
+ Tests:
+ - Runner completes successfully despite memory errors
+ - Warnings are logged but execution continues
+ - Final result is still produced
+ """
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory)
+
+ # Make all memory operations fail
+ flexible_memory.create_turn = Mock(side_effect=Exception("Memory failed"))
+ flexible_memory.add_user_query = Mock(side_effect=Exception("Memory failed"))
+ flexible_memory.add_assistant_response = Mock(side_effect=Exception("Memory failed"))
+ flexible_memory.__call__ = Mock(side_effect=Exception("Memory failed"))
+
+ # Setup successful agent execution
+ mock_function = Function(name="finish")
+ mock_function._is_answer_final = True
+ mock_function._answer = "Final answer"
+
+ mock_agent.planner.call.return_value = GeneratorOutput(
+ data=mock_function,
+ error=None,
+ raw_response="raw"
+ )
+
+ # Run should complete successfully despite memory failures
+ with patch('adalflow.components.agent.runner_flexible.log') as mock_log:
+ result = runner.call({"input_str": "Test query"})
+
+ # Check runner completed successfully
+ assert isinstance(result, RunnerResult)
+ assert result.answer == "Final answer"
+
+ # Check warnings were logged for memory failures
+ warning_calls = mock_log.warning.call_args_list
+ assert len(warning_calls) > 0
+ assert any("Failed to" in str(call) for call in warning_calls)
+
+ def test_async_runner_with_memory(self, mock_agent, flexible_memory):
+ """Test async runner execution with memory.
+
+ Tests:
+ - Async execution works with memory
+ - Turn management in async context
+ - Memory operations are safe in async
+ """
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory)
+
+ # Setup async mock
+ mock_function = Function(name="finish")
+ mock_function._is_answer_final = True
+ mock_function._answer = "Async answer"
+
+ async def mock_acall(*args, **kwargs):
+ return GeneratorOutput(
+ data=mock_function,
+ error=None,
+ raw_response="raw"
+ )
+
+ mock_agent.planner.acall = mock_acall
+
+ # Run async
+ async def run_test():
+ result = await runner.acall({"input_str": "Async query"})
+ return result
+
+ result = asyncio.run(run_test())
+
+ assert isinstance(result, RunnerResult)
+ assert result.answer == "Async answer"
+
+ # Check memory has content
+ history = flexible_memory()
+ assert "Async query" in history
+ assert "Async answer" in history
+
+ def test_memory_persistence_across_runs(self, runner_with_memory, mock_agent):
+ """Test that memory persists across multiple runner executions.
+
+ Tests:
+ - Memory accumulates across runs
+ - Each run creates a new turn
+ - History is available in subsequent runs
+ """
+ # Setup mock responses
+ def create_mock_response(answer):
+ mock_function = Function(name="finish")
+ mock_function._is_answer_final = True
+ mock_function._answer = answer
+ return GeneratorOutput(data=mock_function, error=None, raw_response="raw")
+
+ # First run
+ mock_agent.planner.call.return_value = create_mock_response("Answer 1")
+ result1 = runner_with_memory.call({"input_str": "Query 1"})
+ assert result1.answer == "Answer 1"
+
+ # Second run
+ mock_agent.planner.call.return_value = create_mock_response("Answer 2")
+ result2 = runner_with_memory.call({"input_str": "Query 2"})
+ assert result2.answer == "Answer 2"
+
+ # Check memory has both conversations
+ history = runner_with_memory.conversation_memory()
+ assert "Query 1" in history
+ assert "Answer 1" in history
+ assert "Query 2" in history
+ assert "Answer 2" in history
+
+ # Check we have 2 turns
+ assert runner_with_memory.conversation_memory.count_turns() == 2
+
+ def test_memory_with_complex_metadata(self, runner_with_memory, mock_agent):
+ """Test memory handling of complex metadata.
+
+ Tests:
+ - Step history is properly stored
+ - Complex nested metadata works
+ - Metadata survives memory errors
+ """
+ # Setup mock with multiple steps
+ mock_function1 = Function(name="search", args=["query"])
+ mock_agent.planner.call.side_effect = [
+ GeneratorOutput(data=mock_function1, error=None, raw_response="raw"),
+ GeneratorOutput(
+ data=Function(name="finish", _is_answer_final=True, _answer="Done"),
+ error=None,
+ raw_response="raw"
+ )
+ ]
+
+ mock_agent.tool_manager.execute_func.return_value = FunctionOutput(
+ name="search",
+ input=mock_function1,
+ output=ToolOutput(output="Search results", observation="Found items")
+ )
+
+ # Run with multiple steps
+ result = runner_with_memory.call({"input_str": "Multi-step query"})
+
+ # Check step history in metadata
+ all_messages = runner_with_memory.conversation_memory.get_all_messages()
+ assistant_messages = [m for m in all_messages if m.role == "assistant"]
+
+ if assistant_messages:
+ last_assistant = assistant_messages[-1]
+ assert last_assistant.metadata is not None
+ assert "step_history" in last_assistant.metadata
+ assert len(last_assistant.metadata["step_history"]) > 0
+
+ def test_runner_without_memory(self, mock_agent):
+ """Test that runner works fine without memory.
+
+ Tests:
+ - Runner can be initialized without memory
+ - All safe memory operations handle None memory
+ - Execution completes normally
+ """
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=None)
+
+ assert runner.conversation_memory is None
+ assert runner.use_conversation_memory is False
+
+ # All safe operations should return appropriate defaults
+ assert runner._safe_create_turn() is None
+ assert runner._safe_add_user_query("query", "turn_id") is None
+ assert runner._safe_add_assistant_response("response", "turn_id") is None
+ assert runner._safe_get_conversation_history() == ""
+
+ # Runner should work normally
+ mock_function = Function(name="finish")
+ mock_function._is_answer_final = True
+ mock_function._answer = "No memory answer"
+
+ mock_agent.planner.call.return_value = GeneratorOutput(
+ data=mock_function,
+ error=None,
+ raw_response="raw"
+ )
+
+ result = runner.call({"input_str": "Test"})
+ assert result.answer == "No memory answer"
+
+
+class TestMemoryErrorRecovery:
+ """Test error recovery scenarios with memory failures."""
+
+ @pytest.fixture
+ def mock_agent(self):
+ """Create a mock agent for testing."""
+ agent = Mock(spec=Agent)
+ agent.max_steps = 3
+ agent.answer_data_type = str
+ agent.is_thinking_model = False
+
+ # Mock planner
+ agent.planner = Mock()
+ agent.planner.call = Mock()
+ agent.planner.acall = Mock()
+ agent.planner.get_prompt = Mock(return_value="test prompt")
+ agent.planner.estimated_token_count = 100
+
+ # Mock tool manager
+ agent.tool_manager = Mock()
+ agent.tool_manager.tools = []
+ agent.tool_manager.execute_func = Mock()
+ agent.tool_manager.execute_func_async = Mock()
+
+ return agent
+
+ def test_memory_failure_during_turn_creation(self, mock_agent):
+ """Test recovery when turn creation fails.
+
+ Tests:
+ - Runner continues without turn_id
+ - Subsequent operations handle None turn_id
+ - Execution completes successfully
+ """
+ memory = FlexibleConversationMemory()
+ memory.create_turn = Mock(side_effect=Exception("Turn creation failed"))
+
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory)
+
+ mock_function = Function(name="finish", _is_answer_final=True, _answer="Success")
+ mock_agent.planner.call.return_value = GeneratorOutput(
+ data=mock_function, error=None, raw_response="raw"
+ )
+
+ with patch('adalflow.components.agent.runner_flexible.log'):
+ result = runner.call({"input_str": "Test"})
+ assert result.answer == "Success"
+
+ def test_partial_memory_failure(self, mock_agent):
+ """Test when some memory operations fail but others succeed.
+
+ Tests:
+ - Turn creation succeeds
+ - User query fails
+ - Assistant response fails
+ - Runner still completes
+ """
+ memory = FlexibleConversationMemory()
+ original_create = memory.create_turn
+ memory.add_user_query = Mock(side_effect=Exception("User query failed"))
+ memory.add_assistant_response = Mock(side_effect=Exception("Assistant response failed"))
+
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory)
+
+ mock_function = Function(name="finish", _is_answer_final=True, _answer="Partial success")
+ mock_agent.planner.call.return_value = GeneratorOutput(
+ data=mock_function, error=None, raw_response="raw"
+ )
+
+ with patch('adalflow.components.agent.runner_flexible.log'):
+ result = runner.call({"input_str": "Test"})
+ assert result.answer == "Partial success"
+
+ # Turn should have been created successfully
+ assert memory.count_turns() == 1
+
+ def test_memory_recovery_between_runs(self, mock_agent):
+ """Test that memory can recover between runs.
+
+ Tests:
+ - First run with memory failure
+ - Memory recovers
+ - Second run succeeds with memory
+ """
+ memory = FlexibleConversationMemory()
+ runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory)
+
+ mock_function = Function(name="finish", _is_answer_final=True, _answer="Answer")
+ mock_agent.planner.call.return_value = GeneratorOutput(
+ data=mock_function, error=None, raw_response="raw"
+ )
+
+ # First run with failing memory
+ original_create = memory.create_turn
+ memory.create_turn = Mock(side_effect=Exception("Temporary failure"))
+
+ with patch('adalflow.components.agent.runner_flexible.log'):
+ result1 = runner.call({"input_str": "Query 1"})
+ assert result1.answer == "Answer"
+
+ # Restore memory functionality
+ memory.create_turn = original_create
+
+ # Second run should work with memory
+ result2 = runner.call({"input_str": "Query 2"})
+ assert result2.answer == "Answer"
+
+ # Second run should have created a turn
+ assert memory.count_turns() == 1
+ history = memory()
+ assert "Query 2" in history
\ No newline at end of file