@@ -37,22 +37,22 @@ def apply(self, ugraph):
3737 life_span = defaultdict (lambda : [None , None ])
3838 for op_info in ugraph .ops_info .values ():
3939 for tensor in op_info .input_tensors :
40- ref_cnts [tensor . name ] += 1
40+ ref_cnts [tensor ] += 1
4141 for time_slot , op_name in enumerate (ugraph .topo_order ):
4242 op_info = ugraph .ops_info [op_name ]
4343 for tensor in op_info .output_tensors :
44- life_span [tensor . name ][0 ] = time_slot
44+ life_span [tensor ][0 ] = time_slot
4545 for tensor in op_info .input_tensors :
46- ref_cnts [tensor . name ] -= 1
47- if ref_cnts [tensor . name ] == 0 :
48- life_span [tensor . name ][1 ] = time_slot
46+ ref_cnts [tensor ] -= 1
47+ if ref_cnts [tensor ] == 0 :
48+ life_span [tensor ][1 ] = time_slot
4949 time_alloc_plan = {}
50- for tensor_name , (start , end ) in life_span .items ():
50+ for tensor_info , (start , end ) in life_span .items ():
5151 time_alloc = TimeslotAllocation (
5252 time_slot_start = start ,
5353 time_slot_end = end
5454 )
55- time_alloc_plan [tensor_name ] = time_alloc
55+ time_alloc_plan [tensor_info ] = time_alloc
5656 logger .info ('topo ordered tensor life span analysis done' )
5757 ugraph .attributes [self .KWARGS_NAMESCOPE ] = time_alloc_plan
5858
@@ -126,7 +126,7 @@ def apply(self, ugraph):
126126 continue
127127 # all output tensor should not overlap with tensors that's still alive
128128 for out_tensor , known_tensor in product (op_info .output_tensors , tensors_to_schedule ):
129- time_alloc = time_alloc_plan [known_tensor . name ]
129+ time_alloc = time_alloc_plan [known_tensor ]
130130 if time_slot in time_alloc :
131131 nonoverlap_map [out_tensor ].add (known_tensor )
132132 # all output tensors should not overlap with each other
@@ -178,11 +178,11 @@ def _solve_space_alloc(self, tensors_to_schedule, nonoverlap_map):
178178 var_end = model .NewIntVar (0 , self .max_pool_size , '{}_end' .format (tensor .name ))
179179 size = self ._compute_tensor_bytes_size (tensor )
180180 intv_var = model .NewIntervalVar (var_start , size , var_end , '{}_alloc' .format (tensor .name ))
181- inter_vars [tensor . name ] = intv_var
182- tensor_allocs [tensor . name ] = _VarMemorySpan (var_start , var_end , size )
181+ inter_vars [tensor ] = intv_var
182+ tensor_allocs [tensor ] = _VarMemorySpan (var_start , var_end , size )
183183 for tensor in tensors_to_schedule :
184- inter_var = inter_vars [tensor . name ]
185- nonoverlap_vars = [inter_vars [t . name ] for t in nonoverlap_map [tensor ]]
184+ inter_var = inter_vars [tensor ]
185+ nonoverlap_vars = [inter_vars [t ] for t in nonoverlap_map [tensor ]]
186186 for other in nonoverlap_vars :
187187 model .AddNoOverlap ([inter_var , other ])
188188 var_mempool_size = model .NewIntVar (0 , self .max_pool_size , 'mempool_size' )
@@ -195,8 +195,8 @@ def _solve_space_alloc(self, tensors_to_schedule, nonoverlap_map):
195195 opt_mempool_size = None
196196 if status == cp_model .OPTIMAL :
197197 opt_mempool_size = solver .Value (var_mempool_size )
198- for name , alloc in tensor_allocs .items ():
199- alloc_plan [name ] = SpaceAllocation (
198+ for entity , alloc in tensor_allocs .items ():
199+ alloc_plan [entity ] = SpaceAllocation (
200200 offset_start = solver .Value (alloc .start ),
201201 size = alloc .size ,
202202 data_alignment = self .data_alignment ,
@@ -307,7 +307,7 @@ def apply(self, ugraph):
307307 max_offset_end = allocate_table [in_o .name ]['offsetend' ]
308308 allocs .append (
309309 TimeSpaceAllocation (
310- entity_name = in_o . name ,
310+ entity = in_o ,
311311 time_alloc = time_alloc ,
312312 space_alloc = space_alloc ,
313313 )
@@ -328,7 +328,8 @@ def apply(self, ugraph):
328328 max_offset_end = allocate_table [out_o .name ]['offsetend' ]
329329 allocs .append (
330330 TimeSpaceAllocation (
331- entity_name = out_o .name ,
331+ entity = out_o
332+ ,
332333 time_alloc = time_alloc ,
333334 space_alloc = space_alloc ,
334335 )
0 commit comments