@@ -22,7 +22,10 @@ def __init__(self, config):
2222 def transform (self , ugraph ):
2323 logger .info ("Transforming graph: %s" , ugraph .name )
2424 logger .info ("Transform pipeline: %s" , ' -> ' .join (self .trans_methods ))
25- self ._check_non_quantized (ugraph )
25+ if not self ._check_generic (ugraph ):
26+ raise ValueError (
27+ 'the given graph is not generic:\n {}' .format (ugraph )
28+ )
2629 new_ugraph = self .transformer .transform (ugraph )
2730 new_ugraph .name = ugraph .name
2831 logger .info ('Graph transormation done' )
@@ -35,26 +38,9 @@ def transform(self, ugraph):
3538 return new_ugraph
3639
3740 @classmethod
38- def _check_non_quantized (cls , ugraph ):
39- is_quantized = False
40- quant_ops = set ([
41- "Dequantize" , "QuantizedMaxPool" ,
42- "QuantizeV2" , "QuantizedMatMul" ,
43- "QuantizedRelu" , "QuantizedAdd" ,
44- "RequantizationRange" ,
45- "Requantize" ,
46- "QuantizedReshape" ,
47- "QuantizedConv2D"
48- ])
49- for op_info in ugraph .ops_info .values ():
50- if op_info .op_type in quant_ops :
51- is_quantized = True
52- break
53- if is_quantized :
54- logger .warning ((
55- "Expecting non-quantized graph, "
56- "graph transformation/optimization might not work properly"
57- ))
41+ def _check_generic (cls , ugraph ):
42+ # TODO: do the real check once we have full list of generic ops
43+ return True
5844
5945 @class_property
6046 def default_config (cls ):
0 commit comments