We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
skorch.utils.to_tensor()
1 parent df8ecc4 commit ca95b8fCopy full SHA for ca95b8f
torch_frame/utils/skorch.py
@@ -1,3 +1,17 @@
1
+import skorch.utils
2
+
3
+# TODO: make it more safe
4
+old_to_tensor = skorch.utils.to_tensor
5
6
+def to_tensor(X, device, accept_sparse=False):
7
+ if isinstance(X, TensorFrame):
8
+ return X
9
+ return old_to_tensor(X, device, accept_sparse)
10
11
+skorch.utils.to_tensor = to_tensor
12
+import importlib
13
+importlib.reload(skorch.net)
14
15
from typing import Any
16
17
import pandas as pd
0 commit comments