Skip to content

Commit a175bf1

Browse files
tangg555CaralHsi
andauthored
feat & refactor: add search function feature to test scheduler on a modified locome benchmark, and slightly change the logic of query consume and query monitors (#204)
* feat & refactor: add search function feature to test scheduler on a modified locome benchmark, and slightly change the logic of query consume and query monitors * refactor: not showing userinput to working memory * fix bugs: fix a bug in retriever, and add new auth info for neo4j db * fix bugs & new feat: fix bugs in mem_scheduler examples, and remove initialize working memories (logically uneccessary). change the function parameters of search as the function input info as an addition --------- Co-authored-by: CaralHsi <caralhsi@gmail.com>
1 parent 013d0ad commit a175bf1

File tree

15 files changed

+207
-133
lines changed

15 files changed

+207
-133
lines changed

examples/mem_scheduler/memos_w_scheduler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def run_with_scheduler_init():
9191
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
9292

9393
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
94+
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
95+
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
96+
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
97+
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
98+
auth_config.graph_db.auto_create
99+
)
94100

95101
# Initialization
96102
mos = MOS(mos_config)
@@ -118,8 +124,7 @@ def run_with_scheduler_init():
118124
query = item["question"]
119125
print(f"Query:\n {query}\n")
120126
response = mos.chat(query=query, user_id=user_id)
121-
print(f"Answer:\n {response}")
122-
print("===== Chat End =====")
127+
print(f"Answer:\n {response}\n")
123128

124129
show_web_logs(mem_scheduler=mos.mem_scheduler)
125130

examples/mem_scheduler/memos_w_scheduler_for_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ def init_task():
103103
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
104104

105105
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
106+
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
107+
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
108+
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
109+
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
110+
auth_config.graph_db.auto_create
111+
)
106112

107113
# Initialization
108114
mos = MOSForTestScheduler(mos_config)

examples/mem_scheduler/try_schedule_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
151151
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
152152

153153
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
154+
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
155+
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
156+
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
157+
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
158+
auth_config.graph_db.auto_create
159+
)
154160

155161
# Initialization
156162
mos = MOSForTestScheduler(mos_config)

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __init__(self, config: BaseSchedulerConfig):
8484
# other attributes
8585
self._context_lock = threading.Lock()
8686
self.current_user_id: UserID | str | None = None
87+
self.current_mem_cube_id: MemCubeID | str | None = None
88+
self.current_mem_cube: GeneralMemCube | None = None
8789
self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None)
8890
self.auth_config = None
8991
self.rabbitmq_config = None
@@ -130,8 +132,8 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
130132
self.current_mem_cube_id = msg.mem_cube_id
131133
self.current_mem_cube = msg.mem_cube
132134

133-
def transform_memories_to_monitors(
134-
self, memories: list[TextualMemoryItem]
135+
def transform_working_memories_to_monitors(
136+
self, query_keywords, memories: list[TextualMemoryItem]
135137
) -> list[MemoryMonitorItem]:
136138
"""
137139
Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects
@@ -143,10 +145,6 @@ def transform_memories_to_monitors(
143145
Returns:
144146
List of MemoryMonitorItem objects with computed importance scores.
145147
"""
146-
query_keywords = self.monitor.query_monitors.get_keywords_collections()
147-
logger.debug(
148-
f"Processing {len(memories)} memories with {len(query_keywords)} query keywords"
149-
)
150148

151149
result = []
152150
mem_length = len(memories)
@@ -195,7 +193,8 @@ def replace_working_memory(
195193
text_mem_base: TreeTextMemory = text_mem_base
196194

197195
# process rerank memories with llm
198-
query_history = self.monitor.query_monitors.get_queries_with_timesort()
196+
query_monitor = self.monitor.query_monitors[user_id][mem_cube_id]
197+
query_history = query_monitor.get_queries_with_timesort()
199198
memories_with_new_order, rerank_success_flag = (
200199
self.retriever.process_and_rerank_memories(
201200
queries=query_history,
@@ -206,8 +205,13 @@ def replace_working_memory(
206205
)
207206

208207
# update working memory monitors
209-
new_working_memory_monitors = self.transform_memories_to_monitors(
210-
memories=memories_with_new_order
208+
query_keywords = query_monitor.get_keywords_collections()
209+
logger.debug(
210+
f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
211+
)
212+
new_working_memory_monitors = self.transform_working_memories_to_monitors(
213+
query_keywords=query_keywords,
214+
memories=memories_with_new_order,
211215
)
212216

213217
if not rerank_success_flag:
@@ -245,25 +249,6 @@ def replace_working_memory(
245249

246250
return memories_with_new_order
247251

248-
def initialize_working_memory_monitors(
249-
self,
250-
user_id: UserID | str,
251-
mem_cube_id: MemCubeID | str,
252-
mem_cube: GeneralMemCube,
253-
):
254-
text_mem_base: TreeTextMemory = mem_cube.text_mem
255-
working_memories = text_mem_base.get_working_memory()
256-
257-
working_memory_monitors = self.transform_memories_to_monitors(
258-
memories=working_memories,
259-
)
260-
self.monitor.update_working_memory_monitors(
261-
new_working_memory_monitors=working_memory_monitors,
262-
user_id=user_id,
263-
mem_cube_id=mem_cube_id,
264-
mem_cube=mem_cube,
265-
)
266-
267252
def update_activation_memory(
268253
self,
269254
new_memories: list[str | TextualMemoryItem],
@@ -367,13 +352,9 @@ def update_activation_memory_periodically(
367352
or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0
368353
):
369354
logger.warning(
370-
"No memories found in working_memory_monitors, initializing from current working_memories"
371-
)
372-
self.initialize_working_memory_monitors(
373-
user_id=user_id,
374-
mem_cube_id=mem_cube_id,
375-
mem_cube=mem_cube,
355+
"No memories found in working_memory_monitors, activation memory update is skipped"
376356
)
357+
return
377358

378359
self.monitor.update_activation_memory_monitors(
379360
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ANSWER_LABEL,
1010
DEFAULT_MAX_QUERY_KEY_WORDS,
1111
QUERY_LABEL,
12+
WORKING_MEMORY_TYPE,
1213
MemCubeID,
1314
UserID,
1415
)
@@ -35,11 +36,12 @@ def __init__(self, config: GeneralSchedulerConfig):
3536

3637
# for evaluation
3738
def search_for_eval(
38-
self,
39-
query: str,
40-
user_id: UserID | str,
41-
top_k: int,
42-
) -> list[str]:
39+
self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
40+
) -> (list[str], bool):
41+
self.monitor.register_query_monitor_if_not_exists(
42+
user_id=user_id, mem_cube_id=self.current_mem_cube_id
43+
)
44+
4345
query_keywords = self.monitor.extract_query_keywords(query=query)
4446
logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
4547

@@ -48,35 +50,61 @@ def search_for_eval(
4850
keywords=query_keywords,
4951
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
5052
)
51-
self.monitor.query_monitors.put(item=item)
52-
logger.debug(
53-
f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}."
54-
)
53+
query_monitor = self.monitor.query_monitors[user_id][self.current_mem_cube_id]
54+
query_monitor.put(item=item)
55+
logger.debug(f"Queries in monitor are {query_monitor.get_queries_with_timesort()}.")
5556

5657
queries = [query]
5758

5859
# recall
59-
cur_working_memory, new_candidates = self.process_session_turn(
60-
queries=queries,
61-
user_id=user_id,
62-
mem_cube_id=self.current_mem_cube_id,
63-
mem_cube=self.current_mem_cube,
64-
top_k=self.top_k,
65-
)
66-
logger.info(f"Processed {queries} and get {len(new_candidates)} new candidate memories.")
67-
68-
# rerank
69-
new_order_working_memory = self.replace_working_memory(
70-
user_id=user_id,
71-
mem_cube_id=self.current_mem_cube_id,
72-
mem_cube=self.current_mem_cube,
73-
original_memory=cur_working_memory,
74-
new_memory=new_candidates,
60+
mem_cube = self.current_mem_cube
61+
text_mem_base = mem_cube.text_mem
62+
63+
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
64+
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
65+
intent_result = self.monitor.detect_intent(
66+
q_list=queries, text_working_memory=text_working_memory
7567
)
76-
new_order_working_memory = new_order_working_memory[:top_k]
77-
logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")
7868

79-
return [m.memory for m in new_order_working_memory]
69+
if not scheduler_flag:
70+
return text_working_memory, intent_result["trigger_retrieval"]
71+
else:
72+
if intent_result["trigger_retrieval"]:
73+
missing_evidences = intent_result["missing_evidences"]
74+
num_evidence = len(missing_evidences)
75+
k_per_evidence = max(1, top_k // max(1, num_evidence))
76+
new_candidates = []
77+
for item in missing_evidences:
78+
logger.info(f"missing_evidences: {item}")
79+
results: list[TextualMemoryItem] = self.retriever.search(
80+
query=item,
81+
mem_cube=mem_cube,
82+
top_k=k_per_evidence,
83+
method=self.search_method,
84+
)
85+
logger.info(
86+
f"search results for {missing_evidences}: {[one.memory for one in results]}"
87+
)
88+
new_candidates.extend(results)
89+
print(
90+
f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories."
91+
)
92+
else:
93+
new_candidates = []
94+
print(f"intent_result: {intent_result}. not triggered")
95+
96+
# rerank
97+
new_order_working_memory = self.replace_working_memory(
98+
user_id=user_id,
99+
mem_cube_id=self.current_mem_cube_id,
100+
mem_cube=self.current_mem_cube,
101+
original_memory=cur_working_memory,
102+
new_memory=new_candidates,
103+
)
104+
new_order_working_memory = new_order_working_memory[:top_k]
105+
logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")
106+
107+
return [m.memory for m in new_order_working_memory], intent_result["trigger_retrieval"]
80108

81109
def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
82110
"""
@@ -105,6 +133,10 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
105133

106134
# update query monitors
107135
for msg in messages:
136+
self.monitor.register_query_monitor_if_not_exists(
137+
user_id=user_id, mem_cube_id=mem_cube_id
138+
)
139+
108140
query = msg.content
109141
query_keywords = self.monitor.extract_query_keywords(query=query)
110142
logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
@@ -114,9 +146,11 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
114146
keywords=query_keywords,
115147
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
116148
)
117-
self.monitor.query_monitors.put(item=item)
149+
150+
self.monitor.query_monitors[user_id][mem_cube_id].put(item=item)
118151
logger.debug(
119-
f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}."
152+
f"Queries in monitor are "
153+
f"{self.monitor.query_monitors[user_id][mem_cube_id].get_queries_with_timesort()}."
120154
)
121155

122156
queries = [msg.content for msg in messages]
@@ -215,6 +249,9 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
215249
mem_type = mem_item.metadata.memory_type
216250
mem_content = mem_item.memory
217251

252+
if mem_type == WORKING_MEMORY_TYPE:
253+
continue
254+
218255
self.log_adding_memory(
219256
memory=mem_content,
220257
memory_type=mem_type,
@@ -289,18 +326,20 @@ def process_session_turn(
289326
new_candidates = []
290327
for item in missing_evidences:
291328
logger.info(f"missing_evidences: {item}")
329+
info = {
330+
"user_id": user_id,
331+
"session_id": "",
332+
}
333+
292334
results: list[TextualMemoryItem] = self.retriever.search(
293-
query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method
335+
query=item,
336+
mem_cube=mem_cube,
337+
top_k=k_per_evidence,
338+
method=self.search_method,
339+
info=info,
294340
)
295341
logger.info(
296342
f"search results for {missing_evidences}: {[one.memory for one in results]}"
297343
)
298344
new_candidates.extend(results)
299-
300-
if len(new_candidates) == 0:
301-
logger.warning(
302-
f"As new_candidates is empty, new_candidates is set same to working_memory.\n"
303-
f"time_trigger_flag: {time_trigger_flag}; intent_result: {intent_result}"
304-
)
305-
new_candidates = cur_working_memory
306345
return cur_working_memory, new_candidates

src/memos/mem_scheduler/modules/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ def __init__(self):
1717

1818
self._chat_llm = None
1919
self._process_llm = None
20-
self.current_mem_cube_id: str | None = None
21-
self.current_mem_cube: GeneralMemCube | None = None
20+
2221
self.mem_cubes: dict[str, GeneralMemCube] = {}
2322

2423
def load_template(self, template_name: str) -> str:

0 commit comments

Comments
 (0)