@@ -45,17 +45,17 @@ private class Model : NN.Module
4545 private NN . Module fc1 = Linear ( 320 , 50 ) ;
4646 private NN . Module fc2 = Linear ( 50 , 10 ) ;
4747
48- public Model ( ) : base ( )
48+ public Model ( )
4949 {
5050 RegisterModule ( conv1 ) ;
5151 RegisterModule ( conv2 ) ;
5252 RegisterModule ( fc1 ) ;
5353 RegisterModule ( fc2 ) ;
5454 }
5555
56- public override ITorchTensor < float > Forward < T > ( params ITorchTensor < T > [ ] tensors )
56+ public override TorchTensor Forward ( TorchTensor input )
5757 {
58- using ( var l11 = conv1 . Forward ( tensors ) )
58+ using ( var l11 = conv1 . Forward ( input ) )
5959 using ( var l12 = MaxPool2D ( l11 , 2 ) )
6060 using ( var l13 = Relu ( l12 ) )
6161
@@ -79,7 +79,7 @@ public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
7979 private static void Train (
8080 NN . Module model ,
8181 NN . Optimizer optimizer ,
82- IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
82+ IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
8383 int epoch ,
8484 long batchSize ,
8585 long size )
@@ -101,7 +101,7 @@ private static void Train(
101101
102102 if ( batchId % _logInterval == 0 )
103103 {
104- Console . WriteLine ( $ "\r Train: epoch { epoch } [{ batchId * batchSize } / { size } ] Loss: { loss . DataItem } ") ;
104+ Console . WriteLine ( $ "\r Train: epoch { epoch } [{ batchId * batchSize } / { size } ] Loss: { loss . DataItem < float > ( ) } ") ;
105105 }
106106
107107 batchId ++ ;
@@ -114,7 +114,7 @@ private static void Train(
114114
115115 private static void Test (
116116 NN . Module model ,
117- IEnumerable < ( ITorchTensor < int > , ITorchTensor < int > ) > dataLoader ,
117+ IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
118118 long size )
119119 {
120120 model . Eval ( ) ;
@@ -127,11 +127,11 @@ private static void Test(
127127 using ( var output = model . Forward ( data ) )
128128 using ( var loss = NN . LossFunction . NLL ( output , target , reduction : NN . Reduction . Sum ) )
129129 {
130- testLoss += loss . DataItem ;
130+ testLoss += loss . DataItem < float > ( ) ;
131131
132132 var pred = output . Argmax ( 1 ) ;
133133
134- correct += pred . Eq ( target ) . Sum ( ) . DataItem ; // Memory leak here
134+ correct += pred . Eq ( target ) . Sum ( ) . DataItem < int > ( ) ; // Memory leak here
135135
136136 data . Dispose ( ) ;
137137 target . Dispose ( ) ;
0 commit comments