Skip to content

LSTM support #2

@mohamad-hasan-sohan-ajini

Description

Hi

My model has some LSTM layers and the count_ops thrown the following error:

In [37]: count_ops(model, x)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-37-846862382dff> in <module>
----> 1 count_ops(model, x)

/usr/local/lib/python3.5/dist-packages/pthflops/ops.py in count_ops(model, input, custom_ops, ignore_layers, print_readable, verbose, *args)
    212     # Convert pytorch module to ONNX
    213     trace, _ = torch.jit.get_trace_graph(model, input, *args)
--> 214     torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
    215     graph = trace.graph()
    216

/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _optimize_trace(trace, operator_export_type)
     40 def _optimize_trace(trace, operator_export_type):
     41     from torch.onnx import utils
---> 42     trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type))
     43
     44

/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type)
    153 
    154     if operator_export_type != OperatorExportTypes.RAW:
--> 155         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    156         torch._C._jit_pass_lint(graph)
    157         torch._C._jit_pass_onnx_peephole(graph)

/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
     50 def _run_symbolic_function(*args, **kwargs):
     51     from torch.onnx import utils
---> 52     return utils._run_symbolic_function(*args, **kwargs)
     53 
     54 

/usr/local/lib/python3.5/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    502                     return None
    503                 fn = getattr(torch.onnx.symbolic, op_name)
--> 504                 return fn(g, *inputs, **attrs)
    505 
    506         elif ns == "prim":

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in lstm(g, *args)
   1274         return _lstm_packed(g, *args)
   1275     else:
-> 1276         return _lstm_full(g, *args)
   1277
   1278

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in wrapper(g, *args)
     87             assert len(arg_descriptors) == len(args)
     88             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
---> 89             return fn(g, *args)
     90         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
     91         try:

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first)
   1260     hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
   1261     return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
-> 1262                         dropout, train, bidirectional, batch_first)
   1263
   1264

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first, batch_sizes)
   1201             state_indices = i, i + 1
   1202         else:
-> 1203             weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
   1204             weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
   1205

/usr/local/lib/python3.5/dist-packages/torch/onnx/symbolic.py in transform_weights(layer_index)
   1188         elif variant == 'GRU' or variant == 'LSTM':
   1189             weight_ih, weight_hh, bias_ih, bias_hh = \
-> 1190                 [reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
   1191         bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
   1192

ValueError: not enough values to unpack (expected 4, got 2)


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions