@@ -41,6 +41,20 @@ def _format_datetime(value: str | datetime) -> str:
4141 return str (value )
4242
4343
44+ def _normalize_datetime (val ):
45+ """
46+ Normalize datetime to ISO 8601 UTC string with +00:00.
47+ - If val is datetime object -> keep isoformat() (Neo4j)
48+ - If val is string without timezone -> append +00:00 (Nebula)
49+ - Otherwise just str()
50+ """
51+ if hasattr (val , "isoformat" ):
52+ return val .isoformat ()
53+ if isinstance (val , str ) and not val .endswith (("+00:00" , "Z" , "+08:00" )):
54+ return val + "+08:00"
55+ return str (val )
56+
57+
4458class SessionPoolError (Exception ):
4559 pass
4660
@@ -62,6 +76,7 @@ def __init__(
6276 self .hosts = hosts
6377 self .user = user
6478 self .password = password
79+ self .minsize = minsize
6580 self .maxsize = maxsize
6681 self .pool = Queue (maxsize )
6782 self .lock = Lock ()
@@ -79,13 +94,13 @@ def _create_and_add_client(self):
7994 self .clients .append (client )
8095
8196 def get_client (self , timeout : float = 5.0 ):
82- from nebulagraph_python import NebulaClient
83-
8497 try :
8598 return self .pool .get (timeout = timeout )
8699 except Empty :
87100 with self .lock :
88101 if len (self .clients ) < self .maxsize :
102+ from nebulagraph_python import NebulaClient
103+
89104 client = NebulaClient (self .hosts , self .user , self .password )
90105 self .clients .append (client )
91106 return client
@@ -120,6 +135,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
120135
121136 return _ClientContext (self )
122137
138+ def reset_pool (self ):
139+ """⚠️ Emergency reset: Close all clients and clear the pool."""
140+ logger .warning ("[Pool] Resetting all clients. Existing sessions will be lost." )
141+ with self .lock :
142+ for client in self .clients :
143+ try :
144+ client .close ()
145+ except Exception :
146+ logger .error ("Fail to close!!!" )
147+ self .clients .clear ()
148+ while not self .pool .empty ():
149+ try :
150+ self .pool .get_nowait ()
151+ except Empty :
152+ break
153+ for _ in range (self .minsize ):
154+ self ._create_and_add_client ()
155+ logger .info ("[Pool] Pool has been reset successfully." )
156+
123157
124158class NebulaGraphDB (BaseGraphDB ):
125159 """
@@ -181,12 +215,27 @@ def __init__(self, config: NebulaGraphDBConfig):
181215
182216 def execute_query (self , gql : str , timeout : float = 5.0 , auto_set_db : bool = True ):
183217 with self .pool .get () as client :
184- if auto_set_db and self .db_name :
185- client .execute (f"SESSION SET GRAPH `{ self .db_name } `" )
186218 try :
219+ if auto_set_db and self .db_name :
220+ client .execute (f"SESSION SET GRAPH `{ self .db_name } `" )
187221 return client .execute (gql , timeout = timeout )
188- except Exception :
222+ except Exception as e :
189223 logger .error (f"Fail to run gql { gql } trace: { traceback .format_exc ()} " )
224+ if "Session not found" in str (e ):
225+ logger .warning ("[execute_query] Session expired, replacing client." )
226+ try :
227+ client .close ()
228+ except Exception :
229+ logger .error ("Fail to close!!!!!" )
230+ finally :
231+ if client in self .pool .clients :
232+ self .pool .clients .remove (client )
233+ from nebulagraph_python import NebulaClient
234+
235+ new_client = NebulaClient (self .pool .hosts , self .pool .user , self .pool .password )
236+ self .pool .clients .append (new_client )
237+ return new_client .execute (gql , timeout = timeout )
238+ raise
190239
191240 def close (self ):
192241 self .pool .close ()
@@ -923,9 +972,11 @@ def clear(self) -> None:
923972 except Exception as e :
924973 logger .error (f"[ERROR] Failed to clear database: { e } " )
925974
926- def export_graph (self ) -> dict [str , Any ]:
975+ def export_graph (self , include_embedding : bool = False ) -> dict [str , Any ]:
927976 """
928977 Export all graph nodes and edges in a structured form.
978+ Args:
979+ include_embedding (bool): Whether to include the large embedding field.
929980
930981 Returns:
931982 {
@@ -942,12 +993,41 @@ def export_graph(self) -> dict[str, Any]:
942993 edge_query += f' WHERE r.user_name = "{ username } "'
943994
944995 try :
945- full_node_query = f"{ node_query } RETURN n"
946- node_result = self .execute_query (full_node_query )
996+ if include_embedding :
997+ return_fields = "n"
998+ else :
999+ return_fields = "," .join (
1000+ [
1001+ "n.id AS id" ,
1002+ "n.memory AS memory" ,
1003+ "n.user_name AS user_name" ,
1004+ "n.user_id AS user_id" ,
1005+ "n.session_id AS session_id" ,
1006+ "n.status AS status" ,
1007+ "n.key AS key" ,
1008+ "n.confidence AS confidence" ,
1009+ "n.tags AS tags" ,
1010+ "n.created_at AS created_at" ,
1011+ "n.updated_at AS updated_at" ,
1012+ "n.memory_type AS memory_type" ,
1013+ "n.sources AS sources" ,
1014+ "n.source AS source" ,
1015+ "n.node_type AS node_type" ,
1016+ "n.visibility AS visibility" ,
1017+ "n.usage AS usage" ,
1018+ "n.background AS background" ,
1019+ ]
1020+ )
1021+
1022+ full_node_query = f"{ node_query } RETURN { return_fields } "
1023+ node_result = self .execute_query (full_node_query , timeout = 20 )
9471024 nodes = []
1025+ logger .debug (f"Debugging: { node_result } " )
9481026 for row in node_result :
949- node_wrapper = row .values ()[0 ].as_node ()
950- props = node_wrapper .get_properties ()
1027+ if include_embedding :
1028+ props = row .values ()[0 ].as_node ().get_properties ()
1029+ else :
1030+ props = {k : v .value for k , v in row .items ()}
9511031
9521032 node = self ._parse_node (props )
9531033 nodes .append (node )
@@ -956,7 +1036,7 @@ def export_graph(self) -> dict[str, Any]:
9561036
9571037 try :
9581038 full_edge_query = f"{ edge_query } RETURN a.id AS source, b.id AS target, type(r) as edge"
959- edge_result = self .execute_query (full_edge_query )
1039+ edge_result = self .execute_query (full_edge_query , timeout = 20 )
9601040 edges = [
9611041 {
9621042 "source" : row .values ()[0 ].value ,
@@ -1023,6 +1103,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
10231103 MATCH (n@Memory)
10241104 { where_clause }
10251105 RETURN n
1106+ LIMIT 100
10261107 """
10271108 nodes = []
10281109 try :
@@ -1065,7 +1146,7 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
10651146 node_props = rec ["n" ].as_node ().get_properties ()
10661147 candidates .append (self ._parse_node (node_props ))
10671148 except Exception as e :
1068- logger .error (f"Failed : { e } " )
1149+ logger .error (f"Failed : { e } , traceback: { traceback . format_exc () } " )
10691150 return candidates
10701151
10711152 def drop_database (self ) -> None :
@@ -1318,15 +1399,17 @@ def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
13181399 parsed = {k : self ._parse_value (v ) for k , v in props .items ()}
13191400
13201401 for tf in ("created_at" , "updated_at" ):
1321- if tf in parsed and hasattr ( parsed [tf ], "isoformat" ) :
1322- parsed [tf ] = parsed [tf ]. isoformat ( )
1402+ if tf in parsed and parsed [tf ] is not None :
1403+ parsed [tf ] = _normalize_datetime ( parsed [tf ])
13231404
13241405 node_id = parsed .pop ("id" )
13251406 memory = parsed .pop ("memory" , "" )
13261407 parsed .pop ("user_name" , None )
13271408 metadata = parsed
13281409 metadata ["type" ] = metadata .pop ("node_type" )
1329- metadata ["embedding" ] = metadata .pop (self .dim_field )
1410+
1411+ if self .dim_field in metadata :
1412+ metadata ["embedding" ] = metadata .pop (self .dim_field )
13301413
13311414 return {"id" : node_id , "memory" : memory , "metadata" : metadata }
13321415
0 commit comments