Skip to content

Commit 013d0ad

Browse files
authored
feat: enhance retriever, reorganizer & NebulaGraph handling (#200)
1 parent 2a6a0f2 commit 013d0ad

File tree

10 files changed

+357
-157
lines changed

10 files changed

+357
-157
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def _format_datetime(value: str | datetime) -> str:
4141
return str(value)
4242

4343

44+
def _normalize_datetime(val):
45+
"""
46+
Normalize datetime to ISO 8601 UTC string with +00:00.
47+
- If val is datetime object -> keep isoformat() (Neo4j)
48+
- If val is string without timezone -> append +00:00 (Nebula)
49+
- Otherwise just str()
50+
"""
51+
if hasattr(val, "isoformat"):
52+
return val.isoformat()
53+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
54+
return val + "+08:00"
55+
return str(val)
56+
57+
4458
class SessionPoolError(Exception):
4559
pass
4660

@@ -62,6 +76,7 @@ def __init__(
6276
self.hosts = hosts
6377
self.user = user
6478
self.password = password
79+
self.minsize = minsize
6580
self.maxsize = maxsize
6681
self.pool = Queue(maxsize)
6782
self.lock = Lock()
@@ -79,13 +94,13 @@ def _create_and_add_client(self):
7994
self.clients.append(client)
8095

8196
def get_client(self, timeout: float = 5.0):
82-
from nebulagraph_python import NebulaClient
83-
8497
try:
8598
return self.pool.get(timeout=timeout)
8699
except Empty:
87100
with self.lock:
88101
if len(self.clients) < self.maxsize:
102+
from nebulagraph_python import NebulaClient
103+
89104
client = NebulaClient(self.hosts, self.user, self.password)
90105
self.clients.append(client)
91106
return client
@@ -120,6 +135,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
120135

121136
return _ClientContext(self)
122137

138+
def reset_pool(self):
139+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
140+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
141+
with self.lock:
142+
for client in self.clients:
143+
try:
144+
client.close()
145+
except Exception:
146+
logger.error("Fail to close!!!")
147+
self.clients.clear()
148+
while not self.pool.empty():
149+
try:
150+
self.pool.get_nowait()
151+
except Empty:
152+
break
153+
for _ in range(self.minsize):
154+
self._create_and_add_client()
155+
logger.info("[Pool] Pool has been reset successfully.")
156+
123157

124158
class NebulaGraphDB(BaseGraphDB):
125159
"""
@@ -181,12 +215,27 @@ def __init__(self, config: NebulaGraphDBConfig):
181215

182216
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
183217
with self.pool.get() as client:
184-
if auto_set_db and self.db_name:
185-
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
186218
try:
219+
if auto_set_db and self.db_name:
220+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
187221
return client.execute(gql, timeout=timeout)
188-
except Exception:
222+
except Exception as e:
189223
logger.error(f"Fail to run gql {gql} trace: {traceback.format_exc()}")
224+
if "Session not found" in str(e):
225+
logger.warning("[execute_query] Session expired, replacing client.")
226+
try:
227+
client.close()
228+
except Exception:
229+
logger.error("Fail to close!!!!!")
230+
finally:
231+
if client in self.pool.clients:
232+
self.pool.clients.remove(client)
233+
from nebulagraph_python import NebulaClient
234+
235+
new_client = NebulaClient(self.pool.hosts, self.pool.user, self.pool.password)
236+
self.pool.clients.append(new_client)
237+
return new_client.execute(gql, timeout=timeout)
238+
raise
190239

191240
def close(self):
192241
self.pool.close()
@@ -923,9 +972,11 @@ def clear(self) -> None:
923972
except Exception as e:
924973
logger.error(f"[ERROR] Failed to clear database: {e}")
925974

926-
def export_graph(self) -> dict[str, Any]:
975+
def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
927976
"""
928977
Export all graph nodes and edges in a structured form.
978+
Args:
979+
include_embedding (bool): Whether to include the large embedding field.
929980
930981
Returns:
931982
{
@@ -942,12 +993,41 @@ def export_graph(self) -> dict[str, Any]:
942993
edge_query += f' WHERE r.user_name = "{username}"'
943994

944995
try:
945-
full_node_query = f"{node_query} RETURN n"
946-
node_result = self.execute_query(full_node_query)
996+
if include_embedding:
997+
return_fields = "n"
998+
else:
999+
return_fields = ",".join(
1000+
[
1001+
"n.id AS id",
1002+
"n.memory AS memory",
1003+
"n.user_name AS user_name",
1004+
"n.user_id AS user_id",
1005+
"n.session_id AS session_id",
1006+
"n.status AS status",
1007+
"n.key AS key",
1008+
"n.confidence AS confidence",
1009+
"n.tags AS tags",
1010+
"n.created_at AS created_at",
1011+
"n.updated_at AS updated_at",
1012+
"n.memory_type AS memory_type",
1013+
"n.sources AS sources",
1014+
"n.source AS source",
1015+
"n.node_type AS node_type",
1016+
"n.visibility AS visibility",
1017+
"n.usage AS usage",
1018+
"n.background AS background",
1019+
]
1020+
)
1021+
1022+
full_node_query = f"{node_query} RETURN {return_fields}"
1023+
node_result = self.execute_query(full_node_query, timeout=20)
9471024
nodes = []
1025+
logger.debug(f"Debugging: {node_result}")
9481026
for row in node_result:
949-
node_wrapper = row.values()[0].as_node()
950-
props = node_wrapper.get_properties()
1027+
if include_embedding:
1028+
props = row.values()[0].as_node().get_properties()
1029+
else:
1030+
props = {k: v.value for k, v in row.items()}
9511031

9521032
node = self._parse_node(props)
9531033
nodes.append(node)
@@ -956,7 +1036,7 @@ def export_graph(self) -> dict[str, Any]:
9561036

9571037
try:
9581038
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
959-
edge_result = self.execute_query(full_edge_query)
1039+
edge_result = self.execute_query(full_edge_query, timeout=20)
9601040
edges = [
9611041
{
9621042
"source": row.values()[0].value,
@@ -1023,6 +1103,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
10231103
MATCH (n@Memory)
10241104
{where_clause}
10251105
RETURN n
1106+
LIMIT 100
10261107
"""
10271108
nodes = []
10281109
try:
@@ -1065,7 +1146,7 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
10651146
node_props = rec["n"].as_node().get_properties()
10661147
candidates.append(self._parse_node(node_props))
10671148
except Exception as e:
1068-
logger.error(f"Failed : {e}")
1149+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
10691150
return candidates
10701151

10711152
def drop_database(self) -> None:
@@ -1318,15 +1399,17 @@ def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
13181399
parsed = {k: self._parse_value(v) for k, v in props.items()}
13191400

13201401
for tf in ("created_at", "updated_at"):
1321-
if tf in parsed and hasattr(parsed[tf], "isoformat"):
1322-
parsed[tf] = parsed[tf].isoformat()
1402+
if tf in parsed and parsed[tf] is not None:
1403+
parsed[tf] = _normalize_datetime(parsed[tf])
13231404

13241405
node_id = parsed.pop("id")
13251406
memory = parsed.pop("memory", "")
13261407
parsed.pop("user_name", None)
13271408
metadata = parsed
13281409
metadata["type"] = metadata.pop("node_type")
1329-
metadata["embedding"] = metadata.pop(self.dim_field)
1410+
1411+
if self.dim_field in metadata:
1412+
metadata["embedding"] = metadata.pop(self.dim_field)
13301413

13311414
return {"id": node_id, "memory": memory, "metadata": metadata}
13321415

src/memos/mem_os/utils/format_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,23 @@ def convert_graph_to_tree_forworkmem(
570570
else:
571571
other_roots.append(root_id)
572572

573-
def build_tree(node_id: str) -> dict[str, Any]:
574-
"""Recursively build tree structure"""
573+
def build_tree(node_id: str, visited=None) -> dict[str, Any] | None:
574+
"""Recursively build tree structure with cycle detection"""
575+
if visited is None:
576+
visited = set()
577+
578+
if node_id in visited:
579+
logger.warning(f"[build_tree] Detected cycle at node {node_id}, skipping.")
580+
return None
581+
visited.add(node_id)
582+
575583
if node_id not in node_map:
576584
return None
577585

578586
children_ids = children_map.get(node_id, [])
579587
children = []
580588
for child_id in children_ids:
581-
child_tree = build_tree(child_id)
589+
child_tree = build_tree(child_id, visited)
582590
if child_tree:
583591
children.append(child_tree)
584592

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def __init__(
3939
if not memory_size:
4040
self.memory_size = {
4141
"WorkingMemory": 20,
42-
"LongTermMemory": 10000,
43-
"UserMemory": 10000,
42+
"LongTermMemory": 1500,
43+
"UserMemory": 480,
4444
}
4545
self._threshold = threshold
4646
self.is_reorganize = is_reorganize

src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5
7373
results["sequence_links"].extend(seq)
7474
"""
7575

76+
"""
7677
# 4) Aggregate
7778
agg = self._detect_aggregate_node_for_group(node, nearest, min_group_size=5)
7879
if agg:
7980
results["aggregate_nodes"].append(agg)
81+
"""
8082

8183
except Exception as e:
8284
logger.error(

0 commit comments

Comments
 (0)