Skip to content

Commit 739868e

Browse files
committed
Reworked implementation to be based on code_ast
1 parent 9b7cdf3 commit 739868e

File tree

9 files changed

+386
-271
lines changed

9 files changed

+386
-271
lines changed

code_graph/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66
from .javalang import javalang_analyses, java_preprocess
77

88

9+
DEFAULT_ANALYSES = ["ast", "cfg", "dataflow"]
10+
11+
912
def codegraph(source_code, lang = "guess", analyses = None, **kwargs):
1013
root_node, tokens = preprocess_code(source_code, lang, **kwargs)
1114

1215
graph_analyses = load_lang_analyses(tokens[0].config.lang)
1316

14-
if analyses is None:
15-
analyses = graph_analyses.keys()
16-
else:
17-
assert all(a in graph_analyses.keys() for a in analyses), \
18-
"Not all analyses are supported. Available analyses are: %s" % ", ".join(GRAPH_ANALYSES.keys())
19-
17+
if analyses is None: analyses = DEFAULT_ANALYSES
18+
19+
assert all(a in graph_analyses.keys() for a in analyses), \
20+
"Not all analyses are supported. Available analyses are: %s" % ", ".join(graph_analyses.keys())
21+
2022
graph = CodeGraph(root_node, tokens, lang = lang)
2123

2224
for analysis in analyses:

code_graph/ast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ def __init__(self, graph):
1010
def visit(self, ast_node):
1111
graph = self.graph
1212

13-
for child in ast_node.children:
14-
graph.add_relation(ast_node, child, "child")
13+
if not graph.is_token(ast_node):
14+
for child in ast_node.children:
15+
graph.add_relation(ast_node, child, "child")
1516

1617
prev_sibling = ast_node.prev_sibling
1718
if prev_sibling is not None:

code_graph/graph.py

Lines changed: 103 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55

66
from code_tokenize.tokens import Token
77

8+
89
class CodeGraph:
910

1011
def __init__(self, root_node, tokens, lang = "python"):
1112

1213
self.tokens = tokens
13-
self.lang = lang
14+
self.root_node = root_node
15+
self.lang = lang
1416

15-
# Internal container
16-
self._ast_nodes = {} # Nodes indexed by an AST node
17-
self._anonymous_nodes = [] # Unindexed nodes, can only be indexed by traversal
17+
self.__nodes = [] # General container for all nodes
18+
self.__ast_nodes = {} # Nodes indexed by AST
1819

20+
# Init graph
1921
self.token_nodes = []
2022

2123
prev_token = self._add_token(tokens[0])
@@ -27,71 +29,69 @@ def __init__(self, root_node, tokens, lang = "python"):
2729
prev_token = token_node
2830
self.token_nodes.append(token_node)
2931

30-
self.root_node = self.add_or_get_node(root_node)
32+
self.root_node = self.add_node(root_node)
3133

32-
# Helper methods --------------------------------
34+
# Add nodes ----------------------------------------------------------------
3335

34-
def _add_anonymous(self, node_obj):
35-
self._anonymous_nodes.append(node_obj)
36-
return node_obj
36+
def _add_node(self, node):
3737

38-
def _add_ast_node(self, ast_node, node_obj = None):
39-
ast_node_key = node_key(ast_node)
38+
if hasattr(node, "_graph_idx"):
39+
# Node is already assigned to a graph
40+
graph_idx = node._graph_idx
41+
assert self.__nodes[graph_idx] == node, "Node is already assigned to another graph"
42+
return node
4043

41-
if ast_node_key in self._ast_nodes:
42-
return self._ast_nodes[ast_node_key]
44+
if hasattr(node, "ast_node") and node.ast_node is not None:
45+
ast_key = node_key(node.ast_node)
46+
try:
47+
return self.__ast_nodes[ast_key]
48+
except KeyError:
49+
self.__ast_nodes[ast_key] = node
4350

44-
if node_obj is None:
45-
node_obj = SyntaxNode(ast_node)
46-
47-
self._ast_nodes[ast_node_key] = node_obj
48-
return node_obj
51+
node._graph_idx = len(self.__nodes)
52+
self.__nodes.append(node)
53+
return node
4954

50-
def _add_token(self, token):
51-
token_node = TokenNode(token)
55+
def _add_ast_node(self, ast_node):
56+
return self._add_node(SyntaxNode(ast_node))
5257

53-
if token_node.ast_node is not None:
54-
return self._add_ast_node(token_node.ast_node, token_node)
55-
else:
56-
return self._add_anonymous(token_node)
58+
def _add_token(self, token):
59+
return self._add_node(TokenNode(token))
5760

58-
# ADD methods -----------------------------------
59-
# Currently, we do not support removal
60-
6161
def add_node(self, node):
62-
if isinstance(node, Token): return self._add_token(node)
63-
if isinstance(node, SyntaxNode): return self._add_ast_node(node.ast_node, node)
64-
if isinstance(node, Node): return self._add_anonymous(node)
62+
if isinstance(node, Node): return self._add_node(node)
63+
if isinstance(node, Token): return self._add_token(node)
64+
if isinstance(node, str): return self._add_node(SymbolNode(node))
6565

6666
return self._add_ast_node(node)
6767

68-
def add_or_get_node(self, node):
69-
if isinstance(node, SyntaxNode):
70-
return self.add_or_get_node(node.ast_node)
71-
72-
if isinstance(node, Node): return node
73-
try:
74-
return self._add_ast_node(node)
75-
except Exception:
76-
raise ValueError("Cannot add or get node %s. Only AST nodes can be indexed." % str(node))
77-
7868
def add_relation(self, source_node, target_node, relation = "ast", no_create = False):
7969

8070
if no_create:
8171
if not self.has_node(source_node): return
8272
if not self.has_node(target_node): return
8373

84-
source_node = self.add_or_get_node(source_node)
85-
target_node = self.add_or_get_node(target_node)
74+
source_node = self.add_node(source_node)
75+
target_node = self.add_node(target_node)
8676
source_node.add_successor(target_node, relation)
8777

8878
# API GET methods-----------------------------------------
8979

9080
def has_node(self, ast_node):
91-
return node_key(ast_node) in self._ast_nodes
81+
try:
82+
return self.__nodes[ast_node._graph_idx] == ast_node
83+
except (IndexError, AttributeError):
84+
return node_key(ast_node) in self.__ast_nodes
85+
86+
def node_by_ast(self, ast_node):
87+
if not self.has_node(ast_node): return None
88+
return self.__ast_nodes[node_key(ast_node)]
89+
90+
def is_token(self, ast_node):
91+
return hasattr(self.node_by_ast(ast_node), "token")
9292

9393
def nodes(self):
94-
return chain(self._ast_nodes.values(), self._anonymous_nodes)
94+
return self.__nodes
9595

9696
def todot(self, file_name = None, edge_colors = None):
9797
dotwriter = GraphToDot(self, edge_colors)
@@ -119,7 +119,7 @@ def tokens_only(self):
119119
# Internal GET methods -----------------------------------
120120

121121
def __len__(self):
122-
return len(self._ast_nodes) + len(self._anonymous_nodes)
122+
return len(self.__nodes)
123123

124124
def __iter__(self):
125125
return iter(self.nodes())
@@ -199,6 +199,9 @@ def predecessors(self):
199199
def predecessors_by_type(self, edge_type):
200200
return self._iter_predecessors(edge_type)
201201

202+
def clone(self):
203+
return Node()
204+
202205
# Node types --------------------------------------------------------
203206

204207
class SyntaxNode(Node):
@@ -210,6 +213,9 @@ def __init__(self, ast_node):
210213
def node_name(self):
211214
return self.ast_node.type
212215

216+
def clone(self):
217+
return SyntaxNode(self.ast_node)
218+
213219
def __hash__(self):
214220
return hash(node_key(self.ast_node))
215221

@@ -223,14 +229,16 @@ def __init__(self, token):
223229

224230
def node_name(self):
225231
return self.token.text
232+
233+
def clone(self):
234+
return TokenNode(self.token)
226235

227236
def __hash__(self):
228237
if self.ast_node is not None:
229238
return hash(node_key(self.ast_node))
230239
return hash(self.token.text)
231240

232241

233-
234242
class SymbolNode(Node):
235243

236244
def __init__(self, symbol):
@@ -240,12 +248,19 @@ def __init__(self, symbol):
240248
def node_name(self):
241249
return self.symbol
242250

251+
def clone(self):
252+
return SymbolNode(self.symbol)
253+
254+
def __hash__(self):
255+
return hash(self.symbol)
256+
243257

244258
# Utils --------------------------------------------------------
245259

246260
def node_key(node):
247261
start_pos, end_pos = node.start_point, node.end_point
248-
return (node.type, start_pos[0], start_pos[1], end_pos[0], end_pos[1])
262+
child_count = node.child_count
263+
return (node.type, child_count, start_pos[0], start_pos[1], end_pos[0], end_pos[1])
249264

250265

251266
class GraphToDot:
@@ -254,10 +269,6 @@ def __init__(self, graph, edge_colors = None):
254269
self.graph = graph
255270
self.edge_colors = {} if edge_colors is None else edge_colors
256271

257-
def _map_nodes_to_ix(self):
258-
for ix, node in enumerate(self.graph):
259-
node._dot_node_id = ix
260-
261272
def _dot_edge(self, source_id, rel_type, target_id):
262273

263274
edge_color = self.edge_colors.get(rel_type, "black")
@@ -266,7 +277,6 @@ def _dot_edge(self, source_id, rel_type, target_id):
266277
return f'node{source_id} -> node{target_id} [label="{rel_type}" {edge_style}];\n'
267278

268279
def run(self, writeable):
269-
self._map_nodes_to_ix()
270280

271281
def escape(token):
272282
return token.replace('"', '\\"')
@@ -281,7 +291,7 @@ def escape(token):
281291
continue
282292
node_name = node.node_name()
283293
writeable.write(
284-
f'\tnode{node._dot_node_id}[shape="rectangle", label="{node_name}"];\n'
294+
f'\tnode{node._graph_idx}[shape="rectangle", label="{node_name}"];\n'
285295
)
286296

287297
# Tokens
@@ -291,12 +301,12 @@ def escape(token):
291301
for token_node in tokens:
292302
token_text = escape(token_node.node_name())
293303
writeable.write(
294-
f'\t\tnode{token_node._dot_node_id}[shape="rectangle", label="{token_text}"];\n'
304+
f'\t\tnode{token_node._graph_idx}[shape="rectangle", label="{token_text}"];\n'
295305
)
296306

297307
for _, edge_type, next_token in token_node.successors_by_type("next_token"):
298308
next_token_edges.append(
299-
self._dot_edge(token_node._dot_node_id, edge_type, next_token._dot_node_id)
309+
self._dot_edge(token_node._graph_idx, edge_type, next_token._graph_idx)
300310
)
301311

302312
for edge in next_token_edges:
@@ -308,21 +318,41 @@ def escape(token):
308318
for _, edge_type, target_node in src_node.successors():
309319
if edge_type == "next_token": continue
310320
edge_str = self._dot_edge(
311-
src_node._dot_node_id,
321+
src_node._graph_idx,
312322
edge_type,
313-
target_node._dot_node_id
323+
target_node._graph_idx
314324
)
315325
writeable.write(f"\t{edge_str}")
316326

317327
writeable.write("}\n")
318328

319-
# Cleanup
320-
for src_node in self.graph:
321-
del src_node._dot_node_id
322-
323329

324330
# Propagate to leaves ----------------------------------------------------------------
325331

332+
def _children(root_node):
333+
return [c for _, edge_type, c in root_node.successors() if edge_type == "child"]
334+
335+
def _bfs_token_search(root_node):
336+
token_nodes = []
337+
queue = [root_node]
338+
339+
while len(queue) > 0:
340+
current_node = queue.pop()
341+
342+
if hasattr(current_node, "token"): token_nodes.append(current_node)
343+
344+
if len(token_nodes) == 0:
345+
queue.extend(_children(current_node))
346+
347+
return min(token_nodes, key =lambda c: c._graph_idx)
348+
349+
350+
def _left_token_search(root_node):
351+
while not hasattr(root_node, "token"):
352+
root_node = min(_children(root_node), key=lambda c: c._graph_idx)
353+
return root_node
354+
355+
326356
def _compute_representer(graph):
327357
representer = {}
328358

@@ -331,20 +361,11 @@ def _compute_representer(graph):
331361

332362
while len(queue) > 0:
333363
current_node = queue.pop(-1)
334-
335-
path = []
336-
while not hasattr(current_node, "token"):
337-
path.append(current_node)
338-
syntax_node = current_node.ast_node
339-
children = [graph.add_or_get_node(c)
340-
for c in syntax_node.children
341-
if graph.has_node(c)]
342-
if len(children) == 0: break
343-
first, *others = children
344-
queue.extend(others)
345-
current_node = first
346-
347-
for r in path: representer[r] = current_node
364+
365+
if current_node in representer: continue
366+
367+
representer[current_node] = _left_token_search(current_node)
368+
queue.extend(_children(current_node))
348369

349370
return representer
350371

@@ -357,26 +378,25 @@ def graph_to_tokens_only(graph):
357378

358379
output = CodeGraph(tokens[0].ast_node, tokens, lang = graph.lang)
359380

360-
for token in graph.tokens:
361-
if not hasattr(token, "ast_node"): continue
362-
token_node = graph.add_or_get_node(token.ast_node)
363-
output_node = output.add_or_get_node(token.ast_node)
381+
for ix, token_node in enumerate(graph.token_nodes):
382+
output_node = output.token_nodes[ix]
364383

365384
for _, edge_type, successor in token_node.successors():
366385
if edge_type in SYNTAX_TYPES: continue
367386
if not hasattr(successor, "token"): continue
368-
output_succ = output.add_or_get_node(successor.ast_node)
387+
388+
output_succ = output.token_nodes[successor._graph_idx]
369389
output.add_relation(output_node, output_succ, edge_type)
370-
371390

372391
for node, representer in representers.items():
373-
output_representer = output.add_or_get_node(representer.ast_node)
392+
token_idx = representer._graph_idx
393+
output_representer = output.token_nodes[token_idx]
374394

375395
for _, edge_type, successor in node.successors():
376396
if edge_type in SYNTAX_TYPES: continue
377397
if successor not in representers: continue
378398
successor_representer = representers[successor]
379-
output_successor_representer = output.add_or_get_node(successor_representer.ast_node)
399+
output_successor_representer = output.token_nodes[successor_representer._graph_idx]
380400
output.add_relation(output_representer,
381401
output_successor_representer,
382402
edge_type)

0 commit comments

Comments
 (0)