-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Hi
My model has some LSTM layers and the count_ops
thrown the following error:
In [37]: count_ops(model, x)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-37-846862382dff> in <module>
----> 1 count_ops(model, x)
/usr/local/lib/python3.5/dist-packages/pthflops/ops.py in count_ops(model, input, custom_ops, ignore_layers, print_readable, verbose, *args)
212 # Convert pytorch module to ONNX
213 trace, _ = torch.jit.get_trace_graph(model, input, *args)
--> 214 torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
215 graph = trace.graph()
216
/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _optimize_trace(trace, operator_export_type)
40 def _optimize_trace(trace, operator_export_type):
41 from torch.onnx import utils
---> 42 trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type))
43
44
/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type)
153
154 if operator_export_type != OperatorExportTypes.RAW:
--> 155 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
156 torch._C._jit_pass_lint(graph)
157 torch._C._jit_pass_onnx_peephole(graph)
/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
50 def _run_symbolic_function(*args, **kwargs):
51 from torch.onnx import utils
---> 52 return utils._run_symbolic_function(*args, **kwargs)
53
54
/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
502 return None
503 fn = getattr(torch.onnx.symbolic, op_name)
--> 504 return fn(g, *inputs, **attrs)
505
506 elif ns == "prim":
/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in lstm(g, *args)
1274 return _lstm_packed(g, *args)
1275 else:
-> 1276 return _lstm_full(g, *args)
1277
1278
/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args)
87 assert len(arg_descriptors) == len(args)
88 args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
---> 89 return fn(g, *args)
90 # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
91 try:
/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first)
1260 hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
1261 return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
-> 1262 dropout, train, bidirectional, batch_first)
1263
1264
/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first, batch_sizes)
1201 state_indices = i, i + 1
1202 else:
-> 1203 weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
1204 weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
1205
/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in transform_weights(layer_index)
1188 elif variant == 'GRU' or variant == 'LSTM':
1189 weight_ih, weight_hh, bias_ih, bias_hh = \
-> 1190 [reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
1191 bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
1192
ValueError: not enough values to unpack (expected 4, got 2)
Metadata
Metadata
Assignees
Labels
No labels