From b697d271f59f56b6df6ff4b19f8ddfb1f7f85f0b Mon Sep 17 00:00:00 2001 From: Li Yin Date: Sun, 17 Aug 2025 00:12:55 -0700 Subject: [PATCH 1/3] experimenting --- compare_structured_output.py | 314 ++++++++++++++++++++++++++ dataset.py | 121 ++++++++++ experimental.py | 426 +++++++++++++++++++++++++++++++++++ 3 files changed, 861 insertions(+) create mode 100644 compare_structured_output.py create mode 100644 dataset.py create mode 100644 experimental.py diff --git a/compare_structured_output.py b/compare_structured_output.py new file mode 100644 index 00000000..ea1d83db --- /dev/null +++ b/compare_structured_output.py @@ -0,0 +1,314 @@ +from dataclasses import dataclass, field +from typing import List, Optional +import adalflow as adal +from adalflow.core import DataClass, required_field +from adalflow.components.model_client.openai_client import OpenAIClient +from adalflow.components.model_client.anthropic_client import AnthropicAPIClient + +from adalflow.components.output_parsers.outputs import JsonOutputParser, JsonOutputParserPydanticModel +from adalflow.utils import setup_env + +# OpenAI's structured output approach +from openai import OpenAI +from pydantic import BaseModel + +setup_env() + +# =============================== +# SHARED DATA MODELS +# =============================== + +# Pydantic models for OpenAI structured output +class Participants(BaseModel): + names: List[str] + addresses: List[str] + +class CalendarEvent(BaseModel): + name: str + date: str + participants: Participants + +# AdalFlow DataClass models +@dataclass +class ParticipantsData(DataClass): + names: List[str] = field( + metadata={"desc": "List of participant names"}, + default_factory=list + ) + addresses: Optional[List[str]] = field( + metadata={"desc": "List of participant addresses"}, + default_factory=list + ) + +@dataclass +class CalendarEventData(DataClass): + name: str = field( + metadata={"desc": "Name of the calendar event"}, + default_factory=required_field() + ) + date: str = field( + metadata={"desc": "Date of the event"}, + default_factory=required_field() + ) + participants: ParticipantsData = field( + metadata={"desc": "Event participants information"}, + default_factory=required_field() + ) + +# =============================== +# OPENAI STRUCTURED OUTPUT +# =============================== + +def test_openai_structured_output(): + """Test OpenAI's native structured output parsing.""" + print("\n=== Testing OpenAI Structured Output ===") + + client = OpenAI() + + try: + response = client.responses.parse( + model="gpt-4o-2024-08-06", + input=[ + {"role": "system", "content": "Extract the event information. Use a synthesic, very complciated string as the address for the particpant to test very complicated json parsing"}, + { + "role": "user", + "content": "Alice and Bob are going to a science fair on Friday.", + }, + ], + text_format=CalendarEvent, + ) + + event = response.output_parsed + print(f"OpenAI Response: {event}") + + return event + + except Exception as e: + print(f"OpenAI structured output failed: {e}") + return None + +# =============================== +# ADALFLOW GENERATOR + JSON PARSER +# =============================== + +class AdalFlowEventExtractor(adal.Component): + """AdalFlow component using Generator with JsonOutputParser.""" + + def __init__(self, model_client: adal.ModelClient, model_kwargs: dict): + super().__init__() + + # Set up output parser + self.output_parser = JsonOutputParser( + data_class=CalendarEventData, + return_data_class=True + ) + self.output_parser_pydantic = JsonOutputParserPydanticModel( + pydantic_model=CalendarEvent, + return_pydantic_object=True + ) + # Template for the Generator + self.template = r""" + +{{system_prompt}} + +{{output_format_str}} + + +{{user_input}} + + """.strip() + + # Set up Generator + self.llm = adal.Generator( + model_client=model_client, + model_kwargs=model_kwargs, + template=self.template, + output_processors=self.output_parser, + use_cache=False, + ) + + def call(self, user_input: str, system_prompt: str) -> adal.GeneratorOutput: + """Extract event information using AdalFlow Generator + JsonOutputParser.""" + # print(f" output_format_str: {self.output_parser.format_instructions()}") + prompt_kwargs = { + "system_prompt": system_prompt, + "output_format_str": self.output_parser.format_instructions(), + "user_input": user_input + } + + return self.llm(prompt_kwargs=prompt_kwargs) + +def test_adalflow_json_parser(): + """Test AdalFlow's Generator with JsonOutputParser.""" + print("\n=== Testing AdalFlow Generator + JsonOutputParser ===") + + model_config = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 500, + } + } + + extractor = AdalFlowEventExtractor( + model_client=model_config["model_client"], + model_kwargs=model_config["model_kwargs"] + ) + + system_prompt = "Extract the event information. g" + user_input = "Alice and Bob are going to a science fair on Friday." + + try: + response = extractor.call(user_input, system_prompt) + + # print(f"AdalFlow Response: {response}") + + if response.data: + event_data = response.data + print(f"Event Name: {event_data.name}") + print(f"Date: {event_data.date}") + print(f"Participants: {event_data.participants.names}") + print(f"Addresses: {event_data.participants.addresses}") + else: + print(f"Failed to parse, raw response: {response.raw_response}") + + return response + + except Exception as e: + print(f"AdalFlow extraction failed: {e}") + return None + +# =============================== +# COMPARISON AND BENCHMARKING +# =============================== + +def compare_approaches(): + """Compare both approaches side by side.""" + print("\n" + "="*60) + print("STRUCTURED OUTPUT COMPARISON") + print("="*60) + + # Test cases with increasingly complex scenarios + test_cases = [ + "Alice and Bob are going to a science fair on Friday.", + "John, Mary, and David will attend the birthday party on December 25th.", + "The team meeting with Sarah, Mike, Tom, and Lisa is scheduled for next Monday at the conference room.", + # Complex address challenge + "Dr. Katherine Martinez-O'Sullivan and Prof. Ahmed bin Rashid Al-Maktoum will attend the International Conference on AI Ethics on March 15th, 2024. Katherine lives at 1247-B Château de Malmaison, Apt. #47/C, Neuilly-sur-Seine, Île-de-France 92200, France (GPS: 48.8738°N, 2.1667°E) and Ahmed resides at Building 47/Tower C/Floor 23/Unit 2301-A, Sheikh Zayed Road Complex, Near Dubai Mall Metro Station, Dubai, United Arab Emirates, P.O. Box 112233-ABUDHABI-UAE (Emergency Contact: +971-4-XXX-YYYY).", + # JSON-breaking characters and edge cases + "Meeting participants: Alex \"The Coder\" Johnson (address: 123 Main St., Unit #5-B\nSecond Floor\n\"Special Building\"\nCity: Austin\"Texas\" 78701\nCountry: USA\"America\"), Maria José Rodríguez-Pérez (Calle de José Ortega y Gasset, 29\n28006 Madrid\nEspaña), and Zhang Wei 张伟 (北京市朝阳区建国门外大街1号\nBeijing 100001\n中华人民共和国). Event: Tech Summit 2024 on February 29th.", + # Unicode, emojis, and special formatting + "🎉 Grand Opening Party 🎊 on Saturday, Jan 20th! Attendees: Σωκράτης Παπαδόπουλος (Πλατεία Συντάγματος 1, Αθήνα 10563, Ελλάδα 🇬🇷), Владимир Иванович Петров (Красная площадь, дом 1, Москва 109012, Российская Федерация 🇷🇺), and السيد أحمد محمد علي (شارع التحرير رقم 15، القاهرة 11511، جمهورية مصر العربية 🇪🇬).", + # Nested quotes and complex punctuation + "Annual \"Best Practices\" Workshop featuring: Robert \"Bob\" O'Malley-Smith Jr., Ph.D., M.D. (address: \"The Heights\", Building A-1, Suite 100-C, Floor 2.5, 987 Oak Tree Lane & Maple Street Intersection, South Hampton, NY 11968-1234 [Note: Use side entrance during construction]), Dr. Mary-Catherine Van Der Berg-Williams III (Château \"Les Trois Rosés\", 45 Rue de la Paix & Boulevard Saint-Germain, 7ème Arrondissement, Paris 75007, République Française [Buzzer code: \"2024-SECRET\"]), scheduled for December 31st, 2024 at 11:59 PM.", + # JSON structure breakers and extreme formatting + "Emergency meeting tomorrow! Participants: {\"name\": \"John Smith\", \"role\": \"CEO\"} living at [Address Object]: {\"street\": \"123 {Main} Street\", \"apt\": \"#[456-B]\", \"city\": \"New {York}\", \"state\": \"NY\", \"zip\": \"[10001-2345]\", \"coordinates\": {\"lat\": 40.7589, \"lng\": -73.9851}, \"special_notes\": \"Ring bell 3x, say 'pizza delivery', wait 30sec, then ring 2x more\"}, and Jane Doe at \"[CLASSIFIED LOCATION]\" {GPS: [REDACTED], Building: [UNKNOWN], Floor: [N/A]}" + ] + adalflow_count = 0 + openai_count = 0 + + for i, test_case in enumerate(test_cases, 1): + print(f"\n--- Test Case {i} ---") + print(f"Input: {test_case}") + + # Test OpenAI + openai_result = test_openai_structured_output_simple(test_case) + + # Test AdalFlow + adalflow_result = test_adalflow_json_parser_simple(test_case) + + print("adalflow_result:", adalflow_result.data) + print("openai_result:", openai_result) + + + + # Compare results + print("\n--- Comparison ---") + if openai_result and isinstance(openai_result, CalendarEvent): + print("✅Openai") + openai_count += 1 + if adalflow_result and adalflow_result.data and (isinstance(adalflow_result.data, CalendarEvent) or isinstance(adalflow_result.data, CalendarEventData)): + print("✅AdalFlow") + adalflow_count += 1 + + + print(f"OpenAI Count: {openai_count}, AdalFlow Count: {adalflow_count}") + +def test_openai_structured_output_simple(user_input: str): + """Simplified OpenAI test for comparison.""" + client = OpenAI() + + try: + response = client.responses.parse( + model="gpt-4o-2024-08-06", + input=[ + {"role": "system", "content": "Extract the event information."}, + { + "role": "user", + "content": user_input, + }, + ], + text_format=CalendarEvent, + ) + + event = response.output_parsed + return event + except Exception as e: + print(f"OpenAI error: {e}") + return None + +def test_adalflow_json_parser_simple(user_input: str): + """Simplified AdalFlow test for comparison.""" + + + + model_config={ + "model_client": AnthropicAPIClient(), + "model_kwargs": { + "model": "claude-sonnet-4-20250514" + } + } + model_config={ + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "gpt-4o-mini", + } + } + + extractor = AdalFlowEventExtractor( + model_client=model_config["model_client"], + model_kwargs=model_config["model_kwargs"] + ) + + system_prompt = "Extract the event information. " + + try: + return extractor.call(user_input, system_prompt) + except Exception as e: + print(f"AdalFlow error: {e}") + return None + +if __name__ == "__main__": + # Run individual tests + openai_result = test_openai_structured_output() + adalflow_result = test_adalflow_json_parser() + + # Run comparison + compare_approaches() + + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + print("OpenAI Structured Output:") + print(" + Native JSON schema validation") + print(" + Guaranteed structure compliance") + print(" - Requires specific OpenAI models") + print(" - Less flexibility in processing pipeline") + + print("\nAdalFlow Generator + JsonOutputParser:") + print(" + Model-agnostic approach") + print(" + Flexible processing pipeline") + print(" + Integration with optimization framework") + print(" - May require retry logic for parsing failures") + print(" + Better integration with AdalFlow ecosystem") \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 00000000..c87af0c8 --- /dev/null +++ b/dataset.py @@ -0,0 +1,121 @@ +dataset = [ + { + "pages": [ + "NOVALUX SERVICES LLC — internal memo v3\nrev 03/16/2024 | cycle Q1 | ticket 004271\nnotes: last spring round 03-14-2024; prev annual 2019-03-17; audit 15 Mar 2024\nthread list: Liam Chen; Maria Gomez; Henry Stone; Noah Patel; Andrea Ruiz; David Roe\nrooftop set: AC-7C, AC-9F; site: Northwind Offices; zone: PINE-WALK\nroster excerpt: Ruiz, Andrea; Patel, Noah; Chen, Olivia; Long, Peter; Gomez, Maria\nline items:\n- 7C belts swapped 03/15/24\n- 9F filters swapped 03/14/24\n- motor check 2023-03-17\nnames mentioned elsewhere: Omar Diaz, Jason Miller, Karen Bell, Raj Patel\n--- mid-block ------------------------------------------------\nroute grid R2: Chen, Liam | Gomez, Maria | Stone, Emily | Roe, David | Patel, Noah\ndispatch key [core-record: 2024-03-15] ref A1-7C\nnorth list crossref: Henry Stone; Andrea Ruiz\n----------------------------------------------------------------\nmisc dates: 03/15/2024 09:40; 2024/03/15; 15-03-2024\nfooter tag id NLX-INT-004271\n" + ], + "date_rule": "Extract the ISO date appearing between the exact marker '[core-record: ' and the closing ']'.", + "expected_output": { + "document_main_date": "2024-03-15", + "client_information": { + "first_name": "Emily", + "last_name": "Stone" + } + } + }, + { + "pages": [ + "ORION INSURANCE COMPANY — coverage packet\nprint 2025/01/02; cycle close 2024-12-29; sample form rev 12-2024\nreference codes: AX-77-193; BQ-11-004\nnames in circulation: Ethan Li, Priya Nair, Omar Diaz, Laurel Kim, Julia Park\npage markers: P1/2\nlist A (mail sweep): Mendes, Ana; Park, Julia; Li, Ethan; Singh, Maya; Patel, Raj\n--- center band ----------------------------------------------\nwindow text row: Mendes, Carlos | addr token G-441B | @when=2025-02-01@\n--------------------------------------------------------------\nreminders: renewal cycle hits 2026-02-01; example date 02/01/2025 (US)\ntrailing mentions: Raj Patel; Maya Singh; Laurel Kim\n" + ], + "date_rule": "Select the YYYY-MM-DD date enclosed by the markers '@when=' and '@'.", + "expected_output": { + "document_main_date": "2025-02-01", + "client_information": { + "first_name": "Carlos", + "last_name": "Mendes" + } + } + }, + { + "pages": [ + "NORTH RIVER BANK — consolidated lines\nhdr date 11/05/2023; period 10/01/23–10/31/23; ref NRB-STAT-1180\nledger notes: auditor Olivia Chen 02-Nov-2023; prior msg 2023-10-29; contact Jason Miller\nmasked acct **** 4831 | manager Tomas Rivera | approver Ellen Wu\npeople mentioned: Peter Sand; Daniel Cho; Cathy Nguyen; Henry Stone\nflow:\n- ACH in 10/30/23\n- fee waive 11-02-2023\n--- center ledger --------------------------------------------\nparticipants: Sand, Peter | Nair, Priya | Chen, Olivia | Miller, Jason\nseal batch 44A\n--------------------------------------------------------------\nfooter fragments: 05-11-2023; Nov 1, 2023; 2023/11/01\n" + ], + "date_rule": "Use the date contained between '' exactly as YYYY-MM-DD.", + "expected_output": { + "document_main_date": "2023-11-01", + "client_information": { + "first_name": "Priya", + "last_name": "Nair" + } + } + }, + { + "pages": [ + "OAK CREST PROPERTY MGMT. — unit packet\nbldg: Lakeview Rd., unit 5A; intercom map rev 08/23; parking memo 08-23-2022\nnames across building: Amy Tran; Daniel Ortiz; Raj Patel; Sarah Onu; Michael Lin\nstack A roster: Blake, Sarah (prev); Grant, Oliver (current); Ruiz, Andrea; Gomez, Maria\npage 1/3\n--- carryover data -------------------------------------------\nmailbox panel: 5A GRANT O | 5B TRAN A | 6C ORTIZ D | 3D PATEL R\ninspection mentions: 08/22/2022; utilities 08/25/2022; move target 09/01/2022\n--------------------------------------------------------------\nnext page\n", + "OAK CREST PROPERTY MGMT. — notes\nvisitors seen: Andrea Ruiz; Maria Gomez; David Roe; Jason Miller; Olivia Chen\nrandom dates: 20-08-2022; 2022/08/20; 08/20/22\n--- mid-strip -------------------------------------------------\nkey timeline |dt|2022-08-20|dt| for archive tag LC-5A\n--------------------------------------------------------------\nother: parking review 2022-08-23; form edits 08-18-2022\npage 2/3\n", + "OAK CREST PROPERTY MGMT. — misc\nunit map checksum 5A-7F-2C; contact index L.Park; badge review 2022-08-21\nfooter copy ids: OCP-5A-AG\npage 3/3\n" + ], + "date_rule": "Extract the date between the exact tokens '|dt|' and '|dt|' on page 2.", + "expected_output": { + "document_main_date": "2022-08-20", + "client_information": { + "first_name": "Oliver", + "last_name": "Grant" + } + } + }, + { + "pages": [ + "CITYCARE CLINIC — visit archive\nprint 2021-07-14; prior vaccination 2021-06-10; relative visit 06/30/2021\nstaff roll: Mark Holloway; Eva Burns; Nora Lee; Raj Patel\nname scatter: Ramirez, Luis; Mendes, Carlos; Stone, Emily; Lee, Marcus; Petrova, Sofia\nsymptoms log id CC-7781\n--- middle row ------------------------------------------------\nconsent line: Ramirez, Zoe {iso:2021-07-20} sig on file\n--------------------------------------------------------------\nother timestamps: 2021/07/18; 07-20-21; 2021-01-01 (policy)\n" + ], + "date_rule": "Choose the YYYY-MM-DD value inside the braces after 'iso:' on the consent line.", + "expected_output": { + "document_main_date": "2021-07-20", + "client_information": { + "first_name": "Zoe", + "last_name": "Ramirez" + } + } + }, + { + "pages": [ + "XELTRONICS CORPORATION — hiring bundle\nrev 05/11/2020 meeting; start target 06/01/2020; approvals 2020-05-09\nnames across shortlist: Bell, Karen; Park, Julia; Diaz, Omar; Young, Samuel; Novak, Diana\npipeline list: Chen, Olivia; Patel, Raj; Lee, Marcus; Ali, Hassan; Brooks, Natalie\n--- middle band ----------------------------------------------\nroster slot SE-I: Ali, Hassan #on:2020-05-12# marker SE1-B\n--------------------------------------------------------------\nother dates: 05/12/20; 2020/05/12; 12-05-2020\n" + ], + "date_rule": "Extract the YYYY-MM-DD date between '#on:' and '#'.", + "expected_output": { + "document_main_date": "2020-05-12", + "client_information": { + "first_name": "Hassan", + "last_name": "Ali" + } + } + }, + { + "pages": [ + "GREENFIELD UNIVERSITY — decision file\nprint 2019-03-10; committee 11/03/2019; orientation 2019-08-26; deadline 04/15/2019\nstaff: Harold King; Maya Singh; Raj Patel; Olivia Chen\nnames in cohort: Cole, Jason; Lee, Marcus; Brooks, Natalie; Mendes, Carlos; Nair, Priya\n--- central cut ----------------------------------------------\nfile tag GU-2019-4412: Petrova, Sofia <<2019-03-12>> status ADMIT\n--------------------------------------------------------------\nmirrored dates: 03-12-2019; 2019/03/12\n" + ], + "date_rule": "Use the date inside the double angle brackets '<<' and '>>'.", + "expected_output": { + "document_main_date": "2019-03-12", + "client_information": { + "first_name": "Sofia", + "last_name": "Petrova" + } + } + }, + { + "pages": [ + "SKYQUEST TRAVEL — booking ledger\nprint 2018-09-04; quote 09/02/2018; depart 2018-10-05 07:25; return 2018-10-12 19:40\nagent: Irene Zhao; group coord: Hannah Park\nnames log: Cho, Daniel; Nguyen, Cathy; Cole, Jason; Petrova, Sofia; Mendes, Carlos\nPNR refs: LEE/MARCUS; CHO/DANIEL; NGUYEN/CATHY\n--- mid block -------------------------------------------------\nrecord LEE/MARCUS (iso) 2018-09-03 (iso) JFK leg confirm\n--------------------------------------------------------------\nother styles: 09-03-2018; 2018/09/03\n" + ], + "date_rule": "Extract the YYYY-MM-DD that appears between the tokens '(iso) ' and ' (iso)'.", + "expected_output": { + "document_main_date": "2018-09-03", + "ClientInformation": { + "first_name": "Marcus", + "last_name": "Lee" + } + } + }, + { + "pages": [ + "RIVERSIDE ENERGY — statement dump\nperiod 11/01/2017–11/30/2017; read 2017-11-28; prior payment 2017-11-10; rebate 2017-10-05\ntouchpoints: Olga Ivanov; Sean Murphy; Emily Stone; Oliver Grant; Priya Nair\naccount roster sample: Murphy, Sean; Grant, Oliver; Nair, Priya; Brooks, Natalie; Lee, Marcus\n--- center line ----------------------------------------------\nacct row Orchard Ave: Brooks, Natalie [bill.iso=2017-12-01] due 12/21/2017\n--------------------------------------------------------------\nfooter echoes: 2017/12/01; 01-12-2017; Dec 1, 2017\n" + ], + "date_rule": "Use the date that appears after 'bill.iso=' inside the square brackets.", + "expected_output": { + "document_main_date": "2017-12-01", + "client_information": { + "first_name": "Natalie", + "last_name": "Brooks" + } + } + } +] \ No newline at end of file diff --git a/experimental.py b/experimental.py new file mode 100644 index 00000000..810c8375 --- /dev/null +++ b/experimental.py @@ -0,0 +1,426 @@ +from dataclasses import dataclass, field +from typing import Literal +from adalflow.datasets.types import DataClass, BaseData +from adalflow.core import DataClass, required_field +from dataset import dataset +from adalflow.utils import setup_env + +setup_env() +@dataclass +class DataExtractionInput(BaseData): + pages: list[str] = field( + metadata={"desc": "The pages of the document"}, + default_factory=required_field() + ) + date_rule: str = field( + metadata={"desc": "The rule for extracting the date from the document"}, + default_factory=required_field() + ) + expected_output: dict = field( + metadata={"desc": "The ground truth of the data extraction"}, + default=None + ) + + + __input_fields__ = ["pages", "date_rule"] + + +@dataclass +class ClientInformation(DataClass): + + primary_client_reasoning_scratch_pad: str = field( + metadata={"desc": "The reasoning process for determining the primary client name. This is the name that should be used for the client information."}, + default_factory=required_field() + ) + first_name_reasoning_scratch_pad: str = field( + metadata={"desc": "The reasoning process for determining the client's first name."}, + default_factory=required_field() + ) + first_name: str = field( + metadata={"desc": "The client's given first name, parsed according to rules."}, + default_factory=required_field() + ) + middle_name: str = field( + metadata={"desc": "The client's middle name, if present as a full word after parsing."}, + default_factory=required_field() + ) + last_name_reasoning_scratch_pad: str = field( + metadata={"desc": "The reasoning process for determining the client's last name."}, + default_factory=required_field() + ) + last_name: str = field( + metadata={"desc": "The client's surname or family name, parsed according to rules."}, + default_factory=required_field() + ) + + +@dataclass +class DataExtractionOutput(BaseData): + + document_dates: list[str] = field( + metadata={"desc": "The list of dates found in the document"}, + default_factory=list, + ) + document_main_date: str = field( + metadata={"desc": "The main date of the document, extracted from the list document_dates"}, + default_factory=required_field() + ) + client_information: ClientInformation = field( + metadata={"desc": "The client information of the document"}, + default_factory=required_field(), + ) + + __output_fields__ = ["document_dates", "document_main_date", "client_information"] + +input_dataclass_list = [] + +for item in dataset: + dataset_item = DataExtractionInput.from_dict(item) + # dataset_item_1 = DataExtractionInput( + # pages=item["pages"], + # date_rule=item["date_rule"], + # expected_output=item["expected_output"] + # ) + # assert dataset_item == dataset_item_1, "Dataset item does not match the expected structure" + input_dataclass_list.append(dataset_item) + + +train_dataset = input_dataclass_list[:2] +val_dataset = input_dataclass_list[3:5] +test_dataset = input_dataclass_list[6:8] + +len(train_dataset), len(val_dataset), len(test_dataset) + +from adalflow.components.model_client.openai_client import OpenAIClient +import adalflow as adal + +model_openai_o4_mini = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "o4-mini", # or "o1" + "reasoning": { + "effort": "medium", # low, medium, high + } + } +} + +template = r""" + + {{system_prompt}} + + {{output_format_str}} + + {% if few_shot_demos is not none %} + Here are some examples: + {{few_shot_demos}} + {% endif %} + + + {{input_str}} + + """.strip() + + +task_prompt_document_extraction = r""" +You are a helpful assistant specialized in data processing and extraction. +""".strip() + + +from typing import Union, Optional +import adalflow as adal + +class DataExtractor(adal.Component): + + def __init__(self, model_client: adal.ModelClient, model_kwargs: dict): + super().__init__() + + # INPUT + self.data_class_input = DataExtractionInput + self.parser_input = adal.DataClassParser( + data_class=self.data_class_input, return_data_class=True, format_type="json" + ) + + # OUTPUT + task_desc_str = adal.Prompt( + template=task_prompt_document_extraction, + # prompt_kwargs={"classes": label_desc} #prompt variables to be hydrated + )() + self.data_class_output = DataExtractionOutput + self.data_class_output.set_task_desc(task_desc_str) + + self.parser_output = adal.DataClassParser( + data_class=self.data_class_output, return_data_class=True, format_type="json" + ) + + print(f"oututput format: {self.parser_output.get_output_format_str()}") + + # GENERATOR PARAMS + prompt_kwargs = { + "system_prompt": adal.Parameter( + data=self.parser_output.get_task_desc_str(), + role_desc="Task description", + requires_opt=True, + param_type=adal.ParameterType.PROMPT, + ), + "output_format_str": adal.Parameter( + data=self.parser_output.get_output_format_str(), + role_desc="Output format requirements", + requires_opt=True, + param_type=adal.ParameterType.PROMPT, + ), + + # I didnt enable few shot demos to not overcomplicate, but it would be nice to get this working too. :D + + # "few_shot_demos": adal.Parameter( + # data=None, + # requires_opt=True, + # role_desc="Few shot examples to help the model", + # param_type=adal.ParameterType.DEMOS, + # ), + } + + self.llm = adal.Generator( + model_client=model_client, + model_kwargs=model_kwargs, + prompt_kwargs=prompt_kwargs, + template=template, + output_processors=self.parser_output, + use_cache=False, + ) + + print(f"system prompt: {self.llm.get_prompt()}") + + def _prepare_input(self, dataset_item: DataExtractionInput): + + # QUESTION: + # Im my use case, I have some arguments to pass to the prompt that are different for each dataset item. + # Normaly I would put it in the system prompt, but here I was not sure if I could change the system prompt inside _prepare_input. Maybe it could break something. + # So here in this example code, Im just passing it as a dump of the input data. + + input_data = self.data_class_input(pages=dataset_item.pages, date_rule=dataset_item.date_rule) + input_str = self.parser_input.get_input_str(input_data) + + prompt_kwargs = { + "input_str": adal.Parameter( + data=input_str, requires_opt=False, role_desc="input to the LLM" + ) } + return prompt_kwargs + + def bicall(self, + dataset_item: DataExtractionInput, + id: Optional[str] = None + ) -> Union[adal.GeneratorOutput, adal.Parameter]: + prompt_kwargs = self._prepare_input(dataset_item) + output = self.llm(prompt_kwargs=prompt_kwargs, id=id) + + return output + +task = DataExtractor( + model_client=model_openai_o4_mini["model_client"], + model_kwargs=model_openai_o4_mini["model_kwargs"], +) +print(task) + +from typing import Dict, Callable, Any, Tuple + +from adalflow.eval.answer_match_acc import AnswerMatchAcc + +def eval_fn_data_extraction( + y: DataExtractionOutput, + y_gt: DataExtractionInput +) -> float: + # I got some runs where the LLM failed to parse the output correctly. + # TODO: try do to some retry somewhere in the code. + + # QUESTION: + # I dont know if AdalFlow framework support it? + # Maybe if the framework use the response format of the API it would increase the json output accuracy? + + try: + patient_first_name_pred = y.client_information.first_name + patient_first_name_gt = y_gt.expected_output["client_information"]["first_name"] + return AnswerMatchAcc(type="exact_match").compute_single_item(patient_first_name_pred, patient_first_name_gt) + + except Exception as e: + print(f"Parse error: {e}") + return 0 + +class DataExtractorTrainner(adal.AdalComponent): + def __init__( + self, + model_client: adal.ModelClient, + model_kwargs: Dict, + teacher_model_config: Dict, + backward_engine_model_config: Dict, + text_optimizer_model_config: Dict, + ): + task = DataExtractor(model_client, model_kwargs) + # eval_fn = AnswerMatchAcc(type="exact_match").compute_single_item + eval_fn = eval_fn_data_extraction + loss_fn = adal.EvalFnToTextLoss( + eval_fn=eval_fn, + eval_fn_desc="exact_match: 1 if str(y) == str(y_gt) else 0. When the LLM prediction failed with format parsing which results with errors, we set y_pred = -1", + ) + super().__init__( + task=task, + eval_fn=eval_fn, + loss_fn=loss_fn, + backward_engine_model_config=backward_engine_model_config, + text_optimizer_model_config=text_optimizer_model_config, + teacher_model_config=teacher_model_config, + ) + + def prepare_task(self, sample: DataExtractionInput): + return self.task.call, {"dataset_item": sample, "id": sample.id} + + def prepare_eval( + self, sample: DataExtractionInput, y_pred: adal.GeneratorOutput + ) -> float: + + prediction = -1 + if y_pred and y_pred.data is not None: + prediction = y_pred.data + return self.eval_fn, {"y": prediction, "y_gt": sample} + + def prepare_loss( + self, dataset_item: DataExtractionInput, y_pred: adal.Parameter, *args, **kwargs + ) -> Tuple[Callable[..., Any], Dict]: + + # QUESTION: + # How can I create a system that has multiple target variables to compute loss? + + + full_response = y_pred.data + y_label = -1 # default value for failed prediction + try: + first_name = full_response.data.client_information.first_name + except Exception as e: + print(f"Parse error: {e}") + first_name = -1 + + if isinstance(first_name, str) and first_name.strip(): + y_label = first_name.strip() + + y_pred.eval_input = y_label + + y_gt = adal.Parameter( + name="y_gt", + data=dataset_item.expected_output["client_information"]["first_name"], + eval_input=dataset_item.expected_output["client_information"]["first_name"], + requires_opt=False, + ) + + return self.loss_fn, { + "kwargs": { + "y": y_pred, + "y_gt": y_gt + }, + "id": dataset_item.id, + "gt": y_gt.eval_input, + } + +def train( + model_client: adal.ModelClient, + model_kwargs: Dict, + train_batch_size=4, + raw_shots: int = 0, + bootstrap_shots: int = 1, + max_steps=12, + num_workers=4, + strategy="constrained", + optimization_order="sequential", + debug=False, +): + print("Starting training process...") + + # Define the model configuration for all components + # gpt_4o_model = { + # "model_client": OpenAIClient(), + # "model_kwargs": { + # "model": "gpt-4o-mini", + # "temperature": 1, + # "top_p": 0.99, + # "max_tokens": 1000, + # # "frequency_penalty": 1, # high for nto repeating prompt + # }, + # } + model_openai_o4_mini = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "o4-mini", # or "o1" + "reasoning": { + "effort": "medium", # low, medium, high + } + } + } + + model_openai_gpt_5 = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "gpt-5", # or "o1" + "reasoning": { + "effort": "medium", # low, medium, high + } + } + } + + print(f"Component model configuration: {model_openai_o4_mini}") + + try: + print("Initializing ADAL component...") + adal_component = DataExtractorTrainner( + model_client=model_client, + model_kwargs=model_kwargs, + text_optimizer_model_config=model_openai_gpt_5, + backward_engine_model_config=model_openai_o4_mini, + teacher_model_config=model_openai_o4_mini, + ) + print("ADAL component initialized successfully") + + print("Initializing trainer...") + trainer = adal.Trainer( + train_batch_size=train_batch_size, + adaltask=adal_component, + strategy=strategy, + max_steps=max_steps, + num_workers=num_workers, + raw_shots=raw_shots, + bootstrap_shots=bootstrap_shots, + debug=debug, + weighted_sampling=True, + optimization_order=optimization_order, + exclude_input_fields_from_bootstrap_demos=True, + ) + print("Trainer initialized successfully") + + print("Loading datasets...") + # train_dataset, val_dataset, test_dataset = load_datasets() + print( + f"Datasets loaded - Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}" + ) + + print("Starting model training...") + trainer.fit( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + debug=debug, + ) + print("Training completed successfully") + + except Exception as e: + print(f"Error occurred: {str(e)}") + raise + + +model_openai_o4_mini = { + "model_client": OpenAIClient(), + "model_kwargs": { + "model": "o4-mini", # or "o1" + "reasoning": { + "effort": "medium", # low, medium, high + } + } + } + +train(**model_openai_o4_mini) \ No newline at end of file From 87f9b59aa0346383de3421506cbb4a79742dba2b Mon Sep 17 00:00:00 2001 From: Li Yin Date: Tue, 19 Aug 2025 12:26:01 -0700 Subject: [PATCH 2/3] quick fix of memory add issue, causing no execute complte events --- adalflow/adalflow/components/agent/agent.py | 16 +- adalflow/adalflow/components/agent/runner.py | 127 ++++-- .../components/memory/flexible_memory.py | 362 ++++++++++++++++++ adalflow/adalflow/components/memory/memory.py | 14 +- .../output_parsers/dataclass_parser.py | 2 +- .../components/output_parsers/outputs.py | 1 - adalflow/adalflow/core/types.py | 136 ++++++- adalflow/tests/test_memory.py | 12 +- adalflow/tests/test_openai_client.py | 202 ++++++++++ adalflow/tests/test_runner.py | 227 +++++++++++ 10 files changed, 1037 insertions(+), 62 deletions(-) create mode 100644 adalflow/adalflow/components/memory/flexible_memory.py diff --git a/adalflow/adalflow/components/agent/agent.py b/adalflow/adalflow/components/agent/agent.py index 96d3cc4e..ad96d9bc 100644 --- a/adalflow/adalflow/components/agent/agent.py +++ b/adalflow/adalflow/components/agent/agent.py @@ -18,6 +18,8 @@ from adalflow.core.tool_manager import ToolManager from adalflow.core.prompt_builder import Prompt from adalflow.core.types import GeneratorOutput, ModelType, Function +from adalflow.core.base_data_class import DataClass, DataClassFormatType + from adalflow.optim.parameter import Parameter, ParameterType from adalflow.components.output_parsers import JsonOutputParser from adalflow.utils import printc @@ -54,7 +56,7 @@ # ' - +# TODO: replace the agent to pydantic Function Model. But it cant control the fields # the context will wrap the whole component def create_default_tool_manager( # Tool manager parameters @@ -154,9 +156,17 @@ def create_default_planner( include_fields = ["name", "kwargs", "_is_answer_final", "_answer"] else: include_fields = ["thought", "name", "kwargs", "_is_answer_final", "_answer"] + + examples = [ + Function( + name="example_function", + kwargs={"param1": "value1", "param2": "value2"}, + _is_answer_final=False, + _answer=None,) + ] output_parser = JsonOutputParser( data_class=ouput_data_class, - examples=None, + examples=examples, # examples=self._examples, # TODO: add examples return_data_class=True, include_fields=include_fields, @@ -169,7 +179,7 @@ def create_default_planner( prompt_kwargs = { "tools": tool_manager.yaml_definitions, - "output_format_str": output_parser.format_instructions(), + "output_format_str": output_parser.format_instructions(format_type=DataClassFormatType.SIGNATURE_JSON), "task_desc": Parameter( name="react_agent_task_desc", data=task_desc, diff --git a/adalflow/adalflow/components/agent/runner.py b/adalflow/adalflow/components/agent/runner.py index 2fc486c2..5ecfd068 100644 --- a/adalflow/adalflow/components/agent/runner.py +++ b/adalflow/adalflow/components/agent/runner.py @@ -20,7 +20,7 @@ AsyncIterable, ) from typing_extensions import TypeAlias -import sys +import uuid from adalflow.optim.parameter import Parameter @@ -323,7 +323,7 @@ def _check_last_step(self, step: Function) -> bool: def _get_final_answer(self, function: Function) -> Any: """Get and process the final answer from the function.""" - if hasattr(function, "_answer"): + if hasattr(function, "_answer") or (hasattr(function, "_is_answer_final") and function._is_answer_final): return self._process_data(function._answer) return None @@ -338,30 +338,44 @@ def _create_runner_result(self, answer: Any, step_history, error: Optional[str] ) def _create_execution_complete_stream_event(self, streaming_result: RunnerStreamingResult, final_output_item: FinalOutputItem): """Complete the streaming execution by adding a sentinel.""" - final_output_event = RunItemStreamEvent( - name="agent.execution_complete", - item=final_output_item, - ) - streaming_result.put_nowait(final_output_event) + try: + final_output_event = RunItemStreamEvent( + name="agent.execution_complete", + item=final_output_item, + ) + streaming_result.put_nowait(final_output_event) - runner_result: RunnerResult = final_output_item.data + runner_result: RunnerResult = final_output_item.data - # set up the final answer - streaming_result.answer = runner_result.answer if runner_result else None - streaming_result.step_history = self.step_history.copy() - streaming_result._is_complete = True + # set up the final answer + streaming_result.answer = runner_result.answer if runner_result else None + streaming_result.step_history = self.step_history.copy() + streaming_result._is_complete = True + + except Exception as e: + log.error(f"Failed to create execution complete stream event: {e}") + raise e - def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem): + def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem, turn_id: Any = None): # add the assistant response to the conversation memory - if self.use_conversation_memory and self.conversation_memory._pending_user_query is not None: - self.conversation_memory.add_assistant_response( - AssistantResponse( - response_str=final_output_item.data.answer, - metadata={ - "step_history": final_output_item.data.step_history.copy() - }, + try: + if self.use_conversation_memory and turn_id is not None: + # Only add if we have a valid turn_id + self.conversation_memory.add_assistant_response( + AssistantResponse( + response_str=final_output_item.data.answer, + metadata={ + "step_history": final_output_item.data.step_history.copy() + }, + ), + turn_id=turn_id ) - ) + elif self.use_conversation_memory: + log.warning("Skipping add_assistant_response - no turn_id available") + except Exception as e: + log.error(f"Failed to add assistant response to memory: {e}") + # Don't re-raise, just log the error + return def create_response_span(self, runner_result, step_count: int, streaming_result: RunnerStreamingResult, runner_span_instance, workflow_status: str = "stream_completed"): @@ -395,7 +409,8 @@ async def _process_stream_final_step( answer: Any, step_count: int, streaming_result, - runner_span_instance + runner_span_instance, + turn_id: Any = None, ) -> FinalOutputItem: """Process the final step and trace it.""" # processed_data = self._get_final_answer(function) @@ -422,8 +437,13 @@ async def _process_stream_final_step( self._create_execution_complete_stream_event( streaming_result, final_output_item ) + + # Ensure event is in queue by yielding control briefly + # This allows the event loop to process the put_nowait before we return + await asyncio.sleep(0) # Small delay to ensure event is properly queued + # add the assistant response to the conversation memory - self._add_assistant_response_to_memory(final_output_item) + self._add_assistant_response_to_memory(final_output_item, turn_id) return final_output_item # TODO: improved after the finish function is refactored @@ -561,6 +581,7 @@ def call( self.step_history ) # a reference to the step history + turn_id = None if self.use_conversation_memory: # Reset any pending query state before starting a new query self.conversation_memory.reset_pending_query() @@ -570,7 +591,7 @@ def call( # meta data is all keys in the list of context_str query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} - self.conversation_memory.add_user_query( + turn_id = self.conversation_memory.add_user_query( UserQuery( query_str=prompt_kwargs.get("input_str", None), metadata=query_metadata, @@ -672,14 +693,15 @@ def call( ) # Add assistant response to conversation memory - if self.use_conversation_memory: + if self.use_conversation_memory and turn_id is not None: self.conversation_memory.add_assistant_response( AssistantResponse( response_str=processed_data, metadata={ "step_history": self.step_history.copy() }, - ) + ), + turn_id=turn_id ) step_count += 1 # Increment step count before breaking @@ -917,6 +939,7 @@ async def acall( self.step_history ) # a reference to the step history + turn_id = None if self.use_conversation_memory: # Reset any pending query state before starting a new query self.conversation_memory.reset_pending_query() @@ -926,7 +949,7 @@ async def acall( # meta data is all keys in the list of context_str query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} - self.conversation_memory.add_user_query( + turn_id = self.conversation_memory.add_user_query( UserQuery( query_str=prompt_kwargs.get("input_str", None), metadata=query_metadata, @@ -1043,14 +1066,15 @@ async def acall( ) # Add assistant response to conversation memory - if self.use_conversation_memory: + if self.use_conversation_memory and turn_id is not None: self.conversation_memory.add_assistant_response( AssistantResponse( response_str=answer, metadata={ "step_history": self.step_history.copy() }, - ) + ), + turn_id=turn_id ) @@ -1268,6 +1292,7 @@ async def impl_astream( """ workflow_status: Literal["streaming", "stream_completed", "stream_failed", "stream_incomplete"] = "streaming" # Create runner span for tracing streaming execution + turn_id = None with runner_span( runner_id=id or f"stream_runner_{hash(str(prompt_kwargs))}", max_steps=self.max_steps, @@ -1288,7 +1313,7 @@ async def impl_astream( # meta data is all keys in the list of context_str query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} - self.conversation_memory.add_user_query( + turn_id = self.conversation_memory.add_user_query( UserQuery( query_str=prompt_kwargs.get("input_str", None), metadata=query_metadata, @@ -1386,10 +1411,10 @@ async def impl_astream( else: # non-streaming cases # yield the final planner response - if output.data is None: + if output.data is None or (not isinstance(output.data, Function)): # recoverable errors, continue to create stepout - current_error = output.error + current_error = f"Error: {output.error} - data: {output.data}, raw_response: {output.raw_response}" # wrap the error in a RawResponsesStreamEvent wrapped_event = RawResponsesStreamEvent( data=None, # no data in this case @@ -1411,6 +1436,10 @@ async def impl_astream( item=step_item, ) streaming_result.put_nowait(step_complete_event) + + # Ensure event is processed before continuing + await asyncio.sleep(0) # Yield control to allow queue processing + self.step_history.append(step_output) if output.error is not None: @@ -1467,14 +1496,26 @@ async def impl_astream( log.debug(f"function: {function}") if self._check_last_step(function): # skip stepoutput - answer = self._get_final_answer(function) + try: + answer = self._get_final_answer(function) + except Exception as e: + # If processing the final answer fails, use the raw answer + log.warning(f"Failed to process final answer: {e}. Using raw answer.") + answer = function._answer if hasattr(function, "_answer") else str(e) + final_output_item = await self._process_stream_final_step( answer=answer, step_count=step_count, streaming_result=streaming_result, runner_span_instance=runner_span_instance, + turn_id=turn_id, ) workflow_status = "stream_completed" + + # Ensure the queue has processed the execution_complete event + # Add a small yield to allow the event loop to process the queued events + await asyncio.sleep(0) # Yield control to allow queue processing + break # Check if permission is required and emit permission event @@ -1556,6 +1597,10 @@ async def impl_astream( name="agent.step_complete", item=step_item ) streaming_result.put_nowait(step_event) + + # Ensure event is processed before continuing + await asyncio.sleep(0) # Yield control to allow queue processing + step_count += 1 except asyncio.CancelledError: @@ -1580,7 +1625,7 @@ async def impl_astream( streaming_result._is_complete = True # Add cancellation response to conversation memory - if self.use_conversation_memory: + if self.use_conversation_memory and turn_id is not None: self.conversation_memory.add_assistant_response( AssistantResponse( response_str="I apologize, but the execution was cancelled by the user.", @@ -1589,7 +1634,8 @@ async def impl_astream( "status": "cancelled", "timestamp": datetime.now().isoformat() } - ) + ), + turn_id=turn_id ) # Signal completion and break @@ -1627,6 +1673,12 @@ async def impl_astream( workflow_status = "stream_incomplete" current_error = f"No output generated after {step_count} steps (max_steps: {self.max_steps})" + # Only emit execution_complete if we created a new final_output_item + # (i.e., when the loop ended without a final answer) + self._create_execution_complete_stream_event( + streaming_result=streaming_result, + final_output_item=final_output_item, + ) runner_span_instance.span_data.update_attributes( { @@ -1644,11 +1696,6 @@ async def impl_astream( error=current_error, ) - self._create_execution_complete_stream_event( - streaming_result=streaming_result, - final_output_item=final_output_item, - ) - # create response span for final output # if workflow_status in ["stream_incomplete", "stream_failed"]: self.create_response_span( diff --git a/adalflow/adalflow/components/memory/flexible_memory.py b/adalflow/adalflow/components/memory/flexible_memory.py new file mode 100644 index 00000000..d6adbec5 --- /dev/null +++ b/adalflow/adalflow/components/memory/flexible_memory.py @@ -0,0 +1,362 @@ +"""Flexible conversation memory with turns containing multiple messages. + +This memory design uses an OrderedDict where each turn_id maps to a list of messages. +This allows multiple user queries and assistant responses within the same turn. +""" + +from uuid import uuid4 +from typing import Optional, List, Dict, Any, Literal +from dataclasses import dataclass, field +from datetime import datetime +from collections import OrderedDict +from adalflow.core.component import Component +from adalflow.core.db import LocalDB +from adalflow.core.types import DataClass +from adalflow.core.prompt_builder import Prompt + + +@dataclass +class Message(DataClass): + """A single message in a conversation.""" + id: str = field(default_factory=lambda: str(uuid4())) + role: Literal["user", "assistant", "system"] = "user" + content: str = "" + metadata: Optional[Dict[str, Any]] = None + timestamp: datetime = field(default_factory=datetime.now) + + @classmethod + def from_user(cls, content: str, metadata: Optional[Dict] = None): + """Create a user message.""" + return cls(role="user", content=content, metadata=metadata) + + @classmethod + def from_assistant(cls, content: str, metadata: Optional[Dict] = None): + """Create an assistant message.""" + return cls(role="assistant", content=content, metadata=metadata) + + @classmethod + def from_system(cls, content: str, metadata: Optional[Dict] = None): + """Create a system message.""" + return cls(role="system", content=content, metadata=metadata) + + +@dataclass +class Conversation(DataClass): + """A conversation organized as turns, where each turn can have multiple messages.""" + id: str = field(default_factory=lambda: str(uuid4())) + user_id: Optional[str] = None + turns: OrderedDict = field(default_factory=OrderedDict) # turn_id -> List[Message] + metadata: Optional[Dict[str, Any]] = None + created_at: datetime = field(default_factory=datetime.now) + _current_turn_id: Optional[str] = field(default=None, init=False) + + def add_message_to_turn(self, turn_id: str, message: Message) -> str: + """Add a message to a specific turn. + + Args: + turn_id: The turn identifier + message: The message to add + + Returns: + str: The message ID + """ + if turn_id not in self.turns: + self.turns[turn_id] = [] + self.turns[turn_id].append(message) + return message.id + + def get_turn_messages(self, turn_id: str) -> List[Message]: + """Get all messages in a specific turn.""" + return self.turns.get(turn_id, []) + + def get_all_messages(self) -> List[Message]: + """Get all messages in order across all turns.""" + messages = [] + for turn_messages in self.turns.values(): + messages.extend(turn_messages) + return messages + + def get_messages_by_role(self, role: str) -> List[Message]: + """Get all messages from a specific role.""" + messages = [] + for turn_messages in self.turns.values(): + messages.extend([msg for msg in turn_messages if msg.role == role]) + return messages + + def get_last_user_message(self) -> Optional[Message]: + """Get the most recent user message.""" + for turn_messages in reversed(list(self.turns.values())): + for msg in reversed(turn_messages): + if msg.role == "user": + return msg + return None + + def get_last_assistant_message(self) -> Optional[Message]: + """Get the most recent assistant message.""" + for turn_messages in reversed(list(self.turns.values())): + for msg in reversed(turn_messages): + if msg.role == "assistant": + return msg + return None + + def create_turn(self) -> str: + """Create a new turn and return its ID.""" + turn_id = str(uuid4()) + self.turns[turn_id] = [] + self._current_turn_id = turn_id + return turn_id + + +# Template for conversation formatting +CONVERSATION_TEMPLATE = r""" +{% for turn_id, messages in turns.items() -%} +{% for message in messages -%} +{% if message.role == "user" -%} +User: {{ message.content }} +{% if message.metadata -%} +{% for key, value in message.metadata.items() -%} +{% if not metadata_filter or key in metadata_filter -%} +{{ key }}: {{ value }} +{% endif -%} +{% endfor -%} +{% endif -%} +{% elif message.role == "assistant" -%} +Assistant: {{ message.content }} +{% if message.metadata -%} +{% for key, value in message.metadata.items() -%} +{% if not metadata_filter or key in metadata_filter -%} +{{ key }}: {{ value }} +{% endif -%} +{% endfor -%} +{% endif -%} +{% elif message.role == "system" -%} +System: {{ message.content }} +{% endif -%} +{% endfor -%} +{% endfor -%}""" + + +class FlexibleConversationMemory(Component): + """A flexible conversation memory with turns containing multiple messages.""" + + def __init__(self, turn_db: LocalDB = None, user_id: str = None): + """Initialize the flexible memory component. + + Args: + turn_db: Database for storing messages + user_id: Optional user identifier + """ + super().__init__() + self.current_conversation = Conversation(user_id=user_id) + self.message_db = turn_db or LocalDB() # Store all messages + self.conver_db = LocalDB() # Store complete conversations + self.user_id = user_id + + def clear_conversation(self): + """Clear all turns and messages in the current conversation.""" + self.current_conversation.turns.clear() + self.current_conversation._current_turn_id = None + + def clear_conversation_turns(self): + """Alias for clear_conversation for compatibility.""" + self.clear_conversation() + + def new_conversation(self): + """Start a new conversation, saving the current one.""" + # Save current conversation if it has messages + if self.current_conversation.turns: + self.conver_db.add(self.current_conversation) + + # Create new conversation + self.current_conversation = Conversation(user_id=self.user_id) + + def create_turn(self) -> str: + """Create a new turn and return its ID. + + Returns: + str: The new turn ID + """ + return self.current_conversation.create_turn() + + def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str: + """Add a user message to a turn. + + Args: + content: The user's message content + metadata: Optional metadata + turn_id: Optional turn ID. If None, creates a new turn. + + Returns: + str: The turn ID the message was added to + """ + # Use provided turn_id or create new turn + if turn_id is None: + turn_id = self.create_turn() + elif turn_id not in self.current_conversation.turns: + # Turn doesn't exist, create it + self.current_conversation.turns[turn_id] = [] + + # Track as current turn + self.current_conversation._current_turn_id = turn_id + + # Create and add the user message + message = Message.from_user(content, metadata) + self.current_conversation.add_message_to_turn(turn_id, message) + + # Store in database + self.message_db.add({ + "message_id": message.id, + "turn_id": turn_id, + "role": "user", + "content": content, + "metadata": metadata, + "timestamp": message.timestamp + }) + + return turn_id + + def add_assistant_response( + self, + content: str, + metadata: Optional[Dict] = None, + turn_id: Optional[str] = None + ) -> str: + """Add an assistant message to a turn. + + Args: + content: The assistant's message content + metadata: Optional metadata + turn_id: Optional turn ID. If None, uses current turn or creates new. + + Returns: + str: The turn ID the message was added to + """ + # Determine which turn to use + if turn_id is None: + if self.current_conversation._current_turn_id: + turn_id = self.current_conversation._current_turn_id + else: + # No active turn, create new one for standalone response + turn_id = self.create_turn() + elif turn_id not in self.current_conversation.turns: + # Turn doesn't exist, create it + self.current_conversation.turns[turn_id] = [] + + # Create and add the assistant message + message = Message.from_assistant(content, metadata) + self.current_conversation.add_message_to_turn(turn_id, message) + + # Store in database + self.message_db.add({ + "message_id": message.id, + "turn_id": turn_id, + "role": "assistant", + "content": content, + "metadata": metadata, + "timestamp": message.timestamp + }) + + return turn_id + + def add_system_message(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str: + """Add a system message to a turn. + + Args: + content: The system message content + metadata: Optional metadata + turn_id: Optional turn ID. If None, creates a new turn. + + Returns: + str: The turn ID the message was added to + """ + # Use provided turn_id or create new turn + if turn_id is None: + turn_id = self.create_turn() + elif turn_id not in self.current_conversation.turns: + self.current_conversation.turns[turn_id] = [] + + # Create and add the system message + message = Message.from_system(content, metadata) + self.current_conversation.add_message_to_turn(turn_id, message) + + # Store in database + self.message_db.add({ + "message_id": message.id, + "turn_id": turn_id, + "role": "system", + "content": content, + "metadata": metadata, + "timestamp": message.timestamp + }) + + return turn_id + + def get_turn_messages(self, turn_id: str) -> List[Message]: + """Get all messages for a specific turn. + + Args: + turn_id: The turn identifier + + Returns: + List of messages in that turn + """ + return self.current_conversation.get_turn_messages(turn_id) + + def get_current_turn_id(self) -> Optional[str]: + """Get the current turn ID if any.""" + return self.current_conversation._current_turn_id + + def call(self, metadata_filter: Optional[List[str]] = None) -> str: + """Get the conversation history as a formatted string. + + Args: + metadata_filter: Optional list of metadata keys to include + + Returns: + str: Formatted conversation history + """ + if not self.current_conversation.turns: + return "" + + prompt = Prompt( + template=CONVERSATION_TEMPLATE, + prompt_kwargs={ + "turns": self.current_conversation.turns, + "metadata_filter": metadata_filter, + }, + ) + return prompt.call().strip() + + def get_all_messages(self) -> List[Message]: + """Get all messages in order across all turns.""" + return self.current_conversation.get_all_messages() + + def get_last_n_messages(self, n: int) -> List[Message]: + """Get the last n messages across all turns.""" + messages = self.get_all_messages() + return messages[-n:] if len(messages) >= n else messages + + def count_messages(self) -> Dict[str, int]: + """Count messages by role.""" + counts = {"user": 0, "assistant": 0, "system": 0} + for msg in self.get_all_messages(): + counts[msg.role] = counts.get(msg.role, 0) + 1 + return counts + + def count_turns(self) -> int: + """Count the number of turns.""" + return len(self.current_conversation.turns) + + def reset_pending_query(self): + """Reset current turn tracking. Included for compatibility.""" + # Don't clear the current turn, just unset it as "current" + # This allows adding more messages to existing turns if needed + self.current_conversation._current_turn_id = None + + def __call__(self, metadata_filter: Optional[List[str]] = None) -> str: + """Make the memory callable to get conversation history.""" + return self.call(metadata_filter) + + def __len__(self) -> int: + """Return the number of messages.""" + return len(self.get_all_messages()) \ No newline at end of file diff --git a/adalflow/adalflow/components/memory/memory.py b/adalflow/adalflow/components/memory/memory.py index e30743a3..3b65cdd0 100644 --- a/adalflow/adalflow/components/memory/memory.py +++ b/adalflow/adalflow/components/memory/memory.py @@ -177,15 +177,17 @@ def add_assistant_response( assistant_response (Union[str, AssistantResponse]): The assistant's response message. Returns: - str: The ID of the completed dialog turn. - - Raises: - ValueError: If there's no pending user query to respond to. + str: The ID of the completed dialog turn, or None if no pending query exists. """ if self._pending_user_query is None: - raise ValueError( - "No pending user query found. Please add a user query first." + # Log a warning instead of raising an error + import logging + logging.warning( + "No pending user query found when adding assistant response. " + "This might happen if the response was already added or the conversation was reset. " + "Ignoring this assistant response to avoid duplication." ) + return None # Return None to indicate no turn was added assistant_response = ( assistant_response diff --git a/adalflow/adalflow/components/output_parsers/dataclass_parser.py b/adalflow/adalflow/components/output_parsers/dataclass_parser.py index 32c80bc9..7e4647e5 100644 --- a/adalflow/adalflow/components/output_parsers/dataclass_parser.py +++ b/adalflow/adalflow/components/output_parsers/dataclass_parser.py @@ -135,7 +135,7 @@ def get_output_format_str(self) -> str: if self._format_type == "yaml": schema = self._data_class.to_yaml_signature(include=self._output_fields) output_format_str = Prompt(template=YAML_OUTPUT_FORMAT)(schema=schema) - else: + elif self._format_type == "json": schema = self._data_class.to_json_signature(include=self._output_fields) output_format_str = Prompt(template=JSON_OUTPUT_FORMAT)(schema=schema) return output_format_str diff --git a/adalflow/adalflow/components/output_parsers/outputs.py b/adalflow/adalflow/components/output_parsers/outputs.py index 1378b428..6286a60e 100644 --- a/adalflow/adalflow/components/output_parsers/outputs.py +++ b/adalflow/adalflow/components/output_parsers/outputs.py @@ -57,7 +57,6 @@ - Properly escape special characters: use \\" for quotes, \\\\ for backslashes - For multiline strings, keep them on a single line with \\n characters **WARNING:** The JSON must be parseable by standard JSON parsers. Malformed JSON will cause parsing failures. When handling complex text with special characters, quotes, or formatting, prioritize proper escaping over readability. - """ """**CRITICAL JSON FORMATTING REQUIREMENTS:** diff --git a/adalflow/adalflow/core/types.py b/adalflow/adalflow/core/types.py index 8d7b4f2b..cea05b36 100644 --- a/adalflow/adalflow/core/types.py +++ b/adalflow/adalflow/core/types.py @@ -27,6 +27,7 @@ field, InitVar, ) +from pydantic import BaseModel, Field as PydanticField from uuid import UUID from datetime import datetime import uuid @@ -449,6 +450,128 @@ def add(a, b): __output_fields__ = ["thought", "name", "kwargs", "_is_answer_final", "_answer"] +class FunctionPydantic(BaseModel): + """The data modeling of a function call using Pydantic BaseModel. + + This is a Pydantic-based version of the Function class that uses BaseModel + and Field for validation and serialization. + + Example: + >>> def add(a, b): + ... return a + b + >>> + >>> # Create function call with kwargs + >>> func = FunctionPydantic(name="add", kwargs={"a": 1, "b": 2}) + >>> # Evaluate the function + >>> result = context_map[func.name](**func.kwargs) + >>> + >>> # Create function call with positional args + >>> func = FunctionPydantic(name="add", args=[1, 2]) + >>> result = context_map[func.name](*func.args) + """ + + id: Optional[str] = PydanticField( + default=None, + description="The id of the function call" + ) + thought: Optional[str] = PydanticField( + default=None, + description="Your reasoning for this step. Be short for simple queries. For complex queries, provide a clear chain of thought." + ) + name: str = PydanticField( + default="", + description="The name of the function" + ) + args: Optional[List[Any]] = PydanticField( + default_factory=list, + description="The positional arguments of the function" + ) + kwargs: Optional[Dict[str, Any]] = PydanticField( + default_factory=dict, + description="The keyword arguments of the function" + ) + is_answer_final: Optional[bool] = PydanticField( + default=None, + alias="_is_answer_final", + description="Whether this current output is the final answer" + ) + answer: Optional[Any] = PydanticField( + default=None, + alias="_answer", + description="The final answer if this is the final output" + ) + + class Config: + populate_by_name = True # Allow population by field name or alias + json_encoders = { + datetime: lambda v: v.isoformat(), + } + + @classmethod + def from_function( + cls, + func: Union[Callable[..., Any], AsyncCallable], + thought: Optional[str] = None, + *args, + **kwargs, + ) -> "FunctionPydantic": + """Create a FunctionPydantic object from a function. + + Args: + func: The function to be converted + thought: Optional reasoning for this function call + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + FunctionPydantic: The FunctionPydantic object + + Example: + >>> def add(a, b): + ... return a + b + >>> + >>> # Create with positional arguments + >>> func = FunctionPydantic.from_function(add, "Add two numbers", 1, 2) + >>> print(func) + >>> # FunctionPydantic(thought='Add two numbers', name='add', args=[1, 2]) + """ + return cls( + thought=thought, + name=func.__name__, + args=list(args) if args else [], + kwargs=kwargs if kwargs else {}, + ) + + def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]: + """Convert to dictionary representation. + + Args: + exclude_none: Whether to exclude None values + + Returns: + Dictionary representation of the function call + """ + data = self.model_dump(exclude_none=exclude_none) + # Include aliased fields with their original names if present + if self.is_answer_final is not None: + data['_is_answer_final'] = self.is_answer_final + if self.answer is not None: + data['_answer'] = self.answer + return data + + def to_json(self, exclude_none: bool = True, indent: int = 2) -> str: + """Convert to JSON string representation. + + Args: + exclude_none: Whether to exclude None values + indent: JSON indentation level + + Returns: + JSON string representation + """ + return self.model_dump_json(exclude_none=exclude_none, indent=indent) + + _action_desc = """FuncName() \ Valid function call expression. \ Example: "FuncName(a=1, b=2)" \ @@ -1634,12 +1757,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: yield event # mark the task as done self._event_queue.task_done() - # if the event is a RunItemStreamEvent and the name is agent.execution_complete then additionally break the loop - if ( - isinstance(event, RunItemStreamEvent) - and event.name == "agent.execution_complete" - ): - break + # Don't break on agent.execution_complete - let QueueCompleteSentinel handle the break + # This ensures the execution_complete event is properly processed by all consumers + # if ( + # isinstance(event, RunItemStreamEvent) + # and event.name == "agent.execution_complete" + # ): + # break except asyncio.CancelledError: # Clean up and re-raise to allow proper cancellation diff --git a/adalflow/tests/test_memory.py b/adalflow/tests/test_memory.py index 3680f683..20093b5e 100644 --- a/adalflow/tests/test_memory.py +++ b/adalflow/tests/test_memory.py @@ -98,13 +98,15 @@ def test_error_double_user_query(): def test_error_assistant_response_without_query(): - """Test error when adding assistant response without user query.""" + """Test that adding assistant response without user query returns None and logs warning.""" memory = Memory() - with pytest.raises(ValueError) as exc_info: - memory.add_assistant_response("Random response") - - assert "No pending user query found" in str(exc_info.value) + # Should return None instead of raising error + result = memory.add_assistant_response("Random response") + assert result is None + + # The conversation should remain empty + assert len(memory.current_conversation.dialog_turns) == 0 def test_mixed_methods(): diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index e0b50a39..3946320f 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -761,6 +761,208 @@ def test_multiple_images_input(self): self.assertEqual(image_contents[1]["image_url"], "https://example.com/image2.jpg") self.assertTrue(image_contents[2]["image_url"].startswith("data:image/png;base64,")) + @patch( + "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" + ) + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_sync_streaming_response_format(self, MockSyncOpenAI, mock_init_sync_client): + """Test that sync streaming returns raw_response as generator and data field contains complete result.""" + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Create streaming events generator + def mock_stream(): + for event in self.streaming_events: + yield event + + mock_sync_client.responses.create.return_value = mock_stream() + self.client.sync_client = mock_sync_client + + # Test streaming API kwargs + api_kwargs = { + "model": "gpt-4", + "input": "Tell me a story", + "stream": True + } + + # Call sync streaming method + result = self.client.call(api_kwargs, ModelType.LLM) + + # Verify raw response is a generator + self.assertTrue(hasattr(result, '__iter__'), "raw_response should be iterable/generator") + + # Parse the response through the client's parser + parsed_result = self.client.parse_chat_completion(result) + + # Verify structure: raw_response should be the generator, data should be None for streaming + self.assertIsNotNone(parsed_result.raw_response, "raw_response should contain the generator") + self.assertIsNone(parsed_result.data, "data should be None for streaming responses") + + # Consume the generator to verify it works + events_collected = list(result) + self.assertEqual(len(events_collected), 3, "Should collect all streaming events") + + # Verify the final completed event contains the full result + final_event = events_collected[-1] + self.assertEqual(final_event.type, "response.completed") + self.assertEqual(final_event.response.output_text, "Once upon ") + + async def test_async_streaming_response_format(self): + """Test that async streaming returns raw_response as async generator and data field handling.""" + mock_async_client = AsyncMock() + + # Create async streaming events generator + async def mock_async_stream(): + for event in self.streaming_events: + yield event + await asyncio.sleep(0.001) # Small delay to simulate real streaming + + mock_async_client.responses.create.return_value = mock_async_stream() + self.client.async_client = mock_async_client + + # Test streaming API kwargs + api_kwargs = { + "model": "gpt-4", + "input": "Tell me a story", + "stream": True + } + + # Call async streaming method + result = await self.client.acall(api_kwargs, ModelType.LLM) + + # Verify raw response is an async generator + self.assertTrue(hasattr(result, '__aiter__'), "raw_response should be async iterable") + + # Parse the response through the client's parser + parsed_result = self.client.parse_chat_completion(result) + + # Verify structure: raw_response should contain the generator, data should be None for streaming + self.assertIsNotNone(parsed_result.raw_response, "raw_response should contain the async generator") + self.assertIsNone(parsed_result.data, "data should be None for streaming responses") + + # Consume the async generator to verify it works and extract complete text + events_collected = [] + complete_text = "" + + async for event in result: + events_collected.append(event) + # Extract text from streaming events + if hasattr(event, 'delta'): + complete_text += event.delta + elif hasattr(event, 'response') and hasattr(event.response, 'output_text'): + complete_text = event.response.output_text # Final complete result + + self.assertEqual(len(events_collected), 3, "Should collect all streaming events") + self.assertEqual(complete_text, "Once upon ", "Should extract complete text from stream") + + # Verify the final completed event contains the full result + final_event = events_collected[-1] + self.assertEqual(final_event.type, "response.completed") + self.assertEqual(final_event.response.output_text, "Once upon ") + + async def test_streaming_text_extraction_and_final_result(self): + """Test that streaming properly extracts incremental text and provides final complete result.""" + mock_async_client = AsyncMock() + + # Create more comprehensive streaming events with incremental text + async def comprehensive_mock_stream(): + # Start event + start_event = Mock() + start_event.type = "response.created" + yield start_event + + # Multiple delta events building up text + delta_texts = ["Hello", " there!", " How", " are", " you?"] + for delta_text in delta_texts: + delta_event = Mock() + delta_event.type = "response.output_text.delta" + delta_event.delta = delta_text + yield delta_event + await asyncio.sleep(0.001) + + # Final completion event with complete text + complete_response = Mock() + complete_response.id = "resp-final" + complete_response.model = "gpt-4" + complete_response.output_text = "Hello there! How are you?" + complete_response.usage = ResponseUsage( + input_tokens=5, output_tokens=6, total_tokens=11, + input_tokens_details={"cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0} + ) + + completion_event = Mock() + completion_event.type = "response.completed" + completion_event.response = complete_response + yield completion_event + + mock_async_client.responses.create.return_value = comprehensive_mock_stream() + self.client.async_client = mock_async_client + + api_kwargs = { + "model": "gpt-4", + "input": "Say hello", + "stream": True + } + + # Get streaming result + result = await self.client.acall(api_kwargs, ModelType.LLM) + + # Verify it's an async generator + self.assertTrue(hasattr(result, '__aiter__')) + + # Process stream and extract text using the utility function + incremental_text = "" + final_complete_text = None + event_count = 0 + + async for event in result: + event_count += 1 + + # Use the utility function to extract text + text_fragment = extract_text_from_response_stream(event) + if text_fragment: + incremental_text += text_fragment + + # Check for final completion + if hasattr(event, 'type') and event.type == "response.completed": + if hasattr(event, 'response') and hasattr(event.response, 'output_text'): + final_complete_text = event.response.output_text + + # Verify results + self.assertEqual(event_count, 7, "Should have start + 5 deltas + completion events") + self.assertEqual(incremental_text, "Hello there! How are you?", "Incremental text should match") + self.assertEqual(final_complete_text, "Hello there! How are you?", "Final complete text should match") + self.assertEqual(incremental_text, final_complete_text, "Incremental and final text should be identical") + + def test_streaming_vs_non_streaming_data_field_behavior(self): + """Test that data field behavior differs correctly between streaming and non-streaming responses.""" + # Test non-streaming: data field should contain the result + non_streaming_result = self.client.parse_chat_completion(self.mock_response) + self.assertIsNotNone(non_streaming_result.data, "Non-streaming should have data field populated") + self.assertEqual(non_streaming_result.data, "Hello, world!", "Data should contain response text") + self.assertEqual(non_streaming_result.raw_response, "Hello, world!", "Raw response should match data for non-streaming") + + # Test streaming: simulate the actual client behavior for streaming + # Set the parser to streaming mode first + original_parser = self.client.response_parser + self.client.response_parser = self.client.streaming_response_parser_sync + + def mock_generator(): + yield from self.streaming_events + + streaming_result = self.client.parse_chat_completion(mock_generator()) + self.assertIsNone(streaming_result.data, "Streaming should have data field as None") + self.assertIsNotNone(streaming_result.raw_response, "Streaming should have raw_response as generator") + + # Verify we can iterate over the raw_response + events_from_raw = list(streaming_result.raw_response) + self.assertEqual(len(events_from_raw), 3, "Should be able to consume raw_response generator") + + # Restore original parser + self.client.response_parser = original_parser + if __name__ == "__main__": unittest.main() diff --git a/adalflow/tests/test_runner.py b/adalflow/tests/test_runner.py index cb098504..08e289a9 100644 --- a/adalflow/tests/test_runner.py +++ b/adalflow/tests/test_runner.py @@ -1082,6 +1082,233 @@ class TestModel(BaseModel): runner._process_data('"just a string"') self.assertIn("Expected dict after JSON parsing", str(cm.exception)) + def test_execution_complete_event_emitted_once(self): + """Test that agent.execution_complete event is emitted exactly once. + + Edge Case: Previously, execution_complete was emitted twice - once in + _process_stream_final_step and again outside the loop. This test ensures + it's only emitted once. + """ + async def async_test(): + from adalflow.core.types import FunctionOutput + + # Create a function with _is_answer_final=True + fn = DummyFunction( + name="answer_output", + _is_answer_final=True, + _answer="test_complete" + ) + agent = DummyAgent( + planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]), + answer_data_type=None, + ) + runner = Runner(agent=agent) + + async def mock_tool_execute_async(func, streaming_result=None): + return FunctionOutput(name=func.name, input=func, output="test_complete") + + runner._tool_execute_async = mock_tool_execute_async + + # Start streaming + streaming_result = runner.astream(prompt_kwargs={}) + + # Collect all execution_complete events + execution_complete_events = [] + all_events = [] + + async for event in streaming_result.stream_events(): + all_events.append(event) + if (isinstance(event, RunItemStreamEvent) and + event.name == "agent.execution_complete"): + execution_complete_events.append(event) + + # Should have exactly one execution_complete event + self.assertEqual(len(execution_complete_events), 1, + f"Expected 1 execution_complete event, got {len(execution_complete_events)}") + + # Verify the event contains correct data + final_event = execution_complete_events[0] + self.assertIsInstance(final_event.item, FinalOutputItem) + self.assertIsInstance(final_event.item.data, RunnerResult) + self.assertEqual(final_event.item.data.answer, "test_complete") + + asyncio.run(async_test()) + + def test_execution_complete_event_properly_consumed(self): + """Test that execution_complete event is properly consumed by stream_to_json. + + Edge Case: Previously, stream_events() would break immediately after yielding + execution_complete, preventing proper consumption by stream_to_json and other + consumers. This test ensures the event is properly consumed. + """ + async def async_test(): + from adalflow.core.types import FunctionOutput + import json + import os + import tempfile + + # Create a function with _is_answer_final=True + fn = DummyFunction( + name="answer_output", + _is_answer_final=True, + _answer="stream_to_json_test" + ) + agent = DummyAgent( + planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]), + answer_data_type=None, + ) + runner = Runner(agent=agent) + + async def mock_tool_execute_async(func, streaming_result=None): + return FunctionOutput(name=func.name, input=func, output="stream_to_json_test") + + runner._tool_execute_async = mock_tool_execute_async + + # Create temp file for JSON output + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + temp_file = f.name + + try: + # Start streaming to JSON + streaming_result = runner.astream(prompt_kwargs={}) + + # Stream to JSON file + events_consumed = [] + async for event in streaming_result.stream_to_json(temp_file): + if isinstance(event, RunItemStreamEvent): + events_consumed.append(event.name) + + # Verify execution_complete was consumed + self.assertIn("agent.execution_complete", events_consumed, + "execution_complete event was not consumed by stream_to_json") + + # Verify JSON file contains the execution_complete event + with open(temp_file, 'r') as f: + json_content = json.load(f) + + # Check that execution_complete event is in the JSON + execution_complete_found = False + for event_entry in json_content: + if (event_entry.get("event_type") == "RunItemStreamEvent" and + "agent.execution_complete" in str(event_entry.get("event_data", {}))): + execution_complete_found = True + break + + self.assertTrue(execution_complete_found, + "execution_complete event not found in JSON file") + + finally: + # Clean up temp file + if os.path.exists(temp_file): + os.remove(temp_file) + + asyncio.run(async_test()) + + def test_no_duplicate_execution_complete_on_incomplete(self): + """Test that execution_complete is only emitted once even when max_steps reached. + + Edge Case: When the loop ends without a final answer (max_steps reached), + execution_complete should be emitted only once in the cleanup section. + """ + async def async_test(): + from adalflow.core.types import FunctionOutput + + # Create functions that are NOT final + functions = [ + DummyFunction(name=f"step_{i}", _is_answer_final=False) + for i in range(3) + ] + outputs = [GeneratorOutput(data=fn) for fn in functions] + + agent = DummyAgent( + planner=FakeStreamingPlanner(outputs), + answer_data_type=None, + max_steps=3 # Will reach max_steps without final answer + ) + runner = Runner(agent=agent) + + async def mock_tool_execute_async(func, streaming_result=None): + return FunctionOutput(name=func.name, input=func, output=f"output_{func.name}") + + runner._tool_execute_async = mock_tool_execute_async + + # Start streaming + streaming_result = runner.astream(prompt_kwargs={}) + + # Collect all execution_complete events + execution_complete_events = [] + + async for event in streaming_result.stream_events(): + if (isinstance(event, RunItemStreamEvent) and + event.name == "agent.execution_complete"): + execution_complete_events.append(event) + + # Should have exactly one execution_complete event even when incomplete + self.assertEqual(len(execution_complete_events), 1, + f"Expected 1 execution_complete event for incomplete run, got {len(execution_complete_events)}") + + # Verify it contains the "No output generated" message + final_event = execution_complete_events[0] + self.assertIsInstance(final_event.item, FinalOutputItem) + self.assertIsInstance(final_event.item.data, RunnerResult) + self.assertIn("No output generated", final_event.item.data.answer) + + asyncio.run(async_test()) + + def test_execution_complete_with_early_break_scenario(self): + """Test execution_complete event when consumer breaks early from stream. + + Edge Case: Test that even if a consumer breaks early from the stream, + the execution_complete event is available for consumption if they + resume iteration (after our fix). + """ + async def async_test(): + from adalflow.core.types import FunctionOutput + + fn = DummyFunction( + name="answer_output", + _is_answer_final=True, + _answer="early_break_test" + ) + agent = DummyAgent( + planner=FakeStreamingPlanner([GeneratorOutput(data=fn)]), + answer_data_type=None, + ) + runner = Runner(agent=agent) + + async def mock_tool_execute_async(func, streaming_result=None): + return FunctionOutput(name=func.name, input=func, output="early_break_test") + + runner._tool_execute_async = mock_tool_execute_async + + # Start streaming + streaming_result = runner.astream(prompt_kwargs={}) + + # First consumer breaks after raw event + first_consumer_events = [] + async for event in streaming_result.stream_events(): + first_consumer_events.append(event) + if isinstance(event, RawResponsesStreamEvent): + break # Break early after first event + + # Second consumer continues from where first left off + second_consumer_events = [] + async for event in streaming_result.stream_events(): + second_consumer_events.append(event) + if isinstance(event, RunItemStreamEvent): + second_consumer_events.append(event) + + # Check if execution_complete was available to second consumer + execution_complete_found = any( + isinstance(e, RunItemStreamEvent) and e.name == "agent.execution_complete" + for e in second_consumer_events + ) + + self.assertTrue(execution_complete_found, + "execution_complete should be available even after early break") + + asyncio.run(async_test()) + if __name__ == "__main__": unittest.main() From 268737881086286bae8e1cfba0055db99bc7cbc3 Mon Sep 17 00:00:00 2001 From: Li Yin Date: Tue, 19 Aug 2025 15:33:44 -0700 Subject: [PATCH 3/3] Add flexible memory components and runner with tests --- adalflow/adalflow/components/agent/prompts.py | 3 +- .../components/agent/runner_flexible.py | 1957 +++++++++++++++++ .../components/memory/flexible_memory.py | 96 +- adalflow/tests/test_flexible_memory.py | 1121 ++++++++++ .../tests/test_flexible_memory_template.py | 538 +++++ adalflow/tests/test_runner_flexible.py | 517 +++++ 6 files changed, 4189 insertions(+), 43 deletions(-) create mode 100644 adalflow/adalflow/components/agent/runner_flexible.py create mode 100644 adalflow/tests/test_flexible_memory.py create mode 100644 adalflow/tests/test_flexible_memory_template.py create mode 100644 adalflow/tests/test_runner_flexible.py diff --git a/adalflow/adalflow/components/agent/prompts.py b/adalflow/adalflow/components/agent/prompts.py index dc37eb2b..d35c38c4 100644 --- a/adalflow/adalflow/components/agent/prompts.py +++ b/adalflow/adalflow/components/agent/prompts.py @@ -28,7 +28,8 @@ - If the last observation starts with "Run into error", you should try to fix the error in the next step. """ -# TODO: access the max steps in the agent prompt or not +# Chat history should be user: message, assistant: message + meta data {step_history} +# step_history is the observations. DEFAULT_ADALFLOW_AGENT_SYSTEM_PROMPT = r""" {{task_desc}} - You cant use more than {{max_steps}} steps. At the {{max_steps}}th current step, must set `_is_answer_final` to True and provide the answer. diff --git a/adalflow/adalflow/components/agent/runner_flexible.py b/adalflow/adalflow/components/agent/runner_flexible.py new file mode 100644 index 00000000..a942d7cf --- /dev/null +++ b/adalflow/adalflow/components/agent/runner_flexible.py @@ -0,0 +1,1957 @@ +"""Agent runner component with flexible memory for managing and executing agent workflows.""" + +from pydantic import BaseModel +import logging +import inspect +import asyncio +import uuid +import json +from datetime import datetime + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Type, + TypeVar, + Union, + AsyncIterable, +) +from typing_extensions import TypeAlias +import uuid + + +from adalflow.optim.parameter import Parameter +from adalflow.utils import printc +from adalflow.core.component import Component +from adalflow.components.agent.agent import Agent + +from adalflow.core.types import ( + GeneratorOutput, + FunctionOutput, + Function, + StepOutput, + RawResponsesStreamEvent, + RunItemStreamEvent, + ToolCallRunItem, + ToolOutputRunItem, + StepRunItem, + FinalOutputItem, + RunnerStreamingResult, + RunnerResult, + QueueCompleteSentinel, + ToolOutput, + ToolCallActivityRunItem, + UserQuery, + AssistantResponse, +) +from adalflow.apps.permission_manager import PermissionManager +from adalflow.components.memory.flexible_memory import FlexibleConversationMemory +from adalflow.core.functional import _is_pydantic_dataclass, _is_adalflow_dataclass +from adalflow.tracing import ( + runner_span, + tool_span, + response_span, + step_span, +) + + +__all__ = ["RunnerFlexible"] + +log = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) # Changed to use Pydantic BaseModel + + +def _is_unrecoverable_error(error: Optional[str]) -> bool: # pragma: no cover + """Check if an error string indicates an unrecoverable error. + + Unrecoverable errors include: + - HTTP 400: Bad request (e.g., context too long) + - HTTP 429: Rate limit exceeded + - HTTP 404: Model not found + - "Connection error": Network connection issues + + This is marked as uncoverable for testing purposes. + + Args: + error: Error string to check + + Returns: + True if the error is unrecoverable, False otherwise + """ + if not error: + return False + + # Check for connection error string pattern (case insensitive) + if "connection error" in error.lower(): + return True + + # Check for HTTP error codes + if "400" in error or "429" in error or "404" in error: + return True + + return False + +BuiltInType: TypeAlias = Union[str, int, float, bool, list, dict, tuple, set, None] +PydanticDataClass: TypeAlias = Type[BaseModel] +AdalflowDataClass: TypeAlias = Type[ + Any +] # Replace with your actual Adalflow dataclass type if available + + +# The runner will create tool call request, add a unique call id. +# Runner with flexible memory and robust error handling +class RunnerFlexible(Component): + """Executes Agent instances with multi-step iterative planning and tool execution. + + The Runner orchestrates the execution of an Agent through multiple reasoning and action + cycles. It manages the step-by-step execution loop where the Agent's planner generates + Function calls that get executed by the ToolManager, with results fed back into the + planning context for the next iteration. + + Execution Flow: + 1. Initialize step history and prompt context + 2. For each step (up to max_steps): + a. Call Agent's planner to get next Function + b. Execute the Function using ToolManager + c. Add step result to history + d. Check if Function is "finish" to terminate + 3. Process final answer to expected output type + + The Runner supports both synchronous and asynchronous execution modes, as well as + streaming execution with real-time event emission. It includes comprehensive tracing + and error handling throughout the execution pipeline. + + Attributes: + agent (Agent): The Agent instance to execute + max_steps (int): Maximum number of execution steps allowed + answer_data_type (Type): Expected type for final answer processing + step_history (List[StepOutput]): History of all execution steps + ctx (Optional[Dict]): Additional context passed to tools + """ + + def __init__( + self, + agent: Agent, + ctx: Optional[Dict] = None, + max_steps: Optional[int] = None, # this will overwrite the agent's max_steps + permission_manager: Optional[PermissionManager] = None, + conversation_memory: Optional[FlexibleConversationMemory] = None, + **kwargs, + ) -> None: + """Initialize runner with an agent and configuration. + + Args: + agent: The agent instance to execute + stream_parser: Optional stream parser + output_type: Optional Pydantic data class type + max_steps: Maximum number of steps to execute + permission_manager: Optional permission manager for tool approval + conversation_memory: Optional conversation memory + """ + super().__init__(**kwargs) + self.agent = agent + self.tool_manager = agent.tool_manager + self.permission_manager = permission_manager + # pass the tool_manager to the permission_manager + if permission_manager is not None: + permission_manager.set_tool_manager(self.tool_manager) + + self.conversation_memory = conversation_memory + + self.use_conversation_memory = conversation_memory is not None + + # get agent requirements + self.max_steps = max_steps + if max_steps is None: + self.max_steps = agent.max_steps + else: + # overwrite the agent's max_steps + self.agent.max_steps = max_steps + self.answer_data_type = agent.answer_data_type or str + + self.step_history: List[StepOutput] = [] + + # add ctx (it is just a reference, and only get added to the final response) + # assume intermediate tool is gonna modify the ctx + self.ctx = ctx + + # Initialize permission manager + self._init_permission_manager() + + # Initialize cancellation flag + self._cancelled = False + self._cancel_callbacks = [] + self._current_task = None # Track the current running task + self._current_streaming_result = None # Track the current streaming result + + # support thinking model + self.is_thinking_model = agent.is_thinking_model if hasattr(agent, 'is_thinking_model') else False + + # Token tracking + self._token_consumption: Dict[str, Any] = { + 'total_prompt_tokens': 0, + 'current_step_tokens': 0, + 'steps_token_history': [], + 'last_total_tokens': 0 # Track last total to calculate step difference + } + + # ============== Safe Memory Operations ============== + def _safe_create_turn(self) -> Optional[str]: + """Safely create a new turn in memory. + + Returns: + Turn ID if successful, None if failed + """ + if not self.use_conversation_memory: + return None + + try: + return self.conversation_memory.create_turn() + except Exception as e: + log.warning(f"Failed to create turn in memory: {e}") + return None + + def _safe_add_user_query(self, query: Union[str, UserQuery], turn_id: Optional[str], metadata: Optional[Dict] = None) -> Optional[str]: + """Safely add user query to memory. + + Returns: + Turn ID if successful, None if failed + """ + if not self.use_conversation_memory or turn_id is None: + return None + + try: + if isinstance(query, str): + return self.conversation_memory.add_user_query(query, turn_id, metadata) + else: + return self.conversation_memory.add_user_query(query.query_str, turn_id, metadata or query.metadata) + except Exception as e: + log.warning(f"Failed to add user query to memory: {e}") + return None + + def _safe_add_assistant_response(self, response: Union[str, AssistantResponse], turn_id: Optional[str], metadata: Optional[Dict] = None) -> Optional[str]: + """Safely add assistant response to memory. + + Returns: + Turn ID if successful, None if failed + """ + if not self.use_conversation_memory or turn_id is None: + return None + + try: + if isinstance(response, str): + return self.conversation_memory.add_assistant_response(response, turn_id, metadata) + else: + return self.conversation_memory.add_assistant_response( + response.response_str, + turn_id, + metadata or response.metadata + ) + except Exception as e: + log.warning(f"Failed to add assistant response to memory: {e}") + return None + + def _safe_get_conversation_history(self) -> str: + """Safely get conversation history from memory. + + Returns: + Conversation history string, empty string if failed + """ + if not self.use_conversation_memory: + return "" + + try: + return self.conversation_memory() or "" + except Exception as e: + log.warning(f"Failed to get conversation history: {e}") + return "" + + def _safe_reset_pending_query(self) -> None: + """Safely reset pending query in memory.""" + if not self.use_conversation_memory: + return + + try: + self.conversation_memory.reset_pending_query() + except Exception as e: + log.warning(f"Failed to reset pending query: {e}") + # Continue execution even if reset fails + + def _init_permission_manager(self): + """Initialize the permission manager and register tools that require approval.""" + if self.permission_manager and hasattr(self.agent, "tool_manager"): + # Iterate through tools in the ComponentList + for tool in self.agent.tool_manager.tools: + if hasattr(tool, "definition") and hasattr(tool, "require_approval"): + tool_name = tool.definition.func_name + self.permission_manager.register_tool( + tool_name, tool.require_approval + ) + + def set_permission_manager( + self, permission_manager: Optional[PermissionManager] + ) -> None: + """Set or update the permission manager after runner initialization. + + Args: + permission_manager: The permission manager instance to use for tool approval + """ + self.permission_manager = permission_manager + # Re-initialize to register tools with the new permission manager + self._init_permission_manager() + + # pass the tool_manager to the permission_manager + if permission_manager is not None: + permission_manager.set_tool_manager(self.tool_manager) + + + + def is_cancelled(self) -> bool: + """Check if execution has been cancelled.""" + return self._cancelled + + def reset_cancellation(self) -> None: + """Reset the cancellation flag for a new execution.""" + self._cancelled = False + + def get_token_consumption(self) -> Dict[str, Any]: + """Get the current token consumption statistics. + + Returns: + Dict containing token consumption data: + - total_prompt_tokens: Total tokens consumed across all steps + - current_step_tokens: Tokens from the most recent step + - steps_token_history: List of token counts per step + """ + return self._token_consumption.copy() + + def _update_token_consumption(self) -> None: + """Update token consumption statistics by checking the planner's accumulated token count. + + Since the generator accumulates tokens, we calculate the step tokens as the difference + from the last recorded total. + """ + if hasattr(self.agent.planner, 'estimated_token_count'): + current_total = self.agent.planner.estimated_token_count + step_tokens = current_total - self._token_consumption['last_total_tokens'] + + self._token_consumption['current_step_tokens'] = step_tokens + self._token_consumption['total_prompt_tokens'] = current_total + self._token_consumption['steps_token_history'].append(step_tokens) + self._token_consumption['last_total_tokens'] = current_total + + return step_tokens + return 0 + + def register_cancel_callback(self, callback) -> None: + """Register a callback to be called when execution is cancelled.""" + self._cancel_callbacks.append(callback) + + async def cancel(self) -> None: + """Cancel the current execution. + + This will stop the current execution but preserve state like memory. + """ + log.info("Runner.cancel() called - setting cancelled flag") + self._cancelled = True + + # Try to emit a test event if we have a streaming result + if hasattr(self, '_current_streaming_result') and self._current_streaming_result: + try: + cancel_received_event = RunItemStreamEvent( + name="runner.cancel_received", + item=FinalOutputItem( + data={ + "status": "cancel_received", + "message": "Cancel request received", + }) + ) + self._current_streaming_result.put_nowait(cancel_received_event) + log.info("Emitted cancel_received event") + except Exception as e: + log.error(f"Failed to emit cancel_received event: {e}") + + # Cancel the current streaming task if it exists + if self._current_task and not self._current_task.done(): + log.info(f"Cancelling runner task: {self._current_task}") + self._current_task.cancel() + + # Create a task to wait for cancellation to complete + await self._wait_for_cancellation() + + async def _wait_for_cancellation(self): + """Wait for task to be cancelled with timeout.""" + if self._current_task: + try: + # Wait up to 1 second for task to cancel gracefully + await asyncio.wait_for( + self._current_task, + timeout=1.0 + ) + except (asyncio.TimeoutError, asyncio.CancelledError): + # Task didn't cancel in time or was cancelled - that's ok + pass + + def _check_last_step(self, step: Function) -> bool: + """Check if the last step has is_answer_final set to True.""" + if hasattr(step, "_is_answer_final") and step._is_answer_final: + return True + + return False + + def _get_final_answer(self, function: Function) -> Any: + """Get and process the final answer from the function.""" + if hasattr(function, "_answer") or (hasattr(function, "_is_answer_final") and function._is_answer_final): + return self._process_data(function._answer) + return None + + + def _create_runner_result(self, answer: Any, step_history, error: Optional[str] = None, ) -> RunnerResult: + """Create a RunnerResult object with the final answer and error.""" + return RunnerResult( + answer=answer, + step_history=step_history.copy(), + error=error, + # ctx=self.ctx, + ) + def _create_execution_complete_stream_event(self, streaming_result: RunnerStreamingResult, final_output_item: FinalOutputItem): + """Complete the streaming execution by adding a sentinel.""" + try: + final_output_event = RunItemStreamEvent( + name="agent.execution_complete", + item=final_output_item, + ) + streaming_result.put_nowait(final_output_event) + + runner_result: RunnerResult = final_output_item.data + + # set up the final answer + streaming_result.answer = runner_result.answer if runner_result else None + streaming_result.step_history = self.step_history.copy() + streaming_result._is_complete = True + + except Exception as e: + log.error(f"Failed to create execution complete stream event: {e}") + raise e + + def _add_assistant_response_to_memory(self, final_output_item: FinalOutputItem, turn_id: Any = None): + # add the assistant response to the conversation memory using safe wrapper + if self.use_conversation_memory and turn_id is not None: + # Only add if we have a valid turn_id + self._safe_add_assistant_response( + AssistantResponse( + response_str=final_output_item.data.answer, + metadata={ + "step_history": final_output_item.data.step_history.copy() + }, + ), + turn_id=turn_id + ) + elif self.use_conversation_memory: + log.warning("Skipping add_assistant_response - no turn_id available") + + def create_response_span(self, runner_result, step_count: int, streaming_result: RunnerStreamingResult, runner_span_instance, workflow_status: str = "stream_completed"): + + runner_span_instance.span_data.update_attributes( + { + "steps_executed": step_count + 1, + "final_answer": runner_result.answer, + "workflow_status": workflow_status, + } + ) + + # Create response span for tracking final streaming result + with response_span( + answer=runner_result.answer, + result_type=type(runner_result.answer).__name__, + execution_metadata={ + "steps_executed": step_count + 1, + "max_steps": self.max_steps, + "workflow_status": workflow_status, + "streaming": True, + }, + response=runner_result, + ): + pass + + + + + async def _process_stream_final_step( + self, + answer: Any, + step_count: int, + streaming_result, + runner_span_instance, + turn_id: Any = None, + ) -> FinalOutputItem: + """Process the final step and trace it.""" + + # Runner result is the same as the sync/async call result + + runner_result = self._create_runner_result( + answer=answer, + step_history=self.step_history, + ) + + # Emit execution complete event + final_output_item = FinalOutputItem(data=runner_result) + self._create_execution_complete_stream_event( + streaming_result, final_output_item + ) + + # Ensure event is in queue by yielding control briefly + # This allows the event loop to process the put_nowait before we return + await asyncio.sleep(0) # Small delay to ensure event is properly queued + + # add the assistant response to the conversation memory + self._add_assistant_response_to_memory(final_output_item, turn_id) + return final_output_item + + # TODO: improved after the finish function is refactored + def _process_data( + self, + data: Union[BuiltInType, PydanticDataClass, AdalflowDataClass], + id: Optional[str] = None, + ) -> T: + """Process the generator output data field and convert to the specified pydantic data class of output_type. + + Args: + data: The data to process + id: Optional identifier for the output + + Returns: + str: The processed data as a string + """ + + try: + model_output = None + log.info(f"answer_data_type: {type(self.answer_data_type)}") + + # returns a dictionary in this case + if _is_pydantic_dataclass(self.answer_data_type): + log.info( + f"initial answer returned by finish when user passed a pydantic type: {data}, type: {type(data)}" + ) + # if it has not yet been deserialized then deserialize into dictionary using json loads + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in data: {e}") + if not isinstance(data, dict): + raise ValueError(f"Expected dict after JSON parsing, got {type(data)}") + log.info( + f"initial answer after being evaluated using json: {data}, type: {type(data)}" + ) + # data should be a string that represents a dictionary + model_output = self.answer_data_type(**data) + elif _is_adalflow_dataclass(self.answer_data_type): + log.info( + f"initial answer returned by finish when user passed a adalflow type: {data}, type: {type(data)}" + ) + + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in data: {e}") + if not isinstance(data, dict): + raise ValueError(f"Expected dict after JSON parsing, got {type(data)}") + log.info( + f"initial answer after being evaluated using json: {data}, type: {type(data)}" + ) + # data should be a string that represents a dictionary + model_output = self.answer_data_type.from_dict(data) + else: # expect data to be a python built_in_type + log.info( + f"type of answer is neither a pydantic dataclass or adalflow dataclass, answer before being casted again for safety: {data}, type: {type(data)}" + ) + data = self.answer_data_type( + data + ) # directly cast using the answer_data_type + if not isinstance(data, self.answer_data_type): + raise ValueError( + f"Expected data of type {self.answer_data_type}, but got {type(data)}" + ) + model_output = data + + if not model_output: + raise ValueError(f"Failed to parse output: {data}") + + return model_output + + except Exception as e: + log.error(f"Error processing output: {str(e)}") + raise ValueError(f"Error processing output: {str(e)}") + + @classmethod + def _get_planner_function(self, output: GeneratorOutput) -> Optional[Function]: + """Check the planner output and return the function. + + Args: + output: The planner output + """ + if not isinstance(output, GeneratorOutput): + raise ValueError( + f"Expected GeneratorOutput, but got {type(output)}, value: {output}" + ) + + function = output.data + + if not isinstance(function, Function): + # can still self-recover in the agent for formatting. + # raise ValueError( + # f"Expected Function in the data field of the GeneratorOutput, but got {type(function)}, value: {function}" + # ) + return None + + return function + + def call( + self, + prompt_kwargs: Dict[str, Any], + model_kwargs: Optional[Dict[str, Any]] = None, + use_cache: Optional[bool] = None, + id: Optional[str] = None, # global run id + ) -> RunnerResult: + """Execute the planner synchronously for multiple steps with function calling support. + + At the last step the action should be set to "finish" instead which terminates the sequence + + Args: + prompt_kwargs: Dictionary of prompt arguments for the generator + model_kwargs: Optional model parameters to override defaults + use_cache: Whether to use cached results if available + id: Optional unique identifier for the request + + Returns: + RunnerResult containing step history and final processed output + """ + # Create runner span for tracing + with runner_span( + runner_id=id or f"runner_{hash(str(prompt_kwargs))}", + max_steps=self.max_steps, + workflow_status="starting", + ) as runner_span_instance: + # reset the step history + self.step_history = [] + + # take in the query in prompt_kwargs + prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {} + prompt_kwargs["step_history"] = ( + self.step_history + ) # a reference to the step history + + turn_id = None + if self.use_conversation_memory: + # Reset any pending query state before starting a new query + self._safe_reset_pending_query() + + # Create new turn + turn_id = self._safe_create_turn() + + prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history() + # save the user query to the conversation memory + + # meta data is all keys in the list of context_str + query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} + if turn_id: + self._safe_add_user_query( + UserQuery( + query_str=prompt_kwargs.get("input_str", None), + metadata=query_metadata, + ), + turn_id=turn_id + ) + + # set maximum number of steps for the planner into the prompt + prompt_kwargs["max_steps"] = self.max_steps + + model_kwargs = model_kwargs.copy() if model_kwargs else {} + + step_count = 0 + last_output = None + current_error = None + + while step_count < self.max_steps: + try: + log.debug(f"Running step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}") + # Create step span for each iteration + with step_span( + step_number=step_count, action_type="planning" + ) as step_span_instance: + + # Call the planner first to get the output + output = self.agent.planner.call( + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + use_cache=use_cache, + id=id, + ) + + # Track token usage + step_tokens = self._update_token_consumption() + if step_tokens > 0: + log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}") + + log.debug(f"planner output: {output}") + + # consistency with impl_astream, break if output is not a Generator Output + if not isinstance(output, GeneratorOutput): + # Create runner finish event with error and stop the loop + current_error = ( + f"Expected GeneratorOutput, but got {output}" + ) + # add this to the step history + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + self.step_history.append(step_output) + break + + function = output.data + + log.debug(f"function: {function}") + if function is None: + error_msg = f"Run into error: {output.error}, raw response: {output.raw_response}" + # Handle recoverable vs unrecoverable errors + if output.error is not None: + if _is_unrecoverable_error(output.error): + # Unrecoverable errors: context too long, rate limit, model not found + current_error = output.error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=f"Unrecoverable error: {output.error}", + ) + self.step_history.append(step_output) + break # Stop execution for unrecoverable errors + # Recoverable errors: JSON format errors, parsing errors, etc. + current_error = output.error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + self.step_history.append(step_output) + step_count += 1 + continue # Continue to next step for recoverable errors + + # start to process correct function + function.id = str(uuid.uuid4()) # add function id + thinking = output.thinking if hasattr(output, 'thinking') else None + if thinking is not None and self.is_thinking_model: + function.thought = thinking + + + if self._check_last_step(function): + processed_data = self._process_data(function._answer) + # Wrap final output in RunnerResult + last_output = RunnerResult( + answer=processed_data, + step_history=self.step_history.copy(), + # ctx=self.ctx, + ) + + # Add assistant response to conversation memory + if self.use_conversation_memory and turn_id is not None: + self._safe_add_assistant_response( + AssistantResponse( + response_str=processed_data, + metadata={ + "step_history": self.step_history.copy() + }, + ), + turn_id=turn_id + ) + + step_count += 1 # Increment step count before breaking + break + + step_output: Optional[StepOutput] = None + + # Create tool span for function execution + with tool_span( + tool_name=function.name, + function_name=function.name, + function_args=function.args, + function_kwargs=function.kwargs, + ) as tool_span_instance: + function_results = self._tool_execute_sync(function) + # Update span attributes using update_attributes for MLflow compatibility + tool_span_instance.span_data.update_attributes( + {"output_result": function_results.output} + ) + + function_output = function_results.output + real_function_output = None + + # Handle generator outputs in sync call + if inspect.iscoroutine(function_output): + # For sync call, we need to run the coroutine + real_function_output = asyncio.run(function_output) + elif inspect.isasyncgen(function_output): + # Collect all values from async generator + async def collect_async_gen(): + collected_items = [] + async for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # Skip activity items + continue + else: + collected_items.append(item) + return collected_items + real_function_output = asyncio.run(collect_async_gen()) + elif inspect.isgenerator(function_output): + # Collect all values from sync generator + collected_items = [] + for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # Skip activity items + continue + else: + collected_items.append(item) + real_function_output = collected_items + else: + real_function_output = function_output + + # Use the processed output + function_output = real_function_output + function_output_observation = function_output + if isinstance(function_output, ToolOutput) and hasattr( + function_output, "observation" + ): + function_output_observation = function_output.observation + + # create a step output + step_output = StepOutput( + step=step_count, + action=function, + function=function, + observation=function_output_observation, + ) + + # Update step span with results + step_span_instance.span_data.update_attributes( + { + "tool_name": function.name, + "tool_output": function_results, + "is_final": self._check_last_step(function), + "observation": function_output_observation, + } + ) + + log.debug( + "The prompt with the prompt template is {}".format( + self.agent.planner.get_prompt(**prompt_kwargs) + ) + ) + self.step_history.append(step_output) + step_count += 1 + + except Exception as e: + error_msg = f"Error in step {step_count}: {str(e)}" + log.error(error_msg) + + # Create response span for error tracking + with response_span( + answer=error_msg, + result_type="error", + execution_metadata={ + "steps_executed": step_count, + "max_steps": self.max_steps, + "workflow_status": "failed", + }, + response=None, + ): + pass + + # Continue to next step instead of returning + step_count += 1 + current_error = error_msg + break + + # Update runner span with final results + # Update runner span with completion info using update_attributes + runner_span_instance.span_data.update_attributes( + { + "steps_executed": step_count, + "final_answer": last_output.answer if last_output else None, + "workflow_status": "completed", + } + ) + + # Create response span for tracking final result + with response_span( + answer=( + last_output.answer + if last_output + else f"No output generated after {step_count} steps (max_steps: {self.max_steps})" + ), + result_type=( + type(last_output.answer).__name__ if last_output else "no_output" + ), + execution_metadata={ + "steps_executed": step_count, + "max_steps": self.max_steps, + "workflow_status": "completed" if last_output else "incomplete", + }, + response=last_output, # can be None if Runner has not finished in the max steps + ): + pass + + # Always return a RunnerResult, even if no successful completion + return last_output or RunnerResult( + answer=current_error or f"No output generated after {step_count} steps (max_steps: {self.max_steps})", + step_history=self.step_history.copy(), + error=current_error, + ) + + def _tool_execute_sync( + self, + func: Function, + ) -> Union[FunctionOutput, Parameter]: + """ + Call this in the call method. + Handles both sync and async functions by running async ones in event loop. + Includes permission checking if permission_manager is configured. + """ + + # execute permission and blocking mechanism in check_permission + # TODO: permission manager might be better to be put inside of tool manager + if self.permission_manager: + + result = asyncio.run(self.permission_manager.check_permission(func)) + + # Handle both old (2 values) and new (3 values) return formats + if len(result) == 3: + allowed, modified_func, _ = result + else: + allowed, modified_func = result + + if not allowed: + return FunctionOutput( + name=func.name, + input=func, + output=ToolOutput( + output="Tool execution cancelled by user", + observation="Tool execution cancelled by user", + display="Permission denied", + status="cancelled", + ), + ) + + # Use modified function if user edited it + func = modified_func or func + + result = self.agent.tool_manager.execute_func(func=func) + + if not isinstance(result, FunctionOutput): + raise ValueError("Result is not a FunctionOutput") + + # check error + if result.error is not None: + log.warning(f"Error in tool execution: {result.error}") + # TODO: specify how to handle this error + + return result + + # support both astream and non-stream + async def acall( + self, + prompt_kwargs: Dict[str, Any], + model_kwargs: Optional[Dict[str, Any]] = None, + use_cache: Optional[bool] = None, + id: Optional[str] = None, + ) -> Optional[RunnerResult]: + """Execute the planner asynchronously for multiple steps with function calling support. + + At the last step the action should be set to "finish" instead which terminates the sequence + + Args: + prompt_kwargs: Dictionary of prompt arguments for the generator + model_kwargs: Optional model parameters to override defaults + use_cache: Whether to use cached results if available + id: Optional unique identifier for the request + + Returns: + RunnerResponse containing step history and final processed output + """ + + + workflow_status = "starting" + runner_id = id or f"async_runner_{hash(str(prompt_kwargs))}" + + + + # Create runner span for tracing + with runner_span( + runner_id=runner_id, + max_steps=self.max_steps, + workflow_status= workflow_status, + ) as runner_span_instance: + # Reset cancellation flag at start of new execution + self.reset_cancellation() + + self.step_history = [] + prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {} + + prompt_kwargs["step_history"] = ( + self.step_history + ) # a reference to the step history + + turn_id = None + if self.use_conversation_memory: + # Reset any pending query state before starting a new query + self._safe_reset_pending_query() + + # Create new turn + turn_id = self._safe_create_turn() + + prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history() + # save the user query to the conversation memory + + # meta data is all keys in the list of context_str + query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} + if turn_id: + self._safe_add_user_query( + UserQuery( + query_str=prompt_kwargs.get("input_str", None), + metadata=query_metadata, + ), + turn_id=turn_id + ) + + # set maximum number of steps for the planner into the prompt + prompt_kwargs["max_steps"] = self.max_steps + + model_kwargs = model_kwargs.copy() if model_kwargs else {} + + step_count = 0 + last_output = None + current_error = None + + while step_count < self.max_steps and not self.is_cancelled(): + try: + log.debug(f"Running async step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}") + + # Create step span for each iteration + with step_span( + step_number=step_count, action_type="async_planning" + ) as step_span_instance: + + log.debug(f"Running async step {step_count + 1}/{self.max_steps} with prompt_kwargs: {prompt_kwargs}") + + if self.is_cancelled(): + raise asyncio.CancelledError("Execution cancelled by user") + + # Call the planner first to get the output + output: GeneratorOutput = await self.agent.planner.acall( + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + use_cache=use_cache, + id=id, + ) + + # Track token usage + step_tokens = self._update_token_consumption() + if step_tokens > 0: + log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}") + + log.debug(f"planner output: {output}") + + if not isinstance(output, GeneratorOutput): + # Create runner finish event with error and stop the loop + current_error = ( + f"Expected GeneratorOutput, but got {type(output)}" + ) + # create a step output for the error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + self.step_history.append(step_output) + step_count += 1 + break + + + + function = output.data + + log.debug(f"function: {function}") + if function is None: + error_msg = f"Run into error: {output.error}, raw response: {output.raw_response}" + # Handle recoverable vs unrecoverable errors + if output.error is not None: + if _is_unrecoverable_error(output.error): + # Unrecoverable errors: context too long, rate limit, model not found + current_error = output.error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=f"Unrecoverable error: {output.error}", + ) + self.step_history.append(step_output) + step_count += 1 + break # Stop execution for unrecoverable errors + # Recoverable errors: JSON format errors, parsing errors, etc. + current_error = output.error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + self.step_history.append(step_output) + step_count += 1 + + continue # Continue to next step for recoverable errors` + + + + thinking = output.thinking if hasattr(output, 'thinking') else None + if function is not None: + # add a function id + function.id = str(uuid.uuid4()) + if thinking is not None and self.is_thinking_model: + function.thought = thinking + + + + if self._check_last_step(function): + answer = self._get_final_answer(function) + # Wrap final output in RunnerResult + last_output = RunnerResult( + answer=answer, + step_history=self.step_history.copy(), + error=current_error, + # ctx=self.ctx, + ) + + # Add assistant response to conversation memory + if self.use_conversation_memory and turn_id is not None: + self._safe_add_assistant_response( + AssistantResponse( + response_str=answer, + metadata={ + "step_history": self.step_history.copy() + }, + ), + turn_id=turn_id + ) + + + + step_count += 1 # Increment step count before breaking + + break + + # Create tool span for function execution + with tool_span( + tool_name=function.name, + function_name=function.name, + function_args=function.args, + function_kwargs=function.kwargs, + ) as tool_span_instance: + function_results = await self._tool_execute_async( + func=function + ) + function_output = function_results.output + # add the process of the generator and async generator + real_function_output = None + + # Handle generator outputs similar to astream implementation + if inspect.iscoroutine(function_output): + real_function_output = await function_output + elif inspect.isasyncgen(function_output): + # Collect all values from async generator + collected_items = [] + async for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # Skip activity items in acall + continue + else: + collected_items.append(item) + # Use collected items as output + real_function_output = collected_items + elif inspect.isgenerator(function_output): + # Collect all values from sync generator + collected_items = [] + for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # Skip activity items in acall + continue + else: + collected_items.append(item) + # Use collected items as output + real_function_output = collected_items + else: + real_function_output = function_output + + # Use the processed output + function_output = real_function_output + function_output_observation = function_output + + if isinstance(function_output, ToolOutput) and hasattr( + function_output, "observation" + ): + function_output_observation = ( + function_output.observation + ) + + # Update tool span attributes using update_attributes for MLflow compatibility + tool_span_instance.span_data.update_attributes( + {"output_result": function_output} + ) + + step_output: StepOutput = StepOutput( + step=step_count, + action=function, + function=function, + observation=function_output_observation, + ) + self.step_history.append(step_output) + + # Update step span with results + step_span_instance.span_data.update_attributes( + { + "tool_name": function.name, + "tool_output": function_results, + "is_final": self._check_last_step(function), + "observation": function_output_observation, + } + ) + + log.debug( + "The prompt with the prompt template is {}".format( + self.agent.planner.get_prompt(**prompt_kwargs) + ) + ) + + step_count += 1 + + except Exception as e: + error_msg = f"Error in step {step_count}: {str(e)}" + log.error(error_msg) + + # Create response span for error tracking + with response_span( + answer=error_msg, + result_type="error", + execution_metadata={ + "steps_executed": step_count, + "max_steps": self.max_steps, + "workflow_status": "failed", + }, + response=None, + ): + pass + + # Continue to next step instead of returning + step_count += 1 + current_error = error_msg + break + + # Update runner span with final results + # Update runner span with completion info using update_attributes + runner_span_instance.span_data.update_attributes( + { + "steps_executed": step_count, + "final_answer": last_output.answer if last_output else None, + "workflow_status": "completed", + } + ) + + # Create response span for tracking final result + with response_span( + answer=( + last_output.answer + if last_output + else f"No output generated after {step_count} steps (max_steps: {self.max_steps})" + ), + result_type=( + type(last_output.answer).__name__ if last_output else "no_output" + ), + execution_metadata={ + "steps_executed": step_count, + "max_steps": self.max_steps, + "workflow_status": "completed" if last_output else "incomplete", + }, + response=last_output, # can be None if Runner has not finished in the max steps + ): + pass + + # Always return a RunnerResult, even if no successful completion + return last_output or RunnerResult( + answer=current_error or f"No output generated after {step_count} steps (max_steps: {self.max_steps})", + step_history=self.step_history.copy(), + error=current_error, + ) + + + def astream( + self, + prompt_kwargs: Dict[str, Any], + model_kwargs: Optional[Dict[str, Any]] = None, + use_cache: Optional[bool] = None, + id: Optional[str] = None, + ) -> RunnerStreamingResult: + """ + Execute the runner asynchronously with streaming support. + + Returns: + RunnerStreamingResult: A streaming result object with stream_events() method + """ + # Cancel any previous task that might still be running + # TODO might have problems of overwriting and cancelling other tasks if we call await astream two times asychronously with the same runner / agent instance. + if self._current_task and not self._current_task.done(): + self._current_task.cancel() + log.info("Cancelled previous streaming task") + # Don't wait for cancellation here - just cancel and move on + self._current_task = None + + # Reset cancellation flag for new execution + self._cancelled = False + + result = RunnerStreamingResult() + # Store the streaming result so we can emit events to it during cancellation + self._current_streaming_result = result + + self.reset_cancellation() + + # Store the task so we can cancel it if needed + self._current_task = asyncio.get_event_loop().create_task( + self.impl_astream(prompt_kwargs, model_kwargs, use_cache, id, result) + ) + result._run_task = self._current_task + return result + + async def impl_astream( + self, + prompt_kwargs: Dict[str, Any], + model_kwargs: Optional[Dict[str, Any]] = None, + use_cache: Optional[bool] = None, + id: Optional[str] = None, + streaming_result: Optional[RunnerStreamingResult] = None, + ) -> None: + """ + Behave exactly the same as `acall` but with streaming support. + + - GeneratorOutput will be emitted as RawResponsesStreamEvent + + - StepOutput will be emitted as RunItemStreamEvent with name "agent.step_complete". + + - Finally, there will be a FinalOutputItem with the final answer or error. + + Execute the planner asynchronously for multiple steps with function calling support. + + At the last step the action should be set to "finish" instead which terminates the sequence + + Args: + prompt_kwargs: Dictionary of prompt arguments for the generator + model_kwargs: Optional model parameters to override defaults + use_cache: Whether to use cached results if available + id: Optional unique identifier for the request + """ + workflow_status: Literal["streaming", "stream_completed", "stream_failed", "stream_incomplete"] = "streaming" + # Create runner span for tracing streaming execution + turn_id = None + with runner_span( + runner_id=id or f"stream_runner_{hash(str(prompt_kwargs))}", + max_steps=self.max_steps, + workflow_status= workflow_status, + ) as runner_span_instance: + + # Reset cancellation flag at start of new execution + self.step_history = [] + prompt_kwargs = prompt_kwargs.copy() if prompt_kwargs else {} + + prompt_kwargs["step_history"] = self.step_history + if self.use_conversation_memory: + # Reset any pending query state before starting a new query + self._safe_reset_pending_query() + + # Create new turn + turn_id = self._safe_create_turn() + + prompt_kwargs["chat_history_str"] = self._safe_get_conversation_history() + # save the user query to the conversation memory + # meta data is all keys in the list of context_str + query_metadata = {"context_str": prompt_kwargs.get("context_str", None)} + if turn_id: + self._safe_add_user_query( + UserQuery( + query_str=prompt_kwargs.get("input_str", None), + metadata=query_metadata, + ), + turn_id + ) + # a reference to the step history + # set maximum number of steps for the planner into the prompt + prompt_kwargs["max_steps"] = self.max_steps + + model_kwargs = model_kwargs.copy() if model_kwargs else {} + step_count = 0 + final_output_item = None + current_error = None + + # whenever we have the final output, we break the loop, this includes + # (1) final_answer (check final step) + # (2) unrecoverable error in llm planner + # (3) any exception + + # for normal, we will have raw_response_event, request_permission, tool_call_start, tool_call_activity, tool_call_complete, step_complete + # for error, we can skip any step but will always have step_complete [] + + + # ToolOutput + # has three status: success, error, canceled + while step_count < self.max_steps and not self.is_cancelled(): + try: + # Create step span for each streaming iteration + # error handing: when run into any error, it creates a runner finish event. and stops the loop + # it should directly sent the execution complete with error event + with step_span( + step_number=step_count, action_type="stream_planning" + ) as step_span_instance: + # important to ensure the prompt at each step is correct + log.debug( + "The prompt with the prompt template is {}".format( + self.agent.planner.get_prompt(**prompt_kwargs) + ) + ) + + # Check cancellation before calling planner + # TODO seems slightly unnecessary we are calling .cancel on the task in cancel which will raise this exception regardless unless we want to terminate earlier by checking the cancelled field + if self.is_cancelled(): + raise asyncio.CancelledError("Execution cancelled by user") + + # when it's streaming, the output will be an async generator + output: GeneratorOutput = await self.agent.planner.acall( + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + use_cache=use_cache, + id=id, + ) + planner_prompt = self.agent.planner.get_prompt(**prompt_kwargs) + log.debug(f"Planner output: {output}, prompt: {planner_prompt}") + + # Track token usage + step_tokens = self._update_token_consumption() + if step_tokens > 0: + log.debug(f"Step {step_count} - Prompt tokens: {step_tokens}, Total: {self._token_consumption['total_prompt_tokens']}") + + if not isinstance(output, GeneratorOutput): + # Create runner finish event with error and stop the loop + error_msg = ( + f"Expected GeneratorOutput, but got {type(output)}" + ) + final_output_item = FinalOutputItem( + error=error_msg, + ) + workflow_status = "stream_failed" + current_error = error_msg + # create a step output for the error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + self.step_history.append(step_output) + step_count += 1 + break + + + # handle the generator output data and error + wrapped_event = None + + if isinstance(output.raw_response, AsyncIterable): + log.debug( + f"Streaming raw response from planner: {output.raw_response}" + ) + # Streaming llm call - iterate through the async generator + async for event in output.raw_response: + # TODO seems slightly unnecessary we are calling .cancel on the task in cancel which will raise this exception regardless + if self.is_cancelled(): + raise asyncio.CancelledError("Execution cancelled by user") + wrapped_event = RawResponsesStreamEvent(data=event) + streaming_result.put_nowait(wrapped_event) + + else: # non-streaming cases + # yield the final planner response + if output.data is None or (not isinstance(output.data, Function)): + + # recoverable errors, continue to create stepout + current_error = f"Error: {output.error} - data: {output.data}, raw_response: {output.raw_response}" + # wrap the error in a RawResponsesStreamEvent + wrapped_event = RawResponsesStreamEvent( + data=None, # no data in this case + error= output.error, + ) + streaming_result.put_nowait(wrapped_event) + + + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=current_error, + ) + # emit the step complete event with error which matches the step_output + step_item = StepRunItem(data=step_output) + step_complete_event = RunItemStreamEvent( + name="agent.step_complete", + item=step_item, + ) + streaming_result.put_nowait(step_complete_event) + + # Ensure event is processed before continuing + await asyncio.sleep(0) # Yield control to allow queue processing + + self.step_history.append(step_output) + + if output.error is not None: + + if _is_unrecoverable_error(output.error): # context too long or rate limite, not recoverable + # 404 model not exist + # create a final output item with error and stop the loop + final_output_item = FinalOutputItem( + error=output.error, + ) + workflow_status = "stream_failed" + current_error = output.error + step_output = StepOutput( + step=step_count, + action=None, + function=None, + observation=f"Unrecoverable error: {output.error}", + ) + self.step_history.append(step_output) + step_count += 1 + break + step_count += 1 + continue # continue to next step + + + # normal functions + wrapped_event = RawResponsesStreamEvent( + data=output.data, + input=planner_prompt, + ) # wrap on the data field to be the final output, the data might be null + streaming_result.put_nowait(wrapped_event) + + # asychronously consuming the raw response will + # update the data field of output with the result of the output processor + + # handle function output + + function = output.data # here are the recoverable errors, should continue to step output + thinking = output.thinking # check the reasoning model response + if thinking is not None and self.is_thinking_model: + # if the thinking is not None, we will add it to the function + if function is not None and isinstance(function, Function): + function.thought = thinking + + function.id = str(uuid.uuid4()) # add function id + function_result = None + function_output_observation = None + + if thinking is not None and self.is_thinking_model: + function.thought = thinking + + # TODO: simplify this + tool_call_id = function.id + tool_call_name = function.name + log.debug(f"function: {function}") + + if self._check_last_step(function): # skip stepoutput + try: + answer = self._get_final_answer(function) + except Exception as e: + # If processing the final answer fails, use the raw answer + log.warning(f"Failed to process final answer: {e}. Using raw answer.") + answer = function._answer if hasattr(function, "_answer") else str(e) + + final_output_item = await self._process_stream_final_step( + answer=answer, + step_count=step_count, + streaming_result=streaming_result, + runner_span_instance=runner_span_instance, + turn_id=turn_id, + ) + workflow_status = "stream_completed" + + # Ensure the queue has processed the execution_complete event + # Add a small yield to allow the event loop to process the queued events + await asyncio.sleep(0) # Yield control to allow queue processing + + break + + # Check if permission is required and emit permission event + # TODO: trace the permission event + + function_output_observation = None + function_result = None + print("function name", function.name) + complete_step = False + if ( + self.permission_manager + and self.permission_manager.is_approval_required( + function.name + ) + ): + permission_event = ( + self.permission_manager.create_permission_event( + function + ) + ) + # there is an error + if isinstance(permission_event, ToolOutput): + # need a tool complete event + function_result = FunctionOutput( + name=function.name, + input=function, + output=permission_event, + ) + tool_complete_event = RunItemStreamEvent( + name="agent.tool_call_complete", + # error is already tracked in output + # TODO: error tracking is not needed in RunItem, it is tracked in the tooloutput status. + item=ToolOutputRunItem( + data=function_result, + id=tool_call_id, + error=permission_event.observation if permission_event.status == "error" else None, # error message sent to the frontend + ), + ) + streaming_result.put_nowait(tool_complete_event) + function_output_observation = permission_event.observation + complete_step = True + else: + permission_stream_event = RunItemStreamEvent( + name="agent.tool_permission_request", + item=permission_event, + ) + streaming_result.put_nowait(permission_stream_event) + if not complete_step: + # Execute the tool with streaming support + function_result, function_output, function_output_observation = await self.stream_tool_execution( + function=function, + tool_call_id=tool_call_id, + tool_call_name=tool_call_name, + streaming_result=streaming_result, + ) + + # Add step to history for approved tools (same as non-permission branch) + step_output: StepOutput = StepOutput( + step=step_count, + action=function, + function=function, + observation=function_output_observation, + ) + self.step_history.append(step_output) + + # Update step span with results (for both recoverable errors and normal function execution) + step_span_instance.span_data.update_attributes( + { + "tool_name": function.name if function else None, + "tool_output": function_result, + "is_final": self._check_last_step(function), + "observation": function_output_observation, + } + ) + + # Emit step completion event (with error if any) + step_item = StepRunItem(data=step_output) + step_event = RunItemStreamEvent( + name="agent.step_complete", item=step_item + ) + streaming_result.put_nowait(step_event) + + # Ensure event is processed before continuing + await asyncio.sleep(0) # Yield control to allow queue processing + + step_count += 1 + + except asyncio.CancelledError: + # Handle cancellation gracefully + cancel_msg = "Execution cancelled by user" + log.info(cancel_msg) + + # Emit cancellation event so frontend/logs can see it + cancel_event = RunItemStreamEvent( + name="runner.cancelled", + item=FinalOutputItem(data={ + "status": "cancelled", + "message": cancel_msg, + "step_count": step_count, + }) + ) + streaming_result.put_nowait(cancel_event) + + # Store cancellation result + streaming_result.answer = cancel_msg + streaming_result.step_history = self.step_history.copy() + streaming_result._is_complete = True + + # Add cancellation response to conversation memory + if self.use_conversation_memory and turn_id is not None: + self._safe_add_assistant_response( + AssistantResponse( + response_str="I apologize, but the execution was cancelled by the user.", + metadata={ + "step_history": self.step_history.copy(), + "status": "cancelled", + "timestamp": datetime.now().isoformat() + } + ), + turn_id=turn_id + ) + + # Signal completion and break + streaming_result.put_nowait(QueueCompleteSentinel()) + break + + except Exception as e: + # these excepts should almost never happen + error_msg = f"Error in step {step_count}: {str(e)}" + log.error(error_msg) + + workflow_status = "stream_failed" + streaming_result._exception = error_msg + + # Emit error as FinalOutputItem to queue + final_output_item = FinalOutputItem(error=error_msg) + # error_event = RunItemStreamEvent( + # name="runner_finished", item=error_final_item + # ) + current_error = error_msg + break + + # If loop terminated without creating a final output item, create our own + # TODO this might be redundant + if final_output_item is None: + # Create a RunnerResult with incomplete status + runner_result = RunnerResult( + answer=f"No output generated after {step_count} steps (max_steps: {self.max_steps})", + error=current_error, + step_history=self.step_history.copy(), + + ) + final_output_item = FinalOutputItem(data=runner_result) + + workflow_status = "stream_incomplete" + current_error = f"No output generated after {step_count} steps (max_steps: {self.max_steps})" + + # Only emit execution_complete if we created a new final_output_item + # (i.e., when the loop ended without a final answer) + self._create_execution_complete_stream_event( + streaming_result=streaming_result, + final_output_item=final_output_item, + ) + + runner_span_instance.span_data.update_attributes( + { + "steps_executed": step_count, + "final_answer": final_output_item.data.answer if final_output_item.data else None, + "workflow_status": workflow_status, + } + ) + + # create runner result with or without error + + runner_result = RunnerResult( + answer=final_output_item.data.answer if final_output_item.data else None, + step_history=self.step_history.copy(), + error=current_error, + ) + + # create response span for final output + # if workflow_status in ["stream_incomplete", "stream_failed"]: + self.create_response_span( + runner_result=runner_result, + step_count=step_count, + streaming_result=streaming_result, + runner_span_instance=runner_span_instance, + workflow_status=workflow_status, + ) + + # Signal completion of streaming + streaming_result.put_nowait(QueueCompleteSentinel()) + + async def _tool_execute_async( + self, + func: Function, + streaming_result: Optional[RunnerStreamingResult] = None, + ) -> Union[FunctionOutput, Parameter]: + """ + Call this in the acall method. + Handles both sync and async functions. + Note: this version has no support for streaming. + Includes permission checking if permission_manager is configured. + """ + + # Check permission before execution + if self.permission_manager: + result = await self.permission_manager.check_permission(func) + # Handle both old (2 values) and new (3 values) return formats + if len(result) == 3: + allowed, modified_func, _ = result + else: + allowed, modified_func = result + + if not allowed: + return FunctionOutput( + name=func.name, + input=func, + output=ToolOutput( + output="Tool execution cancelled by user", + observation="Tool execution cancelled by user", + display="Permission denied", + status="cancelled", + ), + ) + + # Use modified function if user edited it + func = modified_func or func + + # Emit tool call event + if streaming_result is not None: + tool_call_item = ToolCallRunItem(data=func, id=func.id) + tool_call_event = RunItemStreamEvent( + name="agent.tool_call_start", item=tool_call_item + ) + streaming_result.put_nowait(tool_call_event) + + # if streaming_result is not None: + # result = await self.agent.tool_manager.execute_func_astream(func=func) + # else: + result = await self.agent.tool_manager.execute_func_async(func=func) + + if not isinstance(result, FunctionOutput): + raise ValueError("Result is not a FunctionOutput") + return result + + async def stream_tool_execution( + self, + function: Function, + tool_call_id: str, + tool_call_name: str, + streaming_result: RunnerStreamingResult, + ) -> tuple[Any, Any, Any]: + """ + Execute a tool/function call with streaming support and proper event handling. + + This method handles: + - Tool span creation for tracing + - Async generator support for streaming results + - Tool activity events + - Tool completion events + - Error handling and observation extraction + + Args: + function: The Function object to execute + tool_call_id: Unique identifier for this tool call + tool_call_name: Name of the tool being called + streaming_result: Queue for streaming events + + Returns: + tuple: (function_output, function_output_observation) + """ + # Create tool span for streaming function execution + with tool_span( + tool_name=tool_call_name, + function_name=function.name, # TODO fix attributes + function_args=function.args, + function_kwargs=function.kwargs, + ) as tool_span_instance: + + # TODO: inside of FunctionTool execution, it should ensure the types of async generator item + # to be either ToolCallActivityRunItem or ToolOutput(maybe) + # Call activity might be better designed + + function_result = await self._tool_execute_async( + func=function, streaming_result=streaming_result + ) # everything must be wrapped in FunctionOutput + + if not isinstance(function_result, FunctionOutput): + raise ValueError( + f"Result must be wrapped in FunctionOutput, got {type(function_result)}" + ) + + function_output = function_result.output + real_function_output = None + + # TODO: validate when the function is a generator + + if inspect.iscoroutine(function_output): + real_function_output = await function_output + elif inspect.isasyncgen(function_output): + async for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # add the tool_call_id to the item + item.id = tool_call_id + tool_call_event = RunItemStreamEvent( + name="agent.tool_call_activity", item=item + ) + streaming_result.put_nowait(tool_call_event) + else: + real_function_output = item + + elif inspect.isgenerator(function_output): + for item in function_output: + if isinstance(item, ToolCallActivityRunItem): + # add the tool_call_id to the item + item.id = tool_call_id + tool_call_event = RunItemStreamEvent( + name="agent.tool_call_activity", item=item + ) + streaming_result.put_nowait(tool_call_event) + else: + real_function_output = item + else: + real_function_output = function_output + + # create call complete + call_complete_event = RunItemStreamEvent( + name="agent.tool_call_complete", + item=ToolOutputRunItem( + id=tool_call_id, + data=FunctionOutput( + name=function.name, + input=function, + output=real_function_output, + ), + ), + ) + streaming_result.put_nowait(call_complete_event) + + function_output = real_function_output + function_output_observation = function_output + + if isinstance(function_output, ToolOutput) and hasattr( + function_output, "observation" + ): + function_output_observation = ( + function_output.observation + ) + # Update tool span attributes using update_attributes for MLflow compatibility + + tool_span_instance.span_data.update_attributes( + {"output_result": real_function_output} + ) + + return function_result, function_output, function_output_observation diff --git a/adalflow/adalflow/components/memory/flexible_memory.py b/adalflow/adalflow/components/memory/flexible_memory.py index d6adbec5..30f51973 100644 --- a/adalflow/adalflow/components/memory/flexible_memory.py +++ b/adalflow/adalflow/components/memory/flexible_memory.py @@ -21,7 +21,7 @@ class Message(DataClass): id: str = field(default_factory=lambda: str(uuid4())) role: Literal["user", "assistant", "system"] = "user" content: str = "" - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None # example metadata: "step_history": List[Dict[str, Any]] for agent response, "images" for user query with images timestamp: datetime = field(default_factory=datetime.now) @classmethod @@ -42,7 +42,7 @@ def from_system(cls, content: str, metadata: Optional[Dict] = None): @dataclass class Conversation(DataClass): - """A conversation organized as turns, where each turn can have multiple messages.""" + """A conversation organized as turns, where each turn can have multiple messages and is one agent loop.""" id: str = field(default_factory=lambda: str(uuid4())) user_id: Optional[str] = None turns: OrderedDict = field(default_factory=OrderedDict) # turn_id -> List[Message] @@ -61,9 +61,15 @@ def add_message_to_turn(self, turn_id: str, message: Message) -> str: str: The message ID """ if turn_id not in self.turns: - self.turns[turn_id] = [] - self.turns[turn_id].append(message) - return message.id + raise ValueError(f"Turn '{turn_id}' does not exist. Please create the turn first.") + if not isinstance(message, Message): + raise TypeError("message must be an instance of Message") + + try: + self.turns[turn_id].append(message) + return message.id + except Exception as e: + raise RuntimeError(f"Failed to add message to turn '{turn_id}': {e}") def get_turn_messages(self, turn_id: str) -> List[Message]: """Get all messages in a specific turn.""" @@ -101,10 +107,13 @@ def get_last_assistant_message(self) -> Optional[Message]: def create_turn(self) -> str: """Create a new turn and return its ID.""" - turn_id = str(uuid4()) - self.turns[turn_id] = [] - self._current_turn_id = turn_id - return turn_id + try: + turn_id = str(uuid4()) + self.turns[turn_id] = [] + self._current_turn_id = turn_id + return turn_id + except Exception as e: + raise RuntimeError(f"Failed to create turn: {e}") # Template for conversation formatting @@ -176,29 +185,29 @@ def create_turn(self) -> str: Returns: str: The new turn ID """ - return self.current_conversation.create_turn() + turn_id= self.current_conversation.create_turn() + return turn_id - def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str: + def add_user_query(self, content: str, turn_id: str, metadata: Optional[Dict] = None) -> str: """Add a user message to a turn. Args: content: The user's message content + turn_id: The turn ID to add the message to (required, must exist) metadata: Optional metadata - turn_id: Optional turn ID. If None, creates a new turn. Returns: str: The turn ID the message was added to - """ - # Use provided turn_id or create new turn - if turn_id is None: - turn_id = self.create_turn() - elif turn_id not in self.current_conversation.turns: - # Turn doesn't exist, create it - self.current_conversation.turns[turn_id] = [] - # Track as current turn - self.current_conversation._current_turn_id = turn_id - + Raises: + ValueError: If the turn_id doesn't exist + """ + # Check if turn exists + if turn_id not in self.current_conversation.turns: + raise ValueError( + f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() " + f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}" + ) # Create and add the user message message = Message.from_user(content, metadata) self.current_conversation.add_message_to_turn(turn_id, message) @@ -218,29 +227,28 @@ def add_user_query(self, content: str, metadata: Optional[Dict] = None, turn_id: def add_assistant_response( self, content: str, + turn_id: str, metadata: Optional[Dict] = None, - turn_id: Optional[str] = None ) -> str: """Add an assistant message to a turn. Args: content: The assistant's message content + turn_id: The turn ID to add the message to (required, must exist) metadata: Optional metadata - turn_id: Optional turn ID. If None, uses current turn or creates new. Returns: str: The turn ID the message was added to + + Raises: + ValueError: If the turn_id doesn't exist """ - # Determine which turn to use - if turn_id is None: - if self.current_conversation._current_turn_id: - turn_id = self.current_conversation._current_turn_id - else: - # No active turn, create new one for standalone response - turn_id = self.create_turn() - elif turn_id not in self.current_conversation.turns: - # Turn doesn't exist, create it - self.current_conversation.turns[turn_id] = [] + # Check if turn exists + if turn_id not in self.current_conversation.turns: + raise ValueError( + f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() " + f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}" + ) # Create and add the assistant message message = Message.from_assistant(content, metadata) @@ -258,23 +266,27 @@ def add_assistant_response( return turn_id - def add_system_message(self, content: str, metadata: Optional[Dict] = None, turn_id: Optional[str] = None) -> str: + def add_system_message(self, content: str, turn_id: str, metadata: Optional[Dict] = None) -> str: """Add a system message to a turn. Args: content: The system message content + turn_id: The turn ID to add the message to (required, must exist) metadata: Optional metadata - turn_id: Optional turn ID. If None, creates a new turn. Returns: str: The turn ID the message was added to - """ - # Use provided turn_id or create new turn - if turn_id is None: - turn_id = self.create_turn() - elif turn_id not in self.current_conversation.turns: - self.current_conversation.turns[turn_id] = [] + Raises: + ValueError: If the turn_id doesn't exist + """ + # Check if turn exists + if turn_id not in self.current_conversation.turns: + raise ValueError( + f"Turn '{turn_id}' does not exist. Please create the turn first using create_turn() " + f"or provide an existing turn ID. Available turns: {list(self.current_conversation.turns.keys())}" + ) + # Create and add the system message message = Message.from_system(content, metadata) self.current_conversation.add_message_to_turn(turn_id, message) diff --git a/adalflow/tests/test_flexible_memory.py b/adalflow/tests/test_flexible_memory.py new file mode 100644 index 00000000..08bd8fe6 --- /dev/null +++ b/adalflow/tests/test_flexible_memory.py @@ -0,0 +1,1121 @@ +"""Comprehensive test suite for FlexibleConversationMemory. + +This test module covers: +1. Basic message operations (creation, addition, retrieval) +2. Turn management (creation, message assignment, ordering) +3. Conversation management (new conversations, clearing, persistence) +4. Error handling and validation +5. Database integration +6. Metadata handling and filtering +7. Edge cases and boundary conditions +""" + +import pytest +from datetime import datetime +from collections import OrderedDict +from unittest.mock import Mock, patch +from adalflow.components.memory.flexible_memory import ( + Message, + Conversation, + FlexibleConversationMemory, +) +from adalflow.core.db import LocalDB + + +class TestMessage: + """Test the Message dataclass functionality.""" + + def test_message_creation_default(self): + """Test creating a message with default values. + + Tests: + - Default role is 'user' + - Content defaults to empty string + - ID is automatically generated + - Timestamp is automatically set + - Metadata is None by default + """ + msg = Message() + assert msg.role == "user" + assert msg.content == "" + assert msg.id is not None + assert len(msg.id) > 0 + assert isinstance(msg.timestamp, datetime) + assert msg.metadata is None + + def test_message_creation_with_values(self): + """Test creating a message with custom values. + + Tests: + - Custom role, content, and metadata are properly set + - Values are stored correctly + """ + metadata = {"key": "value", "number": 42} + msg = Message(role="assistant", content="Hello", metadata=metadata) + assert msg.role == "assistant" + assert msg.content == "Hello" + assert msg.metadata == metadata + + def test_message_from_user(self): + """Test creating a user message using class method. + + Tests: + - from_user() creates message with role='user' + - Content and metadata are properly set + """ + metadata = {"context": "test"} + msg = Message.from_user("User query", metadata) + assert msg.role == "user" + assert msg.content == "User query" + assert msg.metadata == metadata + + def test_message_from_assistant(self): + """Test creating an assistant message using class method. + + Tests: + - from_assistant() creates message with role='assistant' + - Content and metadata are properly set + """ + msg = Message.from_assistant("Assistant response", {"confidence": 0.9}) + assert msg.role == "assistant" + assert msg.content == "Assistant response" + assert msg.metadata == {"confidence": 0.9} + + def test_message_from_system(self): + """Test creating a system message using class method. + + Tests: + - from_system() creates message with role='system' + - Content and metadata are properly set + """ + msg = Message.from_system("System prompt") + assert msg.role == "system" + assert msg.content == "System prompt" + assert msg.metadata is None + + def test_message_unique_ids(self): + """Test that each message gets a unique ID. + + Tests: + - Multiple messages have different IDs + - IDs are valid UUID strings + """ + msg1 = Message() + msg2 = Message() + msg3 = Message() + assert msg1.id != msg2.id + assert msg2.id != msg3.id + assert msg1.id != msg3.id + + +class TestConversation: + """Test the Conversation dataclass functionality.""" + + def test_conversation_creation(self): + """Test creating a conversation with default values. + + Tests: + - Conversation ID is generated + - turns is an OrderedDict + - user_id can be set + - Metadata defaults to None + - Timestamp is set + """ + conv = Conversation() + assert conv.id is not None + assert isinstance(conv.turns, OrderedDict) + assert len(conv.turns) == 0 + assert conv.user_id is None + assert conv.metadata is None + assert isinstance(conv.created_at, datetime) + + def test_conversation_with_user_id(self): + """Test creating a conversation with a specific user ID. + + Tests: + - User ID is properly stored + """ + conv = Conversation(user_id="user123") + assert conv.user_id == "user123" + + def test_create_turn(self): + """Test creating a new turn in the conversation. + + Tests: + - Turn is created with unique ID + - Turn is added to OrderedDict + - Current turn ID is tracked + - Multiple turns can be created + """ + conv = Conversation() + + turn_id1 = conv.create_turn() + assert turn_id1 is not None + assert turn_id1 in conv.turns + assert conv.turns[turn_id1] == [] + assert conv._current_turn_id == turn_id1 + + turn_id2 = conv.create_turn() + assert turn_id2 != turn_id1 + assert turn_id2 in conv.turns + assert conv._current_turn_id == turn_id2 + assert len(conv.turns) == 2 + + def test_add_message_to_turn(self): + """Test adding messages to a specific turn. + + Tests: + - Messages can be added to existing turns + - Messages are stored in order + - Message IDs are returned + - Multiple messages can be added to same turn + """ + conv = Conversation() + turn_id = conv.create_turn() + + msg1 = Message.from_user("First message") + msg_id1 = conv.add_message_to_turn(turn_id, msg1) + assert msg_id1 == msg1.id + assert len(conv.turns[turn_id]) == 1 + assert conv.turns[turn_id][0] == msg1 + + msg2 = Message.from_assistant("Second message") + msg_id2 = conv.add_message_to_turn(turn_id, msg2) + assert msg_id2 == msg2.id + assert len(conv.turns[turn_id]) == 2 + assert conv.turns[turn_id][1] == msg2 + + def test_add_message_to_nonexistent_turn(self): + """Test error handling when adding message to non-existent turn. + + Tests: + - ValueError is raised with appropriate message + - Error message includes available turns + """ + conv = Conversation() + msg = Message.from_user("Test") + + with pytest.raises(ValueError) as exc_info: + conv.add_message_to_turn("fake_turn_id", msg) + assert "Turn 'fake_turn_id' does not exist" in str(exc_info.value) + + def test_add_non_message_object(self): + """Test error handling when adding non-Message object. + + Tests: + - TypeError is raised when non-Message object is provided + """ + conv = Conversation() + turn_id = conv.create_turn() + + with pytest.raises(TypeError) as exc_info: + conv.add_message_to_turn(turn_id, "not a message") + assert "must be an instance of Message" in str(exc_info.value) + + def test_get_turn_messages(self): + """Test retrieving messages from a specific turn. + + Tests: + - Can retrieve all messages from a turn + - Empty list returned for non-existent turn + - Messages maintain order + """ + conv = Conversation() + turn_id = conv.create_turn() + + msg1 = Message.from_user("User msg") + msg2 = Message.from_assistant("Assistant msg") + conv.add_message_to_turn(turn_id, msg1) + conv.add_message_to_turn(turn_id, msg2) + + messages = conv.get_turn_messages(turn_id) + assert len(messages) == 2 + assert messages[0] == msg1 + assert messages[1] == msg2 + + # Non-existent turn returns empty list + assert conv.get_turn_messages("fake_id") == [] + + def test_get_all_messages(self): + """Test retrieving all messages across all turns. + + Tests: + - Messages from all turns are returned + - Order is maintained (turn order and message order within turns) + - Empty list for empty conversation + """ + conv = Conversation() + + # Empty conversation + assert conv.get_all_messages() == [] + + # Add messages to multiple turns + turn1 = conv.create_turn() + msg1 = Message.from_user("Turn 1 User") + msg2 = Message.from_assistant("Turn 1 Assistant") + conv.add_message_to_turn(turn1, msg1) + conv.add_message_to_turn(turn1, msg2) + + turn2 = conv.create_turn() + msg3 = Message.from_user("Turn 2 User") + msg4 = Message.from_assistant("Turn 2 Assistant") + conv.add_message_to_turn(turn2, msg3) + conv.add_message_to_turn(turn2, msg4) + + all_messages = conv.get_all_messages() + assert len(all_messages) == 4 + assert all_messages == [msg1, msg2, msg3, msg4] + + def test_get_messages_by_role(self): + """Test filtering messages by role. + + Tests: + + - Can filter messages by user/assistant/system role + - Returns messages from all turns + - Empty list for roles with no messages + """ + conv = Conversation() + + turn1 = conv.create_turn() + conv.add_message_to_turn(turn1, Message.from_user("User 1")) + conv.add_message_to_turn(turn1, Message.from_assistant("Assistant 1")) + + turn2 = conv.create_turn() + conv.add_message_to_turn(turn2, Message.from_user("User 2")) + conv.add_message_to_turn(turn2, Message.from_system("System 1")) + + user_messages = conv.get_messages_by_role("user") + assert len(user_messages) == 2 + assert all(msg.role == "user" for msg in user_messages) + + assistant_messages = conv.get_messages_by_role("assistant") + assert len(assistant_messages) == 1 + assert assistant_messages[0].content == "Assistant 1" + + system_messages = conv.get_messages_by_role("system") + assert len(system_messages) == 1 + assert system_messages[0].content == "System 1" + + def test_get_last_user_message(self): + """Test retrieving the most recent user message. + + Tests: + - Returns the last user message across all turns + - Returns None if no user messages exist + - Skips over non-user messages + """ + conv = Conversation() + + # No messages + assert conv.get_last_user_message() is None + + turn1 = conv.create_turn() + user_msg1 = Message.from_user("First user") + conv.add_message_to_turn(turn1, user_msg1) + conv.add_message_to_turn(turn1, Message.from_assistant("Assistant")) + + turn2 = conv.create_turn() + user_msg2 = Message.from_user("Second user") + conv.add_message_to_turn(turn2, user_msg2) + conv.add_message_to_turn(turn2, Message.from_assistant("Another assistant")) + + last_user = conv.get_last_user_message() + assert last_user == user_msg2 + + def test_get_last_assistant_message(self): + """Test retrieving the most recent assistant message. + + Tests: + - Returns the last assistant message across all turns + - Returns None if no assistant messages exist + - Skips over non-assistant messages + """ + conv = Conversation() + + # No messages + assert conv.get_last_assistant_message() is None + + turn1 = conv.create_turn() + conv.add_message_to_turn(turn1, Message.from_user("User")) + assistant_msg1 = Message.from_assistant("First assistant") + conv.add_message_to_turn(turn1, assistant_msg1) + + turn2 = conv.create_turn() + conv.add_message_to_turn(turn2, Message.from_user("Another user")) + assistant_msg2 = Message.from_assistant("Second assistant") + conv.add_message_to_turn(turn2, assistant_msg2) + + last_assistant = conv.get_last_assistant_message() + assert last_assistant == assistant_msg2 + + +class TestFlexibleConversationMemory: + """Test the FlexibleConversationMemory component.""" + + def test_memory_initialization(self): + """Test initializing memory with default and custom settings. + + Tests: + - Default initialization creates empty conversation + - Custom database can be provided + - User ID is properly set + """ + # Default initialization + memory = FlexibleConversationMemory() + assert memory.current_conversation is not None + assert memory.user_id is None + assert memory.message_db is not None + assert memory.conver_db is not None + + # With custom settings + custom_db = LocalDB() + memory = FlexibleConversationMemory(turn_db=custom_db, user_id="test_user") + assert memory.message_db == custom_db + assert memory.user_id == "test_user" + assert memory.current_conversation.user_id == "test_user" + + def test_create_turn(self): + """Test creating turns through memory interface. + + Tests: + - Turn creation returns valid ID + - Multiple turns can be created + - Turn IDs are unique + """ + memory = FlexibleConversationMemory() + + turn_id1 = memory.create_turn() + assert turn_id1 is not None + assert turn_id1 in memory.current_conversation.turns + + turn_id2 = memory.create_turn() + assert turn_id2 != turn_id1 + assert len(memory.current_conversation.turns) == 2 + + def test_add_user_query(self): + """Test adding user queries to turns. + + Tests: + - User query is added to specified turn + - Message is stored in database + - Metadata is properly handled + - Error raised for non-existent turn + """ + memory = FlexibleConversationMemory() + turn_id = memory.create_turn() + + # Add user query + returned_id = memory.add_user_query("Hello", turn_id, {"context": "greeting"}) + assert returned_id == turn_id + + # Check message was added + messages = memory.get_turn_messages(turn_id) + assert len(messages) == 1 + assert messages[0].role == "user" + assert messages[0].content == "Hello" + assert messages[0].metadata == {"context": "greeting"} + + # Check database storage + assert len(memory.message_db.items) == 1 + db_item = memory.message_db.items[0] + assert db_item["role"] == "user" + assert db_item["content"] == "Hello" + assert db_item["turn_id"] == turn_id + + def test_add_user_query_nonexistent_turn(self): + """Test error handling when adding user query to non-existent turn. + + Tests: + - ValueError is raised with helpful message + - Message includes available turns + """ + memory = FlexibleConversationMemory() + + with pytest.raises(ValueError) as exc_info: + memory.add_user_query("Hello", "fake_turn") + assert "Turn 'fake_turn' does not exist" in str(exc_info.value) + assert "create_turn()" in str(exc_info.value) + + def test_add_assistant_response(self): + """Test adding assistant responses to turns. + + Tests: + - Assistant response is added to specified turn + - Message is stored in database + - Metadata is properly handled + - Error raised for non-existent turn + """ + memory = FlexibleConversationMemory() + turn_id = memory.create_turn() + + # Add assistant response + returned_id = memory.add_assistant_response( + "Hi there!", turn_id, {"confidence": 0.95} + ) + assert returned_id == turn_id + + # Check message was added + messages = memory.get_turn_messages(turn_id) + assert len(messages) == 1 + assert messages[0].role == "assistant" + assert messages[0].content == "Hi there!" + assert messages[0].metadata == {"confidence": 0.95} + + # Check database storage + assert len(memory.message_db.items) == 1 + db_item = memory.message_db.items[0] + assert db_item["role"] == "assistant" + assert db_item["content"] == "Hi there!" + + def test_add_system_message(self): + """Test adding system messages to turns. + + Tests: + - System message is added to specified turn + - Message is stored in database + - Metadata is properly handled + - Error raised for non-existent turn + """ + memory = FlexibleConversationMemory() + turn_id = memory.create_turn() + + # Add system message + returned_id = memory.add_system_message( + "System initialized", turn_id, {"version": "1.0"} + ) + assert returned_id == turn_id + + # Check message was added + messages = memory.get_turn_messages(turn_id) + assert len(messages) == 1 + assert messages[0].role == "system" + assert messages[0].content == "System initialized" + assert messages[0].metadata == {"version": "1.0"} + + def test_complete_conversation_flow(self): + """Test a complete conversation flow with multiple turns. + + Tests: + - Multiple turns with user/assistant exchanges + - Messages maintain order within and across turns + - All retrieval methods work correctly + """ + memory = FlexibleConversationMemory() + + # Turn 1: Initial greeting + turn1 = memory.create_turn() + memory.add_user_query("Hello, how are you?", turn1) + memory.add_assistant_response("I'm doing well, thank you!", turn1) + + # Turn 2: Follow-up question + turn2 = memory.create_turn() + memory.add_user_query("What can you help me with?", turn2) + memory.add_assistant_response("I can help with many things!", turn2) + memory.add_user_query("Can you be more specific?", turn2) # Multiple user queries in same turn + memory.add_assistant_response("I can help with coding, analysis, and more.", turn2) + + # Check all messages + all_messages = memory.get_all_messages() + assert len(all_messages) == 6 + + # Check message ordering + assert all_messages[0].content == "Hello, how are you?" + assert all_messages[1].content == "I'm doing well, thank you!" + assert all_messages[2].content == "What can you help me with?" + assert all_messages[3].content == "I can help with many things!" + assert all_messages[4].content == "Can you be more specific?" + assert all_messages[5].content == "I can help with coding, analysis, and more." + + def test_clear_conversation(self): + """Test clearing the conversation. + + Tests: + - All turns are cleared + - Current turn ID is reset + - Database is not affected + """ + memory = FlexibleConversationMemory() + + # Add some content + turn1 = memory.create_turn() + memory.add_user_query("Message 1", turn1) + turn2 = memory.create_turn() + memory.add_user_query("Message 2", turn2) + + # Verify content exists + assert len(memory.current_conversation.turns) == 2 + assert memory.current_conversation._current_turn_id == turn2 + + # Clear conversation + memory.clear_conversation() + + # Verify cleared + assert len(memory.current_conversation.turns) == 0 + assert memory.current_conversation._current_turn_id is None + + # Database should still have items + assert len(memory.message_db.items) == 2 + + def test_clear_conversation_turns_alias(self): + """Test that clear_conversation_turns is an alias for clear_conversation. + + Tests: + - Both methods have the same effect + """ + memory = FlexibleConversationMemory() + turn = memory.create_turn() + memory.add_user_query("Test", turn) + + memory.clear_conversation_turns() + assert len(memory.current_conversation.turns) == 0 + + def test_new_conversation(self): + """Test starting a new conversation. + + Tests: + - Current conversation is saved to conver_db + - New empty conversation is created + - User ID is preserved + """ + memory = FlexibleConversationMemory(user_id="test_user") + + # Add content to first conversation + turn = memory.create_turn() + memory.add_user_query("First conversation", turn) + first_conv_id = memory.current_conversation.id + + # Start new conversation + memory.new_conversation() + + # Check new conversation is empty and different + assert len(memory.current_conversation.turns) == 0 + assert memory.current_conversation.id != first_conv_id + assert memory.current_conversation.user_id == "test_user" + + # Check first conversation was saved + assert len(memory.conver_db.items) == 1 + saved_conv = memory.conver_db.items[0] + assert saved_conv.id == first_conv_id + + def test_get_current_turn_id(self): + """Test retrieving the current turn ID. + + Tests: + - Returns None when no turns exist + - Returns correct turn ID after creation + - Updates when new turn is created + """ + memory = FlexibleConversationMemory() + + # No turns initially + assert memory.get_current_turn_id() is None + + # After creating turn + turn1 = memory.create_turn() + assert memory.get_current_turn_id() == turn1 + + # After creating another turn + turn2 = memory.create_turn() + assert memory.get_current_turn_id() == turn2 + + def test_call_empty_conversation(self): + """Test calling memory with empty conversation. + + Tests: + - Returns empty string for empty conversation + """ + memory = FlexibleConversationMemory() + assert memory() == "" + assert memory.call() == "" + + def test_call_with_messages(self): + """Test formatting conversation for output. + + Tests: + - Messages are formatted correctly + - Roles are properly labeled + - Multiple turns are handled + - Metadata is included when present + """ + memory = FlexibleConversationMemory() + + turn1 = memory.create_turn() + memory.add_user_query("What is Python?", turn1) + memory.add_assistant_response("Python is a programming language.", turn1) + + turn2 = memory.create_turn() + memory.add_user_query("Tell me more", turn2, {"priority": "high"}) + memory.add_assistant_response("It's known for its simplicity.", turn2) + + output = memory() + + # Check basic content + assert "User: What is Python?" in output + assert "Assistant: Python is a programming language." in output + assert "User: Tell me more" in output + assert "Assistant: It's known for its simplicity." in output + + # Check metadata + assert "priority: high" in output + + def test_call_with_metadata_filter(self): + """Test filtering metadata in conversation output. + + Tests: + - Only specified metadata keys are included + - Other metadata is filtered out + - Works across multiple messages + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query("Query", turn, { + "public": "visible", + "private": "hidden", + "context": "test" + }) + memory.add_assistant_response("Response", turn, { + "confidence": 0.9, + "internal": "secret" + }) + + # Filter to only show 'public' and 'confidence' + output = memory(metadata_filter=["public", "confidence"]) + + assert "public: visible" in output + assert "confidence: 0.9" in output + assert "private: hidden" not in output + assert "context: test" not in output + assert "internal: secret" not in output + + def test_call_with_system_messages(self): + """Test that system messages are included in output. + + Tests: + - System messages are formatted correctly + - Mixed message types are handled + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_system_message("System: Initializing", turn) + memory.add_user_query("Hello", turn) + memory.add_assistant_response("Hi there", turn) + + output = memory() + assert "System: System: Initializing" in output + assert "User: Hello" in output + assert "Assistant: Hi there" in output + + def test_get_last_n_messages(self): + """Test retrieving the last N messages. + + Tests: + - Returns correct number of messages + - Returns all messages if N > total + - Maintains correct order + """ + memory = FlexibleConversationMemory() + + # Add 5 messages + turn = memory.create_turn() + for i in range(5): + if i % 2 == 0: + memory.add_user_query(f"Message {i}", turn) + else: + memory.add_assistant_response(f"Message {i}", turn) + + # Get last 3 + last_3 = memory.get_last_n_messages(3) + assert len(last_3) == 3 + assert last_3[0].content == "Message 2" + assert last_3[1].content == "Message 3" + assert last_3[2].content == "Message 4" + + # Get more than available + last_10 = memory.get_last_n_messages(10) + assert len(last_10) == 5 + + def test_count_messages(self): + """Test counting messages by role. + + Tests: + - Counts are accurate for each role + - Handles empty conversation + - Works across multiple turns + """ + memory = FlexibleConversationMemory() + + # Empty conversation + counts = memory.count_messages() + assert counts == {"user": 0, "assistant": 0, "system": 0} + + # Add messages + turn1 = memory.create_turn() + memory.add_user_query("User 1", turn1) + memory.add_assistant_response("Assistant 1", turn1) + + turn2 = memory.create_turn() + memory.add_user_query("User 2", turn2) + memory.add_user_query("User 3", turn2) + memory.add_system_message("System 1", turn2) + + counts = memory.count_messages() + assert counts["user"] == 3 + assert counts["assistant"] == 1 + assert counts["system"] == 1 + + def test_count_turns(self): + """Test counting the number of turns. + + Tests: + - Returns 0 for empty conversation + - Counts turns correctly + """ + memory = FlexibleConversationMemory() + + assert memory.count_turns() == 0 + + memory.create_turn() + assert memory.count_turns() == 1 + + memory.create_turn() + memory.create_turn() + assert memory.count_turns() == 3 + + def test_reset_pending_query(self): + """Test resetting pending query (compatibility method). + + Tests: + - Current turn ID is unset + - Turns are not deleted + - Messages remain intact + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query("Test", turn) + + # Current turn is set + assert memory.get_current_turn_id() == turn + + # Reset pending query + memory.reset_pending_query() + + # Current turn is unset but turn still exists + assert memory.get_current_turn_id() is None + assert turn in memory.current_conversation.turns + assert len(memory.get_turn_messages(turn)) == 1 + + def test_len_magic_method(self): + """Test the __len__ magic method. + + Tests: + - Returns 0 for empty conversation + - Returns correct count of all messages + """ + memory = FlexibleConversationMemory() + + assert len(memory) == 0 + + turn = memory.create_turn() + memory.add_user_query("1", turn) + memory.add_assistant_response("2", turn) + memory.add_system_message("3", turn) + + assert len(memory) == 3 + + def test_multiple_queries_same_turn(self): + """Test adding multiple user queries to the same turn. + + Tests: + - Multiple user queries can be added to one turn + - Order is maintained + - Common in clarification scenarios + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query("First question", turn) + memory.add_user_query("Follow-up question", turn) + memory.add_user_query("Another clarification", turn) + memory.add_assistant_response("Comprehensive answer", turn) + + messages = memory.get_turn_messages(turn) + assert len(messages) == 4 + assert messages[0].content == "First question" + assert messages[1].content == "Follow-up question" + assert messages[2].content == "Another clarification" + assert messages[3].content == "Comprehensive answer" + + def test_complex_metadata_handling(self): + """Test handling complex metadata structures. + + Tests: + - Nested dictionaries in metadata + - Lists in metadata + - Mixed data types + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + complex_metadata = { + "step_history": [ + {"action": "search", "result": "found"}, + {"action": "analyze", "result": "complete"} + ], + "images": ["image1.png", "image2.jpg"], + "confidence": 0.95, + "nested": { + "level1": { + "level2": "deep value" + } + } + } + + memory.add_user_query("Complex query", turn, complex_metadata) + + messages = memory.get_turn_messages(turn) + assert messages[0].metadata == complex_metadata + assert messages[0].metadata["step_history"][0]["action"] == "search" + assert messages[0].metadata["nested"]["level1"]["level2"] == "deep value" + + def test_conversation_persistence(self): + """Test that conversations are properly persisted. + + Tests: + - Conversations are saved when starting new one + - Empty conversations are not saved + - Multiple conversations can be saved + """ + memory = FlexibleConversationMemory() + + # Empty conversation not saved + memory.new_conversation() + assert len(memory.conver_db.items) == 0 + + # Add content and start new conversation + turn = memory.create_turn() + memory.add_user_query("First conv", turn) + memory.new_conversation() + assert len(memory.conver_db.items) == 1 + + # Add more conversations + turn = memory.create_turn() + memory.add_user_query("Second conv", turn) + memory.new_conversation() + assert len(memory.conver_db.items) == 2 + + def test_error_handling_in_create_turn(self): + """Test error handling in turn creation. + + Tests: + - RuntimeError is raised if turn creation fails + - Error message is informative + """ + memory = FlexibleConversationMemory() + + # Mock a failure in UUID generation + with patch('adalflow.components.memory.flexible_memory.uuid4', side_effect=Exception("UUID error")): + with pytest.raises(RuntimeError) as exc_info: + memory.create_turn() + assert "Failed to create turn" in str(exc_info.value) + + def test_database_integration(self): + """Test integration with LocalDB for message storage. + + Tests: + - Messages are stored with correct fields + - Turn IDs are properly tracked + - Timestamps are preserved + - Multiple databases work independently + """ + db1 = LocalDB() + db2 = LocalDB() + + memory1 = FlexibleConversationMemory(turn_db=db1) + memory2 = FlexibleConversationMemory(turn_db=db2) + + # Add to memory1 + turn1 = memory1.create_turn() + memory1.add_user_query("Memory 1 message", turn1) + + # Add to memory2 + turn2 = memory2.create_turn() + memory2.add_user_query("Memory 2 message", turn2) + + # Check databases are independent + assert len(db1.items) == 1 + assert len(db2.items) == 1 + assert db1.items[0]["content"] == "Memory 1 message" + assert db2.items[0]["content"] == "Memory 2 message" + + def test_edge_cases(self): + """Test various edge cases and boundary conditions. + + Tests: + - Empty strings as content + - Very long content + - Special characters in content + - None values where applicable + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + + # Empty string content + memory.add_user_query("", turn) + messages = memory.get_turn_messages(turn) + assert messages[0].content == "" + + # Very long content + long_content = "x" * 10000 + memory.add_assistant_response(long_content, turn) + assert len(memory.get_turn_messages(turn)[1].content) == 10000 + + # Special characters + special = "Hello\n\tWorld!@#$%^&*()[]{}|\\<>?/~`" + memory.add_user_query(special, turn) + assert memory.get_turn_messages(turn)[2].content == special + + # None metadata (should work fine) + memory.add_user_query("Test", turn, None) + assert memory.get_turn_messages(turn)[3].metadata is None + + +class TestIntegration: + """Integration tests for the complete memory system.""" + + def test_realistic_conversation_flow(self): + """Test a realistic multi-turn conversation with all features. + + Tests: + - Complete conversation flow + - Mixed message types + - Metadata handling + - Turn management + - Output formatting + """ + memory = FlexibleConversationMemory(user_id="test_user") + + # Turn 1: Initial setup + turn1 = memory.create_turn() + memory.add_system_message("You are a helpful assistant.", turn1) + memory.add_user_query("Hi, I need help with Python", turn1, {"source": "web_ui"}) + memory.add_assistant_response( + "Hello! I'd be happy to help you with Python. What specific topic?", + turn1, + {"confidence": 0.95} + ) + + # Turn 2: Follow-up + turn2 = memory.create_turn() + memory.add_user_query("How do I read a file?", turn2) + memory.add_assistant_response( + "You can use the open() function with a context manager.", + turn2, + {"confidence": 0.98, "sources": ["python_docs"]} + ) + memory.add_user_query("Can you show an example?", turn2) # Clarification in same turn + memory.add_assistant_response( + "with open('file.txt', 'r') as f:\n content = f.read()", + turn2, + {"code_block": True} + ) + + # Verify conversation state + assert memory.count_turns() == 2 + counts = memory.count_messages() + assert counts["user"] == 3 + assert counts["assistant"] == 3 + assert counts["system"] == 1 + + # Check output formatting + output = memory() + assert "You are a helpful assistant" in output + assert "How do I read a file?" in output + assert "with open('file.txt', 'r')" in output + + # Test filtered output + filtered = memory(metadata_filter=["confidence"]) + assert "confidence: 0.95" in filtered + assert "confidence: 0.98" in filtered + assert "source: web_ui" not in filtered + + # Get last messages + last_2 = memory.get_last_n_messages(2) + assert last_2[0].content == "Can you show an example?" + assert "with open" in last_2[1].content + + # Start new conversation but preserve the old one + memory.new_conversation() + assert len(memory.conver_db.items) == 1 + assert memory.count_messages() == {"user": 0, "assistant": 0, "system": 0} + + def test_agent_style_conversation(self): + """Test conversation flow typical of an AI agent with step tracking. + + Tests: + - Agent-style metadata with step history + - Multiple assistant messages in sequence + - Complex metadata structures + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + + # User query with context + memory.add_user_query( + "Find all Python files in the project", + turn, + {"request_type": "search", "timestamp": "2024-01-01T10:00:00"} + ) + + # Agent thinking steps + memory.add_assistant_response( + "I'll search for Python files in the project.", + turn, + { + "step": 1, + "action": "plan", + "confidence": 0.9 + } + ) + + memory.add_assistant_response( + "Found 15 Python files in 3 directories.", + turn, + { + "step": 2, + "action": "search", + "results": { + "count": 15, + "directories": ["src", "tests", "scripts"] + } + } + ) + + memory.add_assistant_response( + "Here are the Python files I found: [list of files]", + turn, + { + "step": 3, + "action": "report", + "completion": True + } + ) + + # Verify step tracking + messages = memory.get_turn_messages(turn) + assert len(messages) == 4 + + # Check assistant messages have increasing step numbers + assistant_messages = [m for m in messages if m.role == "assistant"] + assert assistant_messages[0].metadata["step"] == 1 + assert assistant_messages[1].metadata["step"] == 2 + assert assistant_messages[2].metadata["step"] == 3 + + # Verify complex metadata structure + assert assistant_messages[1].metadata["results"]["count"] == 15 + assert "src" in assistant_messages[1].metadata["results"]["directories"] \ No newline at end of file diff --git a/adalflow/tests/test_flexible_memory_template.py b/adalflow/tests/test_flexible_memory_template.py new file mode 100644 index 00000000..cfdb3a00 --- /dev/null +++ b/adalflow/tests/test_flexible_memory_template.py @@ -0,0 +1,538 @@ +"""Test the Jinja2 template rendering in FlexibleConversationMemory. + +This module specifically tests: +1. Template rendering with call() method +2. Metadata filtering in templates +3. Multiple turns and messages formatting +4. System, user, and assistant message formatting +5. Edge cases in template rendering +""" + +import pytest +from adalflow.components.memory.flexible_memory import ( + Message, + Conversation, + FlexibleConversationMemory, + CONVERSATION_TEMPLATE, +) +from adalflow.core.prompt_builder import Prompt + + +class TestFlexibleMemoryTemplateRendering: + """Test the Jinja2 template rendering functionality.""" + + def test_basic_template_rendering(self): + """Test basic conversation rendering with the template. + + Tests: + - Single turn with user and assistant messages + - Proper formatting with "User:" and "Assistant:" prefixes + - Newline handling + """ + memory = FlexibleConversationMemory() + + turn_id = memory.create_turn() + memory.add_user_query("Hello, how are you?", turn_id) + memory.add_assistant_response("I'm doing well, thank you!", turn_id) + + # Call the memory to render the template + output = memory.call() + + # Check the formatted output + assert "User: Hello, how are you?" in output + assert "Assistant: I'm doing well, thank you!" in output + + # Check order + lines = output.strip().split('\n') + user_line_idx = next(i for i, line in enumerate(lines) if "User:" in line) + assistant_line_idx = next(i for i, line in enumerate(lines) if "Assistant:" in line) + assert user_line_idx < assistant_line_idx + + def test_multiple_turns_rendering(self): + """Test rendering multiple conversation turns. + + Tests: + - Multiple turns are rendered in order + - Each turn's messages are grouped together + - Proper spacing between turns + """ + memory = FlexibleConversationMemory() + + # Turn 1 + turn1 = memory.create_turn() + memory.add_user_query("What is Python?", turn1) + memory.add_assistant_response("Python is a programming language.", turn1) + + # Turn 2 + turn2 = memory.create_turn() + memory.add_user_query("What can I do with it?", turn2) + memory.add_assistant_response("You can build web apps, analyze data, and more.", turn2) + + output = memory.call() + + # Check all messages are present + assert "What is Python?" in output + assert "Python is a programming language" in output + assert "What can I do with it?" in output + assert "You can build web apps" in output + + # Check order is maintained + python_idx = output.index("What is Python?") + python_answer_idx = output.index("Python is a programming") + what_can_idx = output.index("What can I do with it?") + web_apps_idx = output.index("You can build web apps") + + assert python_idx < python_answer_idx < what_can_idx < web_apps_idx + + def test_metadata_rendering(self): + """Test metadata rendering in the template. + + Tests: + - Metadata is rendered below the message + - Metadata keys and values are formatted correctly + - Multiple metadata items are shown + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Search for Python tutorials", + turn, + metadata={ + "source": "web_interface", + "priority": "high", + "timestamp": "2024-01-01T10:00:00" + } + ) + memory.add_assistant_response( + "Here are some Python tutorials", + turn, + metadata={ + "confidence": 0.95, + "sources_count": 5 + } + ) + + output = memory.call() + + # Check metadata is rendered + assert "source: web_interface" in output + assert "priority: high" in output + assert "timestamp: 2024-01-01T10:00:00" in output + assert "confidence: 0.95" in output + assert "sources_count: 5" in output + + def test_metadata_filtering(self): + """Test metadata filtering in template rendering. + + Tests: + - Only specified metadata keys are shown + - Other metadata is filtered out + - Filtering works across multiple messages + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Query with metadata", + turn, + metadata={ + "public": "visible", + "private": "hidden", + "internal": "secret" + } + ) + memory.add_assistant_response( + "Response with metadata", + turn, + metadata={ + "public": "also visible", + "confidence": "high", + "debug": "hidden info" + } + ) + + # Call with metadata filter + output = memory.call(metadata_filter=["public", "confidence"]) + + # Check filtered metadata + assert "public: visible" in output + assert "public: also visible" in output + assert "confidence: high" in output + + # Check filtered out metadata is not present + assert "private: hidden" not in output + assert "internal: secret" not in output + assert "debug: hidden info" not in output + + def test_system_messages_rendering(self): + """Test system message rendering in the template. + + Tests: + - System messages are properly formatted + - System messages appear with "System:" prefix + - Mixed message types work correctly + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_system_message("You are a helpful assistant.", turn) + memory.add_user_query("Hello", turn) + memory.add_assistant_response("Hi! How can I help you?", turn) + + output = memory.call() + + # Check system message formatting + assert "System: You are a helpful assistant." in output + assert "User: Hello" in output + assert "Assistant: Hi! How can I help you?" in output + + # Check order + system_idx = output.index("System:") + user_idx = output.index("User:") + assistant_idx = output.index("Assistant:") + assert system_idx < user_idx < assistant_idx + + def test_multiple_messages_same_role_in_turn(self): + """Test multiple messages from the same role in one turn. + + Tests: + - Multiple user queries in one turn + - Multiple assistant responses in one turn + - Proper ordering and formatting + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query("First question", turn) + memory.add_user_query("Follow-up question", turn) + memory.add_user_query("Another clarification", turn) + memory.add_assistant_response("Comprehensive answer addressing all questions", turn) + + output = memory.call() + + # All messages should be present + assert "User: First question" in output + assert "User: Follow-up question" in output + assert "User: Another clarification" in output + assert "Assistant: Comprehensive answer" in output + + # Check ordering + first_idx = output.index("First question") + followup_idx = output.index("Follow-up question") + clarification_idx = output.index("Another clarification") + answer_idx = output.index("Comprehensive answer") + + assert first_idx < followup_idx < clarification_idx < answer_idx + + def test_empty_conversation_rendering(self): + """Test rendering an empty conversation. + + Tests: + - Empty conversation returns empty string + - No errors are raised + """ + memory = FlexibleConversationMemory() + + output = memory.call() + assert output == "" + + # Also test with metadata filter + output_filtered = memory.call(metadata_filter=["some_key"]) + assert output_filtered == "" + + def test_special_characters_in_content(self): + """Test rendering with special characters in content. + + Tests: + - Newlines in content + - Special characters (quotes, brackets, etc.) + - Unicode characters + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Can you explain:\n1. Lists\n2. Dictionaries\n3. Sets", + turn + ) + memory.add_assistant_response( + 'Sure! Here\'s an explanation:\n• Lists: [1, 2, 3]\n• Dicts: {"key": "value"}\n• Sets: {1, 2, 3}', + turn + ) + + output = memory.call() + + # Check special characters are preserved + assert "1. Lists" in output + assert "2. Dictionaries" in output + assert '{"key": "value"}' in output + assert "• Lists:" in output + + def test_complex_metadata_in_template(self): + """Test rendering complex metadata structures. + + Tests: + - Nested dictionaries in metadata + - Lists in metadata + - Numbers and booleans in metadata + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Complex query", + turn, + metadata={ + "nested": { + "level1": { + "level2": "deep value" + } + }, + "list_data": ["item1", "item2", "item3"], + "numeric": 42, + "boolean": True + } + ) + + output = memory.call() + + # Check complex metadata is rendered (as string representations) + assert "Complex query" in output + # The template will render these as strings + assert "nested:" in output + assert "list_data:" in output + assert "numeric: 42" in output + assert "boolean: True" in output + + def test_template_with_none_metadata(self): + """Test template handling of None metadata. + + Tests: + - Messages with None metadata don't show metadata section + - Messages with empty dict metadata don't show metadata section + - Mixed messages with and without metadata + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + # Message with None metadata + memory.add_user_query("No metadata query", turn, metadata=None) + # Message with empty metadata + memory.add_assistant_response("No metadata response", turn, metadata={}) + # Message with actual metadata + memory.add_user_query("With metadata", turn, metadata={"key": "value"}) + + output = memory.call() + + # Check messages are present + assert "User: No metadata query" in output + assert "Assistant: No metadata response" in output + assert "User: With metadata" in output + assert "key: value" in output + + # Count occurrences of "key:" - should only be one + assert output.count("key:") == 1 + + def test_callable_interface(self): + """Test that the callable interface works the same as call(). + + Tests: + - __call__ produces same output as call() + - Metadata filtering works with __call__ + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query("Test query", turn, metadata={"meta1": "value1"}) + memory.add_assistant_response("Test response", turn, metadata={"meta2": "value2"}) + + # Test call() and __call__() produce same output + output_call = memory.call() + output_callable = memory() + assert output_call == output_callable + + # Test metadata filtering with both + output_call_filtered = memory.call(metadata_filter=["meta1"]) + output_callable_filtered = memory(metadata_filter=["meta1"]) + assert output_call_filtered == output_callable_filtered + assert "meta1: value1" in output_callable_filtered + assert "meta2: value2" not in output_callable_filtered + + def test_template_direct_rendering(self): + """Test the template can be rendered directly with Prompt. + + Tests: + - CONVERSATION_TEMPLATE constant is valid + - Can create Prompt with the template + - Direct rendering matches memory.call() + """ + from collections import OrderedDict + + # Create data structure manually + turns = OrderedDict() + turn_id = "test_turn" + turns[turn_id] = [ + Message.from_user("Hello"), + Message.from_assistant("Hi there!") + ] + + # Render with Prompt directly + prompt = Prompt( + template=CONVERSATION_TEMPLATE, + prompt_kwargs={ + "turns": turns, + "metadata_filter": None + } + ) + direct_output = prompt.call().strip() + + # Create same conversation with memory + memory = FlexibleConversationMemory() + memory_turn = memory.create_turn() + memory.add_user_query("Hello", memory_turn) + memory.add_assistant_response("Hi there!", memory_turn) + memory_output = memory.call().strip() + + # Should produce same output + assert direct_output == memory_output + + def test_long_conversation_rendering(self): + """Test rendering of long conversations. + + Tests: + - Many turns (10+) + - Many messages per turn + - Performance doesn't degrade + """ + memory = FlexibleConversationMemory() + + # Create 10 turns with multiple messages each + for i in range(10): + turn = memory.create_turn() + memory.add_user_query(f"Question {i+1}", turn) + memory.add_user_query(f"Clarification {i+1}", turn) + memory.add_assistant_response(f"Answer {i+1}", turn) + if i % 2 == 0: + memory.add_system_message(f"System note {i+1}", turn) + + output = memory.call() + + # Check all messages are present + for i in range(10): + assert f"Question {i+1}" in output + assert f"Clarification {i+1}" in output + assert f"Answer {i+1}" in output + if i % 2 == 0: + assert f"System note {i+1}" in output + + # Check it's not empty and has reasonable length + assert len(output) > 500 # Should be a long conversation + + # Check ordering is maintained + q1_idx = output.index("Question 1") + q10_idx = output.index("Question 10") + assert q1_idx < q10_idx + + +class TestTemplateEdgeCases: + """Test edge cases and error conditions in template rendering.""" + + def test_template_with_jinja2_special_chars_in_content(self): + """Test content that contains Jinja2 special characters. + + Tests: + - Content with {{ }} doesn't break template + - Content with {% %} doesn't break template + - Content is properly escaped + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Show me a template: {{ variable }} and {% if condition %}", + turn + ) + memory.add_assistant_response( + "Here's the template: {{ var }} and {% for item in items %}", + turn + ) + + output = memory.call() + + # Special characters should be preserved in output + assert "{{ variable }}" in output + assert "{% if condition %}" in output + assert "{{ var }}" in output + assert "{% for item in items %}" in output + + def test_metadata_with_jinja2_chars(self): + """Test metadata containing Jinja2 special characters. + + Tests: + - Metadata with template syntax doesn't break rendering + - Values are properly escaped + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query( + "Query", + turn, + metadata={ + "template": "{{ user_input }}", + "condition": "{% if x > 0 %}", + "normal": "regular value" + } + ) + + output = memory.call() + + # Metadata should be rendered safely + assert "template: {{ user_input }}" in output + assert "condition: {% if x > 0 %}" in output + assert "normal: regular value" in output + + def test_very_long_messages(self): + """Test rendering very long messages. + + Tests: + - Long messages don't break template + - Full content is preserved + """ + memory = FlexibleConversationMemory() + + # Create a very long message + long_content = "This is a very long message. " * 100 # ~2500 characters + + turn = memory.create_turn() + memory.add_user_query(long_content, turn) + memory.add_assistant_response("Short response", turn) + + output = memory.call() + + # Check long content is fully preserved + assert long_content in output + assert "Short response" in output + + def test_whitespace_preservation(self): + """Test that whitespace in messages is preserved. + + Tests: + - Leading/trailing spaces + - Multiple spaces + - Tabs and newlines + """ + memory = FlexibleConversationMemory() + + turn = memory.create_turn() + memory.add_user_query(" Message with spaces ", turn) + memory.add_assistant_response("\tTabbed\tmessage\t", turn) + memory.add_system_message("Line1\nLine2\nLine3", turn) + + output = memory.call() + + # Check whitespace is preserved + assert " Message with spaces " in output + assert "\tTabbed\tmessage\t" in output + assert "Line1\nLine2\nLine3" in output \ No newline at end of file diff --git a/adalflow/tests/test_runner_flexible.py b/adalflow/tests/test_runner_flexible.py new file mode 100644 index 00000000..043985ef --- /dev/null +++ b/adalflow/tests/test_runner_flexible.py @@ -0,0 +1,517 @@ +"""Test suite for RunnerFlexible with FlexibleConversationMemory integration. + +This module tests: +1. Runner initialization with flexible memory +2. Safe memory operations that never fail the runner +3. Turn management during execution +4. Error handling and recovery +5. Memory persistence across multiple runs +""" + +import pytest +import asyncio +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime + +from adalflow.components.agent.runner_flexible import RunnerFlexible +from adalflow.components.memory.flexible_memory import FlexibleConversationMemory +from adalflow.components.agent.agent import Agent +from adalflow.core.types import ( + GeneratorOutput, + Function, + FunctionOutput, + RunnerResult, + UserQuery, + AssistantResponse, + ToolOutput, +) + + +class TestRunnerFlexibleMemoryIntegration: + """Test RunnerFlexible with FlexibleConversationMemory.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock agent for testing.""" + agent = Mock(spec=Agent) + agent.max_steps = 3 + agent.answer_data_type = str + agent.is_thinking_model = False + + # Mock planner + agent.planner = Mock() + agent.planner.call = Mock() + agent.planner.acall = Mock() + agent.planner.get_prompt = Mock(return_value="test prompt") + agent.planner.estimated_token_count = 100 + + # Mock tool manager + agent.tool_manager = Mock() + agent.tool_manager.tools = [] + agent.tool_manager.execute_func = Mock() + agent.tool_manager.execute_func_async = Mock() + + return agent + + @pytest.fixture + def flexible_memory(self): + """Create a flexible memory instance.""" + return FlexibleConversationMemory() + + @pytest.fixture + def runner_with_memory(self, mock_agent, flexible_memory): + """Create a runner with flexible memory.""" + return RunnerFlexible( + agent=mock_agent, + conversation_memory=flexible_memory + ) + + def test_runner_initialization_with_memory(self, mock_agent, flexible_memory): + """Test that runner initializes correctly with flexible memory. + + Tests: + - Runner accepts FlexibleConversationMemory + - use_conversation_memory flag is set correctly + - Safe memory operation methods are available + """ + runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory) + + assert runner.conversation_memory == flexible_memory + assert runner.use_conversation_memory is True + assert hasattr(runner, '_safe_create_turn') + assert hasattr(runner, '_safe_add_user_query') + assert hasattr(runner, '_safe_add_assistant_response') + assert hasattr(runner, '_safe_get_conversation_history') + + def test_safe_create_turn(self, runner_with_memory): + """Test safe turn creation that never fails. + + Tests: + - Normal turn creation works + - Returns None when memory fails + - Logs warning on failure + """ + # Normal operation + turn_id = runner_with_memory._safe_create_turn() + assert turn_id is not None + assert isinstance(turn_id, str) + + # Test failure handling + runner_with_memory.conversation_memory.create_turn = Mock( + side_effect=Exception("Memory error") + ) + + with patch('adalflow.components.agent.runner_flexible.log') as mock_log: + result = runner_with_memory._safe_create_turn() + assert result is None + mock_log.warning.assert_called_once() + assert "Failed to create turn" in str(mock_log.warning.call_args) + + def test_safe_add_user_query(self, runner_with_memory): + """Test safe user query addition. + + Tests: + - Adding string queries + - Adding UserQuery objects + - Handling memory failures gracefully + """ + turn_id = runner_with_memory._safe_create_turn() + + # Test string query + result = runner_with_memory._safe_add_user_query( + "Test query", turn_id, {"meta": "data"} + ) + assert result == turn_id + + # Test UserQuery object + user_query = UserQuery(query_str="Object query", metadata={"key": "value"}) + result = runner_with_memory._safe_add_user_query(user_query, turn_id) + assert result == turn_id + + # Test with None turn_id + result = runner_with_memory._safe_add_user_query("Query", None) + assert result is None + + # Test failure handling + runner_with_memory.conversation_memory.add_user_query = Mock( + side_effect=Exception("Add failed") + ) + + with patch('adalflow.components.agent.runner_flexible.log') as mock_log: + result = runner_with_memory._safe_add_user_query("Query", turn_id) + assert result is None + mock_log.warning.assert_called_once() + + def test_safe_add_assistant_response(self, runner_with_memory): + """Test safe assistant response addition. + + Tests: + - Adding string responses + - Adding AssistantResponse objects + - Handling memory failures gracefully + """ + turn_id = runner_with_memory._safe_create_turn() + + # Test string response + result = runner_with_memory._safe_add_assistant_response( + "Test response", turn_id, {"meta": "data"} + ) + assert result == turn_id + + # Test AssistantResponse object + assistant_response = AssistantResponse( + response_str="Object response", + metadata={"key": "value"} + ) + result = runner_with_memory._safe_add_assistant_response(assistant_response, turn_id) + assert result == turn_id + + # Test with None turn_id + result = runner_with_memory._safe_add_assistant_response("Response", None) + assert result is None + + # Test failure handling + runner_with_memory.conversation_memory.add_assistant_response = Mock( + side_effect=Exception("Add failed") + ) + + with patch('adalflow.components.agent.runner_flexible.log') as mock_log: + result = runner_with_memory._safe_add_assistant_response("Response", turn_id) + assert result is None + mock_log.warning.assert_called_once() + + def test_safe_get_conversation_history(self, runner_with_memory): + """Test safe conversation history retrieval. + + Tests: + - Normal history retrieval + - Returns empty string on failure + - Never raises exceptions + """ + # Add some conversation + turn_id = runner_with_memory._safe_create_turn() + runner_with_memory._safe_add_user_query("Hello", turn_id) + runner_with_memory._safe_add_assistant_response("Hi there", turn_id) + + # Get history + history = runner_with_memory._safe_get_conversation_history() + assert "Hello" in history + assert "Hi there" in history + + # Test failure handling - mock the call method directly + def mock_call(): + raise Exception("History error") + + original_call = runner_with_memory.conversation_memory.call + runner_with_memory.conversation_memory.call = mock_call + # Also need to replace __call__ since Python uses that + original_dunder_call = runner_with_memory.conversation_memory.__class__.__call__ + runner_with_memory.conversation_memory.__class__.__call__ = lambda self: mock_call() + + try: + with patch('adalflow.components.agent.runner_flexible.log') as mock_log: + result = runner_with_memory._safe_get_conversation_history() + assert result == "" + mock_log.warning.assert_called_once() + finally: + # Restore original methods + runner_with_memory.conversation_memory.call = original_call + runner_with_memory.conversation_memory.__class__.__call__ = original_dunder_call + + def test_runner_continues_on_memory_failure(self, mock_agent, flexible_memory): + """Test that runner continues execution even when memory operations fail. + + Tests: + - Runner completes successfully despite memory errors + - Warnings are logged but execution continues + - Final result is still produced + """ + runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory) + + # Make all memory operations fail + flexible_memory.create_turn = Mock(side_effect=Exception("Memory failed")) + flexible_memory.add_user_query = Mock(side_effect=Exception("Memory failed")) + flexible_memory.add_assistant_response = Mock(side_effect=Exception("Memory failed")) + flexible_memory.__call__ = Mock(side_effect=Exception("Memory failed")) + + # Setup successful agent execution + mock_function = Function(name="finish") + mock_function._is_answer_final = True + mock_function._answer = "Final answer" + + mock_agent.planner.call.return_value = GeneratorOutput( + data=mock_function, + error=None, + raw_response="raw" + ) + + # Run should complete successfully despite memory failures + with patch('adalflow.components.agent.runner_flexible.log') as mock_log: + result = runner.call({"input_str": "Test query"}) + + # Check runner completed successfully + assert isinstance(result, RunnerResult) + assert result.answer == "Final answer" + + # Check warnings were logged for memory failures + warning_calls = mock_log.warning.call_args_list + assert len(warning_calls) > 0 + assert any("Failed to" in str(call) for call in warning_calls) + + def test_async_runner_with_memory(self, mock_agent, flexible_memory): + """Test async runner execution with memory. + + Tests: + - Async execution works with memory + - Turn management in async context + - Memory operations are safe in async + """ + runner = RunnerFlexible(agent=mock_agent, conversation_memory=flexible_memory) + + # Setup async mock + mock_function = Function(name="finish") + mock_function._is_answer_final = True + mock_function._answer = "Async answer" + + async def mock_acall(*args, **kwargs): + return GeneratorOutput( + data=mock_function, + error=None, + raw_response="raw" + ) + + mock_agent.planner.acall = mock_acall + + # Run async + async def run_test(): + result = await runner.acall({"input_str": "Async query"}) + return result + + result = asyncio.run(run_test()) + + assert isinstance(result, RunnerResult) + assert result.answer == "Async answer" + + # Check memory has content + history = flexible_memory() + assert "Async query" in history + assert "Async answer" in history + + def test_memory_persistence_across_runs(self, runner_with_memory, mock_agent): + """Test that memory persists across multiple runner executions. + + Tests: + - Memory accumulates across runs + - Each run creates a new turn + - History is available in subsequent runs + """ + # Setup mock responses + def create_mock_response(answer): + mock_function = Function(name="finish") + mock_function._is_answer_final = True + mock_function._answer = answer + return GeneratorOutput(data=mock_function, error=None, raw_response="raw") + + # First run + mock_agent.planner.call.return_value = create_mock_response("Answer 1") + result1 = runner_with_memory.call({"input_str": "Query 1"}) + assert result1.answer == "Answer 1" + + # Second run + mock_agent.planner.call.return_value = create_mock_response("Answer 2") + result2 = runner_with_memory.call({"input_str": "Query 2"}) + assert result2.answer == "Answer 2" + + # Check memory has both conversations + history = runner_with_memory.conversation_memory() + assert "Query 1" in history + assert "Answer 1" in history + assert "Query 2" in history + assert "Answer 2" in history + + # Check we have 2 turns + assert runner_with_memory.conversation_memory.count_turns() == 2 + + def test_memory_with_complex_metadata(self, runner_with_memory, mock_agent): + """Test memory handling of complex metadata. + + Tests: + - Step history is properly stored + - Complex nested metadata works + - Metadata survives memory errors + """ + # Setup mock with multiple steps + mock_function1 = Function(name="search", args=["query"]) + mock_agent.planner.call.side_effect = [ + GeneratorOutput(data=mock_function1, error=None, raw_response="raw"), + GeneratorOutput( + data=Function(name="finish", _is_answer_final=True, _answer="Done"), + error=None, + raw_response="raw" + ) + ] + + mock_agent.tool_manager.execute_func.return_value = FunctionOutput( + name="search", + input=mock_function1, + output=ToolOutput(output="Search results", observation="Found items") + ) + + # Run with multiple steps + result = runner_with_memory.call({"input_str": "Multi-step query"}) + + # Check step history in metadata + all_messages = runner_with_memory.conversation_memory.get_all_messages() + assistant_messages = [m for m in all_messages if m.role == "assistant"] + + if assistant_messages: + last_assistant = assistant_messages[-1] + assert last_assistant.metadata is not None + assert "step_history" in last_assistant.metadata + assert len(last_assistant.metadata["step_history"]) > 0 + + def test_runner_without_memory(self, mock_agent): + """Test that runner works fine without memory. + + Tests: + - Runner can be initialized without memory + - All safe memory operations handle None memory + - Execution completes normally + """ + runner = RunnerFlexible(agent=mock_agent, conversation_memory=None) + + assert runner.conversation_memory is None + assert runner.use_conversation_memory is False + + # All safe operations should return appropriate defaults + assert runner._safe_create_turn() is None + assert runner._safe_add_user_query("query", "turn_id") is None + assert runner._safe_add_assistant_response("response", "turn_id") is None + assert runner._safe_get_conversation_history() == "" + + # Runner should work normally + mock_function = Function(name="finish") + mock_function._is_answer_final = True + mock_function._answer = "No memory answer" + + mock_agent.planner.call.return_value = GeneratorOutput( + data=mock_function, + error=None, + raw_response="raw" + ) + + result = runner.call({"input_str": "Test"}) + assert result.answer == "No memory answer" + + +class TestMemoryErrorRecovery: + """Test error recovery scenarios with memory failures.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock agent for testing.""" + agent = Mock(spec=Agent) + agent.max_steps = 3 + agent.answer_data_type = str + agent.is_thinking_model = False + + # Mock planner + agent.planner = Mock() + agent.planner.call = Mock() + agent.planner.acall = Mock() + agent.planner.get_prompt = Mock(return_value="test prompt") + agent.planner.estimated_token_count = 100 + + # Mock tool manager + agent.tool_manager = Mock() + agent.tool_manager.tools = [] + agent.tool_manager.execute_func = Mock() + agent.tool_manager.execute_func_async = Mock() + + return agent + + def test_memory_failure_during_turn_creation(self, mock_agent): + """Test recovery when turn creation fails. + + Tests: + - Runner continues without turn_id + - Subsequent operations handle None turn_id + - Execution completes successfully + """ + memory = FlexibleConversationMemory() + memory.create_turn = Mock(side_effect=Exception("Turn creation failed")) + + runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory) + + mock_function = Function(name="finish", _is_answer_final=True, _answer="Success") + mock_agent.planner.call.return_value = GeneratorOutput( + data=mock_function, error=None, raw_response="raw" + ) + + with patch('adalflow.components.agent.runner_flexible.log'): + result = runner.call({"input_str": "Test"}) + assert result.answer == "Success" + + def test_partial_memory_failure(self, mock_agent): + """Test when some memory operations fail but others succeed. + + Tests: + - Turn creation succeeds + - User query fails + - Assistant response fails + - Runner still completes + """ + memory = FlexibleConversationMemory() + original_create = memory.create_turn + memory.add_user_query = Mock(side_effect=Exception("User query failed")) + memory.add_assistant_response = Mock(side_effect=Exception("Assistant response failed")) + + runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory) + + mock_function = Function(name="finish", _is_answer_final=True, _answer="Partial success") + mock_agent.planner.call.return_value = GeneratorOutput( + data=mock_function, error=None, raw_response="raw" + ) + + with patch('adalflow.components.agent.runner_flexible.log'): + result = runner.call({"input_str": "Test"}) + assert result.answer == "Partial success" + + # Turn should have been created successfully + assert memory.count_turns() == 1 + + def test_memory_recovery_between_runs(self, mock_agent): + """Test that memory can recover between runs. + + Tests: + - First run with memory failure + - Memory recovers + - Second run succeeds with memory + """ + memory = FlexibleConversationMemory() + runner = RunnerFlexible(agent=mock_agent, conversation_memory=memory) + + mock_function = Function(name="finish", _is_answer_final=True, _answer="Answer") + mock_agent.planner.call.return_value = GeneratorOutput( + data=mock_function, error=None, raw_response="raw" + ) + + # First run with failing memory + original_create = memory.create_turn + memory.create_turn = Mock(side_effect=Exception("Temporary failure")) + + with patch('adalflow.components.agent.runner_flexible.log'): + result1 = runner.call({"input_str": "Query 1"}) + assert result1.answer == "Answer" + + # Restore memory functionality + memory.create_turn = original_create + + # Second run should work with memory + result2 = runner.call({"input_str": "Query 2"}) + assert result2.answer == "Answer" + + # Second run should have created a turn + assert memory.count_turns() == 1 + history = memory() + assert "Query 2" in history \ No newline at end of file