Skip to content

Commit 4883654

Browse files
committed
Defuse conv2d: inject activation op after conv2d
1 parent d190be4 commit 4883654

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

utensor_cgen/ir/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,11 @@ def output_nodes(self):
352352
353353
:rtype: List[:class:`OperationInfo`]
354354
"""
355-
out_ops = []
355+
out_ops = set()
356356
for op in self._ugraph.ops:
357357
for in_tensor in op.input_tensors:
358358
if in_tensor.op_name == self.name and op.name not in out_ops:
359-
out_ops.append(op.name)
359+
out_ops.add(op.name)
360360
break
361361
return [self._ugraph.ops_info[name] for name in out_ops]
362362

utensor_cgen/legalizer/tflite.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import re
12
from copy import deepcopy
23
from functools import reduce
34

4-
from utensor_cgen.ir.base import OperationInfo
5+
from utensor_cgen.ir.base import OperationInfo, TensorInfo
56
from utensor_cgen.utils import topologic_order_graph
67

78
from .api import Legalizer
@@ -45,6 +46,11 @@ class _GraphRewrite(object):
4546

4647
@classmethod
4748
def apply(cls, ugraph):
49+
cls._handle_fully_connected(ugraph)
50+
cls._handle_conv_2d(ugraph)
51+
52+
@classmethod
53+
def _handle_fully_connected(cls, ugraph):
4854
# 1. transpose the filter to make a right mulitiplication: fc = x @ filter + bias
4955
# 2. if the input is not flatten, inject a reshape op
5056
reshape_cnt = 0
@@ -78,3 +84,55 @@ def apply(cls, ugraph):
7884
reshape_cnt += 1
7985
op_info.input_tensors[0] = out_tensor
8086
topologic_order_graph(ugraph)
87+
88+
@classmethod
89+
def _handle_conv_2d(cls, ugraph):
90+
activation_pattern = re.compile(r'^(\d+) \(\w+\)$')
91+
activation_map = {
92+
'0': 'None',
93+
'1': 'ReLUOperator',
94+
# '2': 'TFLM::TfLiteFusedActivation::kTfLiteActRelu1',
95+
'3': 'ReLU6Operator',
96+
# '4': 'TFLM::TfLiteFusedActivation::kTfLiteActTanh',
97+
# '5': 'TFLM::TfLiteFusedActivation::kTfLiteActSignBit',
98+
# '6': 'TFLM::TfLiteFusedActivation::kTfLiteActSigmoid',
99+
}
100+
for i, op_info in enumerate(ugraph.get_ops_by_type('Conv2d')):
101+
act_idx = activation_pattern.match(
102+
op_info.op_attr['FusedActivationFunction']
103+
).group(1)
104+
act_op_type = activation_map.get(act_idx)
105+
if act_op_type is None:
106+
raise ValueError(
107+
'legalization fail, unknown activation: {}'.format(
108+
op_info.op_attr['FusedActivationFunction']
109+
)
110+
)
111+
elif act_op_type is 'None':
112+
# no activation is set, ignore
113+
continue
114+
else:
115+
ori_out_tensor = op_info.output_tensors[0]
116+
act_op_name = '{}/{}'.format(op_info.name, act_op_type.replace('Operator', ''))
117+
act_tensor = TensorInfo(
118+
name='{}:0'.format(act_op_name),
119+
op_name=act_op_name,
120+
dtype=ori_out_tensor.dtype,
121+
shape=ori_out_tensor.shape[:],
122+
ugraph=ugraph,
123+
attributes=dict(ori_out_tensor.attributes),
124+
)
125+
OperationInfo(
126+
name=act_op_name,
127+
input_tensors=[ori_out_tensor],
128+
output_tensors=[act_tensor],
129+
op_type=act_op_type,
130+
lib_name=ugraph.lib_name,
131+
ugraph=ugraph,
132+
op_attr={}
133+
)
134+
for consumer_op in ori_out_tensor.op.output_nodes:
135+
for i, input_tensor in enumerate(consumer_op.input_tensors):
136+
if input_tensor.name == ori_out_tensor.name:
137+
consumer_op.input_tensors[i] = act_tensor
138+
topologic_order_graph(ugraph)

0 commit comments

Comments
 (0)