1- import skorch .utils
2-
3- # TODO: make it more safe
4- old_to_tensor = skorch .utils .to_tensor
5-
6- def to_tensor (X , device , accept_sparse = False ):
7- if isinstance (X , TensorFrame ):
8- return X
9- return old_to_tensor (X , device , accept_sparse )
10-
11- skorch .utils .to_tensor = to_tensor
121import importlib
13- importlib .reload (skorch .net )
14-
152from typing import Any
163
17- import pandas as pd
4+ import skorch . utils
185import torch
19- import torch .nn as nn
206from numpy .typing import ArrayLike
217from pandas import DataFrame
22- from skorch import NeuralNet , NeuralNetClassifier
23- from skorch .dataset import Dataset as SkorchDataset
8+ from skorch import NeuralNet
249from torch import Tensor
2510
2611import torch_frame
@@ -29,20 +14,34 @@ def to_tensor(X, device, accept_sparse=False):
2914 TextEmbedderConfig ,
3015 TextTokenizerConfig ,
3116)
32- from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
17+ from torch_frame .data .dataset import Dataset
3318from torch_frame .data .loader import DataLoader
3419from torch_frame .data .tensor_frame import TensorFrame
3520from torch_frame .typing import IndexSelectType
3621from torch_frame .utils import infer_df_stype
3722
23+ # TODO: make it more safe
24+ old_to_tensor = skorch .utils .to_tensor
25+
26+
27+ def to_tensor (X , device , accept_sparse = False ):
28+ if isinstance (X , TensorFrame ):
29+ return X
30+ return old_to_tensor (X , device , accept_sparse )
31+
32+
33+ skorch .utils .to_tensor = to_tensor
34+
35+ importlib .reload (skorch .net )
36+
3837
3938class NeuralNetPytorchFrameDataLoader (DataLoader ):
4039 def __init__ (self , dataset : Dataset | TensorFrame , * args ,
4140 device : torch .device , ** kwargs ):
4241 super ().__init__ (dataset , * args , ** kwargs )
4342 self .device = device
4443
45- def collate_fn (
44+ def collate_fn ( # type: ignore
4645 self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
4746 index = torch .tensor (index )
4847 res = super ().collate_fn (index ).to (self .device )
0 commit comments