11import re
2+ from collections import defaultdict
23from itertools import chain
34from pathlib import Path
4- from collections import defaultdict
55
66from utensor_cgen .backend .base import BackendPart
7+ from utensor_cgen .backend .graph_lower .generic_graph_lower import \
8+ TopoOrderTensorTimeslotPlanner
79from utensor_cgen .backend .utensor .snippets .composer import Composer
810from utensor_cgen .backend .utensor .snippets .legacy import (
911 ContextGlobalArrayContainer , WeightSnippet )
1012from utensor_cgen .backend .utensor .snippets .rearch import (
11- DeclareRamTensorSnippet , DeclareRomTensorSnippet ,
12- FreeTensorSnippet , SimpleContainer , TimeSlotContainer
13- )
13+ DeclareRamTensorSnippet , DeclareRomTensorSnippet , FreeTensorSnippet ,
14+ SimpleContainer , TimeSlotContainer )
1415from utensor_cgen .backend .utensor .snippets .template_env import env
15- from utensor_cgen .backend .graph_lower .generic_graph_lower import TopoOrderTensorTimeslotPlanner
1616from utensor_cgen .logger import logger
1717from utensor_cgen .utils import Configuration , class_property
1818
@@ -183,8 +183,10 @@ def _time_slot_generate_files(
183183 ):
184184 template_vars = {}
185185 template_vars ['model_name' ] = ugraph .name
186- template_vars ['meta_data_pool_size' ] = self ._compute_meta_data_size (ugraph )
187- template_vars ['ram_data_pool_size' ] = self ._compute_ram_data_size (ugraph )
186+ (template_vars ['meta_data_pool_size' ],
187+ template_vars ['meta_dtype' ]) = self ._compute_meta_data_size (ugraph )
188+ (template_vars ['ram_data_pool_size' ],
189+ template_vars ['ram_dtype' ]) = self ._compute_ram_data_size (ugraph )
188190 template_vars ['placeholders' ] = placeholders
189191 template_vars ['out_tensor_var_names' ] = [
190192 tensor_var_map [tensor .name ] for tensor in chain (* [
@@ -349,16 +351,23 @@ def default_config(cls):
349351 return config
350352
351353 def _compute_meta_data_size (self , ugraph ):
352- # TODO: if mem_optimizer is None, use a default mem optimizer
353354 if self .meta_data_pool_size == 'auto' :
354- # TODO: compute actual meta data size with ugraph
355- size = 2048
355+ # NOTE: simple heuristic, num of tensors * 64, maybe more or less depending on target platform
356+ # NOTE: assuming user is using localCircularArenaAllocator
357+ # TODO: target aware estimation
358+ tensors = set ()
359+ for op_info in ugraph .ops_info .values ():
360+ tensors .update (op_info .input_tensors )
361+ tensors .update (op_info .output_tensors )
362+ size = len (tensors ) * 64
356363 else :
357364 size = self .meta_data_pool_size
358- return size
365+ dtype_str = self ._get_mem_pool_dtype_str (size )
366+ return size , dtype_str
359367
360368 def _compute_ram_data_size (self , ugraph ):
361- # TODO: if mem_optimizer is None, use a default mem optimizer
369+ # TODO: if tensor alloc plan is None, use a default mem estimator
370+ # NOTE: assuming user is using localCircularArenaAllocator
362371 if self .ram_data_pool_size == 'auto' :
363372 # TODO: compute actual ram data size with ugraph
364373 if '_tensor_alloc' in ugraph .attributes :
@@ -367,4 +376,12 @@ def _compute_ram_data_size(self, ugraph):
367376 size = 256
368377 else :
369378 size = self .ram_data_pool_size
370- return size
379+ dtype_str = self ._get_mem_pool_dtype_str (size )
380+ return size , dtype_str
381+
382+ @staticmethod
383+ def _get_mem_pool_dtype_str (size ):
384+ # NOTE: assuming user is using localCircularArenaAllocator
385+ if size > 2 ** 15 :
386+ return 'uint32_t'
387+ return 'uint16_t'
0 commit comments