Skip to content

Commit 0804fa2

Browse files
committed
Runtime tests / Debug and prepare for packaging
1 parent 0f46667 commit 0804fa2

File tree

6 files changed

+221
-4
lines changed

6 files changed

+221
-4
lines changed

code_graph/graph.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def __init__(self, root_node, tokens, lang = "python"):
1616
self._ast_nodes = {} # Nodes indexed by an AST node
1717
self._anonymous_nodes = [] # Unindexed nodes, can only be indexed by traversal
1818

19-
self.root_node = self.add_node(root_node)
20-
2119
prev_token = self._add_token(tokens[0])
2220
for token in tokens[1:]:
2321
token_node = self._add_token(token)
2422
prev_token.add_successor(token_node, "next_token")
2523
prev_token = token_node
24+
25+
self.root_node = self.add_or_get_node(root_node)
2626

2727
# Helper methods --------------------------------
2828

@@ -61,6 +61,9 @@ def add_node(self, node):
6161
return self._add_ast_node(node)
6262

6363
def add_or_get_node(self, node):
64+
if isinstance(node, SyntaxNode):
65+
return self.add_or_get_node(node.ast_node)
66+
6467
if isinstance(node, Node): return node
6568
try:
6669
return self._add_ast_node(node)
@@ -74,6 +77,9 @@ def add_relation(self, source_node, target_node, relation = "ast"):
7477

7578
# API GET methods-----------------------------------------
7679

80+
def has_node(self, ast_node):
81+
return node_key(ast_node) in self._ast_nodes
82+
7783
def nodes(self):
7884
return chain(self._ast_nodes.values(), self._anonymous_nodes)
7985

@@ -88,6 +94,17 @@ def todot(self, file_name = None, edge_colors = None):
8894
dotwriter.run(f)
8995
f.seek(0)
9096
return f.read()
97+
98+
def tokens_only(self):
99+
"""
100+
Computes a graph containing only tokens
101+
102+
Any edges of inner nodes will be propagated down to leaves.
103+
The first token related to an inner node acts as an representant
104+
105+
"""
106+
return graph_to_tokens_only(self)
107+
91108

92109
# Internal GET methods -----------------------------------
93110

@@ -290,4 +307,67 @@ def escape(token):
290307

291308
# Cleanup
292309
for src_node in self.graph:
293-
del src_node._dot_node_id
310+
del src_node._dot_node_id
311+
312+
313+
# Propagate to leaves ----------------------------------------------------------------
314+
315+
def _compute_representer(graph):
316+
representer = {}
317+
318+
root_node = graph.root_node
319+
queue = [root_node]
320+
321+
while len(queue) > 0:
322+
current_node = queue.pop(-1)
323+
324+
path = []
325+
while not hasattr(current_node, "token"):
326+
path.append(current_node)
327+
syntax_node = current_node.ast_node
328+
children = [graph.add_or_get_node(c)
329+
for c in syntax_node.children
330+
if graph.has_node(c)]
331+
if len(children) == 0: break
332+
first, *others = children
333+
queue.extend(others)
334+
current_node = first
335+
336+
for r in path: representer[r] = current_node
337+
338+
return representer
339+
340+
341+
SYNTAX_TYPES = {"child", "sibling"}
342+
343+
def graph_to_tokens_only(graph):
344+
representers = _compute_representer(graph)
345+
tokens = graph.tokens
346+
347+
output = CodeGraph(tokens[0].ast_node, tokens, lang = graph.lang)
348+
349+
for token in graph.tokens:
350+
if not hasattr(token, "ast_node"): continue
351+
token_node = graph.add_or_get_node(token.ast_node)
352+
output_node = output.add_or_get_node(token.ast_node)
353+
354+
for _, edge_type, successor in token_node.successors():
355+
if edge_type in SYNTAX_TYPES: continue
356+
if not hasattr(successor, "token"): continue
357+
output_succ = output.add_or_get_node(successor.ast_node)
358+
output.add_relation(output_node, output_succ, edge_type)
359+
360+
361+
for node, representer in representers.items():
362+
output_representer = output.add_or_get_node(representer.ast_node)
363+
364+
for _, edge_type, successor in node.successors():
365+
if edge_type in SYNTAX_TYPES: continue
366+
if successor not in representers: continue
367+
successor_representer = representers[successor]
368+
output_successor_representer = output.add_or_get_node(successor_representer.ast_node)
369+
output.add_relation(output_representer,
370+
output_successor_representer,
371+
edge_type)
372+
373+
return output

code_graph/visitor.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@ class ASTVisitor:
44
def __init__(self):
55
self._ast_handler = None
66

7+
# Error handling ------------------------------------------------
8+
9+
def visit_ERROR(self, node):
10+
"""
11+
An ERROR node is introduced if the parser reacts to an syntax error.
12+
13+
The subtree rooted an ERROR node might node be conventional
14+
or a tree (might include cycles).
15+
The walk function however assumes a tree as input and will
16+
run in an infinite loop for errors.
17+
18+
Therefore, default strategy is to skip error nodes.
19+
Can be overriden by subclasses.
20+
21+
"""
22+
return False
23+
24+
725
# Custom handler ------------------------------------------------
826

927
@staticmethod
@@ -18,6 +36,9 @@ def _parse_visit_pattern(name):
1836
# or function is the node type
1937
# and definition is the edge type
2038
# Therefore, we register both
39+
# Assumption:
40+
# - No node type function with edge definition
41+
# - or no node type function_definition
2142

2243
atomic_name = "_".join(parts[1:])
2344
node_name = "_".join(parts[1:-1])
@@ -113,18 +134,36 @@ def walk(self, root_node):
113134
if next_node is None:
114135
next_node = self._next_sibling(current_node)
115136

137+
previous_node = current_node
138+
116139
# Step 4: Go up until sibling exists
117140
while next_node is None and current_node.parent is not None:
118141
current_node = current_node.parent
119-
if current_node == root_node: break
142+
if node_equal(current_node, root_node): break
120143
next_node = self._next_sibling(current_node)
121144

145+
if node_equal(previous_node, next_node):
146+
# A loop can occur if the program is not syntactically correct
147+
# Is this enough?
148+
next_node = None
149+
122150
current_node = next_node
123151

124152
def __call__(self, root_node):
125153
return self.walk(root_node)
126154

127155

156+
# Helper --------------------------------
157+
158+
def node_equal(n1, n2):
159+
try:
160+
return (n1.type == n2.type
161+
and n1.start_point == n2.start_point
162+
and n1.end_point == n2.end_point)
163+
except AttributeError:
164+
return n1 == n2
165+
166+
128167
# Compositions ----------------------------------------------------------------
129168

130169
class VisitorComposition(ASTVisitor):

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
code_tokenize>=0.1.0

run_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import argparse
3+
import json
4+
5+
from time import time
6+
7+
from tqdm import tqdm
8+
9+
from glob import glob
10+
11+
import code_graph as cg
12+
13+
def load_examples(files):
14+
for n, file in enumerate(files):
15+
name = os.path.basename(file)
16+
desc = "File %d / %d: %s" % (n+1, len(files), name)
17+
total = sum(1 for _ in open(file, "r"))
18+
with open(file, "r") as lines:
19+
for line in tqdm(lines, total = total, desc = desc):
20+
yield json.loads(line)
21+
22+
def tokens_to_text(tokens):
23+
method_text = " ".join(tokens)
24+
return "public class Test{\n%s\n}" % method_text
25+
26+
27+
def main():
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("input_dir")
30+
parser.add_argument("result_file")
31+
args = parser.parse_args()
32+
33+
if os.path.isfile(args.input_dir):
34+
files = [args.input_dir]
35+
else:
36+
files = glob(os.path.join(args.input_dir, "*.jsonl"))
37+
38+
run_times = open(args.result_file, "w")
39+
40+
for example in load_examples(files):
41+
tokens = example["tokens"]
42+
length = len(tokens)
43+
code = tokens_to_text(tokens)
44+
45+
start_time = time()
46+
47+
try:
48+
graph = cg.codegraph(code, lang = "java", syntax_error = "raise")
49+
except Exception as e:
50+
print(e)
51+
continue
52+
53+
end_time = time() - start_time
54+
55+
run_times.write(json.dumps([length, len(graph), end_time]) + "\n")
56+
57+
if __name__ == '__main__':
58+
main()

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[metadata]
2+
description-file = README.md
3+
ong_description_content_type = text/markdown

setup.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from setuptools import setup
2+
3+
with open("README.md", "r") as f:
4+
long_description = f.read()
5+
6+
setup(
7+
name = 'code_graph',
8+
packages = ['code_graph'],
9+
version = '0.0.1',
10+
license='MIT',
11+
description = 'Fast program graph generation in Python',
12+
long_description = long_description,
13+
long_description_content_type="text/markdown",
14+
author = 'Cedric Richter',
15+
author_email = 'cedricr.upb@gmail.com',
16+
url = 'https://github.com/cedricrupb/code_graph',
17+
download_url = '',
18+
keywords = ['code', 'graph', 'program', 'language processing'],
19+
install_requires=[
20+
'tree_sitter',
21+
'GitPython',
22+
'requests',
23+
'code_tokenize'
24+
],
25+
classifiers=[
26+
'Development Status :: 3 - Alpha',
27+
'Intended Audience :: Developers',
28+
'Topic :: Software Development :: Build Tools',
29+
'License :: OSI Approved :: MIT License',
30+
'Programming Language :: Python :: 3',
31+
'Programming Language :: Python :: 3.6',
32+
'Programming Language :: Python :: 3.7',
33+
'Programming Language :: Python :: 3.8',
34+
'Programming Language :: Python :: 3.9',
35+
],
36+
)

0 commit comments

Comments
 (0)