Skip to content

Commit 73a90f6

Browse files
committed
Merge
2 parents 6ad3632 + 5d030fb commit 73a90f6

File tree

9 files changed

+330
-104
lines changed

9 files changed

+330
-104
lines changed

utensor_cgen/backend/graph_lower/alloc_plan.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __contains__(self, offset):
5353

5454
@attr.s
5555
class TimeSpaceAllocation(object):
56-
entity_name = attr.ib(validator=instance_of(six.string_types))
56+
entity = attr.ib()
5757
_time_alloc = attr.ib(validator=instance_of(TimeslotAllocation), repr=False)
5858
_space_alloc = attr.ib(validator=instance_of(SpaceAllocation), repr=False)
5959
time_slot_start = attr.ib(init=False)
@@ -70,11 +70,11 @@ def __attrs_post_init__(self):
7070
self.size = self._space_alloc.size
7171

7272
@classmethod
73-
def init(cls, entity_name, time_slot_start, time_slot_end, offset_start, size):
73+
def init(cls, entity, time_slot_start, time_slot_end, offset_start, size):
7474
time_alloc = TimeslotAllocation(time_slot_start, time_slot_end)
7575
space_alloc = SpaceAllocation(offset_start, size)
7676
return cls(
77-
entity_name=entity_name,
77+
entity=entity,
7878
time_alloc=time_alloc,
7979
space_alloc=space_alloc
8080
)
@@ -94,30 +94,30 @@ def __init__(self, allocs, total_size):
9494
raise ValueError(
9595
'expecting value of {} of type {}, get {}'.format(k, TimeSpaceAllocation, type(v))
9696
)
97-
self.plan = {alloc.entity_name: alloc for alloc in allocs}
97+
self.plan = {alloc.entity: alloc for alloc in allocs}
9898
self.total_size = total_size
9999

100-
def __setitem__(self, entity_name, alloc):
100+
def __setitem__(self, entity, alloc):
101101
if not isinstance(alloc, TimeSpaceAllocation):
102102
raise ValueError(
103103
'the value should be of type {}, get {}'.format(TimeSpaceAllocation, type(alloc))
104104
)
105-
if entity_name in self._plan:
105+
if entity in self._plan:
106106
logger.warning(
107-
'duplicate entity_name detected: {}'.format(entity_name)
107+
'duplicate entity detected: {}'.format(entity)
108108
)
109-
self._plan[entity_name] = alloc
109+
self._plan[entity] = alloc
110110

111-
def __getitem__(self, entity_name):
112-
if entity_name not in self.plan:
113-
raise KeyError('%s not found' % entity_name)
114-
return self.plan[entity_name]
111+
def __getitem__(self, entity):
112+
if entity not in self.plan:
113+
raise KeyError('%s not found' % entity)
114+
return self.plan[entity]
115115

116-
def __contains__(self, entity_name):
117-
return entity_name in self.plan
116+
def __contains__(self, entity):
117+
return entity in self.plan
118118

119-
def __delitem__(self, entity_name):
120-
del self.plan[entity_name]
119+
def __delitem__(self, entity):
120+
del self.plan[entity]
121121

122122
def __getattr__(self, attr_name):
123123
return getattr(self.plan, attr_name)

utensor_cgen/backend/graph_lower/generic_graph_lower.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,22 @@ def apply(self, ugraph):
3737
life_span = defaultdict(lambda: [None, None])
3838
for op_info in ugraph.ops_info.values():
3939
for tensor in op_info.input_tensors:
40-
ref_cnts[tensor.name] += 1
40+
ref_cnts[tensor] += 1
4141
for time_slot, op_name in enumerate(ugraph.topo_order):
4242
op_info = ugraph.ops_info[op_name]
4343
for tensor in op_info.output_tensors:
44-
life_span[tensor.name][0] = time_slot
44+
life_span[tensor][0] = time_slot
4545
for tensor in op_info.input_tensors:
46-
ref_cnts[tensor.name] -= 1
47-
if ref_cnts[tensor.name] == 0:
48-
life_span[tensor.name][1] = time_slot
46+
ref_cnts[tensor] -= 1
47+
if ref_cnts[tensor] == 0:
48+
life_span[tensor][1] = time_slot
4949
time_alloc_plan = {}
50-
for tensor_name, (start, end) in life_span.items():
50+
for tensor_info, (start, end) in life_span.items():
5151
time_alloc = TimeslotAllocation(
5252
time_slot_start=start,
5353
time_slot_end=end
5454
)
55-
time_alloc_plan[tensor_name] = time_alloc
55+
time_alloc_plan[tensor_info] = time_alloc
5656
logger.info('topo ordered tensor life span analysis done')
5757
ugraph.attributes[self.KWARGS_NAMESCOPE] = time_alloc_plan
5858

@@ -126,7 +126,7 @@ def apply(self, ugraph):
126126
continue
127127
# all output tensor should not overlap with tensors that's still alive
128128
for out_tensor, known_tensor in product(op_info.output_tensors, tensors_to_schedule):
129-
time_alloc = time_alloc_plan[known_tensor.name]
129+
time_alloc = time_alloc_plan[known_tensor]
130130
if time_slot in time_alloc:
131131
nonoverlap_map[out_tensor].add(known_tensor)
132132
# all output tensors should not overlap with each other
@@ -178,11 +178,11 @@ def _solve_space_alloc(self, tensors_to_schedule, nonoverlap_map):
178178
var_end = model.NewIntVar(0, self.max_pool_size, '{}_end'.format(tensor.name))
179179
size = self._compute_tensor_bytes_size(tensor)
180180
intv_var = model.NewIntervalVar(var_start, size, var_end, '{}_alloc'.format(tensor.name))
181-
inter_vars[tensor.name] = intv_var
182-
tensor_allocs[tensor.name] = _VarMemorySpan(var_start, var_end, size)
181+
inter_vars[tensor] = intv_var
182+
tensor_allocs[tensor] = _VarMemorySpan(var_start, var_end, size)
183183
for tensor in tensors_to_schedule:
184-
inter_var = inter_vars[tensor.name]
185-
nonoverlap_vars = [inter_vars[t.name] for t in nonoverlap_map[tensor]]
184+
inter_var = inter_vars[tensor]
185+
nonoverlap_vars = [inter_vars[t] for t in nonoverlap_map[tensor]]
186186
for other in nonoverlap_vars:
187187
model.AddNoOverlap([inter_var, other])
188188
var_mempool_size = model.NewIntVar(0, self.max_pool_size, 'mempool_size')
@@ -195,8 +195,8 @@ def _solve_space_alloc(self, tensors_to_schedule, nonoverlap_map):
195195
opt_mempool_size = None
196196
if status == cp_model.OPTIMAL:
197197
opt_mempool_size = solver.Value(var_mempool_size)
198-
for name, alloc in tensor_allocs.items():
199-
alloc_plan[name] = SpaceAllocation(
198+
for entity, alloc in tensor_allocs.items():
199+
alloc_plan[entity] = SpaceAllocation(
200200
offset_start=solver.Value(alloc.start),
201201
size=alloc.size,
202202
data_alignment=self.data_alignment,
@@ -307,7 +307,7 @@ def apply(self, ugraph):
307307
max_offset_end = allocate_table[in_o.name]['offsetend']
308308
allocs.append(
309309
TimeSpaceAllocation(
310-
entity_name=in_o.name,
310+
entity=in_o,
311311
time_alloc=time_alloc,
312312
space_alloc=space_alloc,
313313
)
@@ -328,7 +328,8 @@ def apply(self, ugraph):
328328
max_offset_end = allocate_table[out_o.name]['offsetend']
329329
allocs.append(
330330
TimeSpaceAllocation(
331-
entity_name=out_o.name,
331+
entity=out_o
332+
,
332333
time_alloc=time_alloc,
333334
space_alloc=space_alloc,
334335
)

0 commit comments

Comments
 (0)