diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index b0d1c69916..b5fc3d48fa 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -68,11 +68,13 @@ def fake_tensorrt_execute_engine( output_shape.append(min_val) else: output_shape.extend(outputs_mode_dict["opt"][out_idx].size()) - fake_outputs.append( - torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype) + torch.empty( + output_shape, + dtype=outputs_mode_dict["opt"][out_idx].dtype, + device=inputs[0].device, + ) ) - return fake_outputs