55
66from code_tokenize .tokens import Token
77
8+
89class CodeGraph :
910
1011 def __init__ (self , root_node , tokens , lang = "python" ):
1112
1213 self .tokens = tokens
13- self .lang = lang
14+ self .root_node = root_node
15+ self .lang = lang
1416
15- # Internal container
16- self ._ast_nodes = {} # Nodes indexed by an AST node
17- self ._anonymous_nodes = [] # Unindexed nodes, can only be indexed by traversal
17+ self .__nodes = [] # General container for all nodes
18+ self .__ast_nodes = {} # Nodes indexed by AST
1819
20+ # Init graph
1921 self .token_nodes = []
2022
2123 prev_token = self ._add_token (tokens [0 ])
@@ -27,71 +29,69 @@ def __init__(self, root_node, tokens, lang = "python"):
2729 prev_token = token_node
2830 self .token_nodes .append (token_node )
2931
30- self .root_node = self .add_or_get_node (root_node )
32+ self .root_node = self .add_node (root_node )
3133
32- # Helper methods --------------------------------
34+ # Add nodes -------------------------------- --------------------------------
3335
34- def _add_anonymous (self , node_obj ):
35- self ._anonymous_nodes .append (node_obj )
36- return node_obj
36+ def _add_node (self , node ):
3737
38- def _add_ast_node (self , ast_node , node_obj = None ):
39- ast_node_key = node_key (ast_node )
38+ if hasattr (node , "_graph_idx" ):
39+ # Node is already assigned to a graph
40+ graph_idx = node ._graph_idx
41+ assert self .__nodes [graph_idx ] == node , "Node is already assigned to another graph"
42+ return node
4043
41- if ast_node_key in self ._ast_nodes :
42- return self ._ast_nodes [ast_node_key ]
44+ if hasattr (node , "ast_node" ) and node .ast_node is not None :
45+ ast_key = node_key (node .ast_node )
46+ try :
47+ return self .__ast_nodes [ast_key ]
48+ except KeyError :
49+ self .__ast_nodes [ast_key ] = node
4350
44- if node_obj is None :
45- node_obj = SyntaxNode (ast_node )
46-
47- self ._ast_nodes [ast_node_key ] = node_obj
48- return node_obj
51+ node ._graph_idx = len (self .__nodes )
52+ self .__nodes .append (node )
53+ return node
4954
50- def _add_token (self , token ):
51- token_node = TokenNode ( token )
55+ def _add_ast_node (self , ast_node ):
56+ return self . _add_node ( SyntaxNode ( ast_node ) )
5257
53- if token_node .ast_node is not None :
54- return self ._add_ast_node (token_node .ast_node , token_node )
55- else :
56- return self ._add_anonymous (token_node )
58+ def _add_token (self , token ):
59+ return self ._add_node (TokenNode (token ))
5760
58- # ADD methods -----------------------------------
59- # Currently, we do not support removal
60-
6161 def add_node (self , node ):
62- if isinstance (node , Token ): return self ._add_token (node )
63- if isinstance (node , SyntaxNode ): return self ._add_ast_node ( node . ast_node , node )
64- if isinstance (node , Node ): return self ._add_anonymous ( node )
62+ if isinstance (node , Node ): return self ._add_node (node )
63+ if isinstance (node , Token ): return self ._add_token ( node )
64+ if isinstance (node , str ): return self ._add_node ( SymbolNode ( node ) )
6565
6666 return self ._add_ast_node (node )
6767
68- def add_or_get_node (self , node ):
69- if isinstance (node , SyntaxNode ):
70- return self .add_or_get_node (node .ast_node )
71-
72- if isinstance (node , Node ): return node
73- try :
74- return self ._add_ast_node (node )
75- except Exception :
76- raise ValueError ("Cannot add or get node %s. Only AST nodes can be indexed." % str (node ))
77-
7868 def add_relation (self , source_node , target_node , relation = "ast" , no_create = False ):
7969
8070 if no_create :
8171 if not self .has_node (source_node ): return
8272 if not self .has_node (target_node ): return
8373
84- source_node = self .add_or_get_node (source_node )
85- target_node = self .add_or_get_node (target_node )
74+ source_node = self .add_node (source_node )
75+ target_node = self .add_node (target_node )
8676 source_node .add_successor (target_node , relation )
8777
8878 # API GET methods-----------------------------------------
8979
9080 def has_node (self , ast_node ):
91- return node_key (ast_node ) in self ._ast_nodes
81+ try :
82+ return self .__nodes [ast_node ._graph_idx ] == ast_node
83+ except (IndexError , AttributeError ):
84+ return node_key (ast_node ) in self .__ast_nodes
85+
86+ def node_by_ast (self , ast_node ):
87+ if not self .has_node (ast_node ): return None
88+ return self .__ast_nodes [node_key (ast_node )]
89+
90+ def is_token (self , ast_node ):
91+ return hasattr (self .node_by_ast (ast_node ), "token" )
9292
9393 def nodes (self ):
94- return chain ( self ._ast_nodes . values (), self . _anonymous_nodes )
94+ return self .__nodes
9595
9696 def todot (self , file_name = None , edge_colors = None ):
9797 dotwriter = GraphToDot (self , edge_colors )
@@ -119,7 +119,7 @@ def tokens_only(self):
119119 # Internal GET methods -----------------------------------
120120
121121 def __len__ (self ):
122- return len (self ._ast_nodes ) + len ( self . _anonymous_nodes )
122+ return len (self .__nodes )
123123
124124 def __iter__ (self ):
125125 return iter (self .nodes ())
@@ -199,6 +199,9 @@ def predecessors(self):
199199 def predecessors_by_type (self , edge_type ):
200200 return self ._iter_predecessors (edge_type )
201201
202+ def clone (self ):
203+ return Node ()
204+
202205# Node types --------------------------------------------------------
203206
204207class SyntaxNode (Node ):
@@ -210,6 +213,9 @@ def __init__(self, ast_node):
210213 def node_name (self ):
211214 return self .ast_node .type
212215
216+ def clone (self ):
217+ return SyntaxNode (self .ast_node )
218+
213219 def __hash__ (self ):
214220 return hash (node_key (self .ast_node ))
215221
@@ -223,14 +229,16 @@ def __init__(self, token):
223229
224230 def node_name (self ):
225231 return self .token .text
232+
233+ def clone (self ):
234+ return TokenNode (self .token )
226235
227236 def __hash__ (self ):
228237 if self .ast_node is not None :
229238 return hash (node_key (self .ast_node ))
230239 return hash (self .token .text )
231240
232241
233-
234242class SymbolNode (Node ):
235243
236244 def __init__ (self , symbol ):
@@ -240,12 +248,19 @@ def __init__(self, symbol):
240248 def node_name (self ):
241249 return self .symbol
242250
251+ def clone (self ):
252+ return SymbolNode (self .symbol )
253+
254+ def __hash__ (self ):
255+ return hash (self .symbol )
256+
243257
244258# Utils --------------------------------------------------------
245259
246260def node_key (node ):
247261 start_pos , end_pos = node .start_point , node .end_point
248- return (node .type , start_pos [0 ], start_pos [1 ], end_pos [0 ], end_pos [1 ])
262+ child_count = node .child_count
263+ return (node .type , child_count , start_pos [0 ], start_pos [1 ], end_pos [0 ], end_pos [1 ])
249264
250265
251266class GraphToDot :
@@ -254,10 +269,6 @@ def __init__(self, graph, edge_colors = None):
254269 self .graph = graph
255270 self .edge_colors = {} if edge_colors is None else edge_colors
256271
257- def _map_nodes_to_ix (self ):
258- for ix , node in enumerate (self .graph ):
259- node ._dot_node_id = ix
260-
261272 def _dot_edge (self , source_id , rel_type , target_id ):
262273
263274 edge_color = self .edge_colors .get (rel_type , "black" )
@@ -266,7 +277,6 @@ def _dot_edge(self, source_id, rel_type, target_id):
266277 return f'node{ source_id } -> node{ target_id } [label="{ rel_type } " { edge_style } ];\n '
267278
268279 def run (self , writeable ):
269- self ._map_nodes_to_ix ()
270280
271281 def escape (token ):
272282 return token .replace ('"' , '\\ "' )
@@ -281,7 +291,7 @@ def escape(token):
281291 continue
282292 node_name = node .node_name ()
283293 writeable .write (
284- f'\t node{ node ._dot_node_id } [shape="rectangle", label="{ node_name } "];\n '
294+ f'\t node{ node ._graph_idx } [shape="rectangle", label="{ node_name } "];\n '
285295 )
286296
287297 # Tokens
@@ -291,12 +301,12 @@ def escape(token):
291301 for token_node in tokens :
292302 token_text = escape (token_node .node_name ())
293303 writeable .write (
294- f'\t \t node{ token_node ._dot_node_id } [shape="rectangle", label="{ token_text } "];\n '
304+ f'\t \t node{ token_node ._graph_idx } [shape="rectangle", label="{ token_text } "];\n '
295305 )
296306
297307 for _ , edge_type , next_token in token_node .successors_by_type ("next_token" ):
298308 next_token_edges .append (
299- self ._dot_edge (token_node ._dot_node_id , edge_type , next_token ._dot_node_id )
309+ self ._dot_edge (token_node ._graph_idx , edge_type , next_token ._graph_idx )
300310 )
301311
302312 for edge in next_token_edges :
@@ -308,21 +318,41 @@ def escape(token):
308318 for _ , edge_type , target_node in src_node .successors ():
309319 if edge_type == "next_token" : continue
310320 edge_str = self ._dot_edge (
311- src_node ._dot_node_id ,
321+ src_node ._graph_idx ,
312322 edge_type ,
313- target_node ._dot_node_id
323+ target_node ._graph_idx
314324 )
315325 writeable .write (f"\t { edge_str } " )
316326
317327 writeable .write ("}\n " )
318328
319- # Cleanup
320- for src_node in self .graph :
321- del src_node ._dot_node_id
322-
323329
324330# Propagate to leaves ----------------------------------------------------------------
325331
332+ def _children (root_node ):
333+ return [c for _ , edge_type , c in root_node .successors () if edge_type == "child" ]
334+
335+ def _bfs_token_search (root_node ):
336+ token_nodes = []
337+ queue = [root_node ]
338+
339+ while len (queue ) > 0 :
340+ current_node = queue .pop ()
341+
342+ if hasattr (current_node , "token" ): token_nodes .append (current_node )
343+
344+ if len (token_nodes ) == 0 :
345+ queue .extend (_children (current_node ))
346+
347+ return min (token_nodes , key = lambda c : c ._graph_idx )
348+
349+
350+ def _left_token_search (root_node ):
351+ while not hasattr (root_node , "token" ):
352+ root_node = min (_children (root_node ), key = lambda c : c ._graph_idx )
353+ return root_node
354+
355+
326356def _compute_representer (graph ):
327357 representer = {}
328358
@@ -331,20 +361,11 @@ def _compute_representer(graph):
331361
332362 while len (queue ) > 0 :
333363 current_node = queue .pop (- 1 )
334-
335- path = []
336- while not hasattr (current_node , "token" ):
337- path .append (current_node )
338- syntax_node = current_node .ast_node
339- children = [graph .add_or_get_node (c )
340- for c in syntax_node .children
341- if graph .has_node (c )]
342- if len (children ) == 0 : break
343- first , * others = children
344- queue .extend (others )
345- current_node = first
346-
347- for r in path : representer [r ] = current_node
364+
365+ if current_node in representer : continue
366+
367+ representer [current_node ] = _left_token_search (current_node )
368+ queue .extend (_children (current_node ))
348369
349370 return representer
350371
@@ -357,26 +378,25 @@ def graph_to_tokens_only(graph):
357378
358379 output = CodeGraph (tokens [0 ].ast_node , tokens , lang = graph .lang )
359380
360- for token in graph .tokens :
361- if not hasattr (token , "ast_node" ): continue
362- token_node = graph .add_or_get_node (token .ast_node )
363- output_node = output .add_or_get_node (token .ast_node )
381+ for ix , token_node in enumerate (graph .token_nodes ):
382+ output_node = output .token_nodes [ix ]
364383
365384 for _ , edge_type , successor in token_node .successors ():
366385 if edge_type in SYNTAX_TYPES : continue
367386 if not hasattr (successor , "token" ): continue
368- output_succ = output .add_or_get_node (successor .ast_node )
387+
388+ output_succ = output .token_nodes [successor ._graph_idx ]
369389 output .add_relation (output_node , output_succ , edge_type )
370-
371390
372391 for node , representer in representers .items ():
373- output_representer = output .add_or_get_node (representer .ast_node )
392+ token_idx = representer ._graph_idx
393+ output_representer = output .token_nodes [token_idx ]
374394
375395 for _ , edge_type , successor in node .successors ():
376396 if edge_type in SYNTAX_TYPES : continue
377397 if successor not in representers : continue
378398 successor_representer = representers [successor ]
379- output_successor_representer = output .add_or_get_node ( successor_representer .ast_node )
399+ output_successor_representer = output .token_nodes [ successor_representer ._graph_idx ]
380400 output .add_relation (output_representer ,
381401 output_successor_representer ,
382402 edge_type )
0 commit comments