Skip to content

Commit 3760ba3

Browse files
Rprop Fix (#1421)
* Rprop Fix --------- Co-authored-by: Ozan Aydin <148207261+ozanMSFT@users.noreply.github.com>
1 parent cf744cd commit 3760ba3

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

src/TorchSharp/Optimizers/Rprop.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public class Rprop : OptimizerHelper
8686
/// <param name="min_step">Minimum allowed step size.</param>
8787
/// <param name="max_step">Maximum allowed step size.</param>
8888
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing.</param>
89-
public Rprop(IEnumerable<Parameter> parameters, double lr = 0.01, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
89+
public Rprop(IEnumerable<Parameter> parameters, double lr = 1e-2, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
9090
: this(new ParamGroup[] { new() { Parameters = parameters } }, lr, etaminus, etaplus, min_step, max_step, maximize)
9191
{
9292
}
@@ -156,10 +156,6 @@ public override Tensor step(Func<Tensor> closure = null)
156156

157157
state.step += 1;
158158

159-
grad = (max_step != 0)
160-
? grad.add(param, alpha: max_step)
161-
: grad.alias();
162-
163159
var sign = grad.mul(state.prev).sign();
164160
sign[sign.gt(0)] = (Tensor)etaplus;
165161
sign[sign.lt(0)] = (Tensor)etaminus;

test/TorchSharpTest/TestTraining.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ public void TrainingRprop()
11701170

11711171
var loss = TrainLoop(seq, x, y, optimizer);
11721172

1173-
LossIsClose(229.68f, loss);
1173+
LossIsClose(77.279f, loss);
11741174
}
11751175

11761176

@@ -1187,7 +1187,7 @@ public void TrainingRpropMax()
11871187

11881188
var loss = TrainLoop(seq, x, y, optimizer, maximize:true);
11891189

1190-
LossIsClose(229.68f, -loss);
1190+
LossIsClose(77.279f, -loss);
11911191
}
11921192

11931193
[Fact]
@@ -1203,7 +1203,7 @@ public void TrainingRpropEtam()
12031203

12041204
var loss = TrainLoop(seq, x, y, optimizer);
12051205

1206-
LossIsClose(201.417f, loss);
1206+
LossIsClose(171.12f, loss);
12071207
}
12081208

12091209
[Fact]
@@ -1219,7 +1219,7 @@ public void TrainingRpropEtap()
12191219

12201220
var loss = TrainLoop(seq, x, y, optimizer);
12211221

1222-
LossIsClose(221.365f, loss);
1222+
LossIsClose(65.859f, loss);
12231223
}
12241224

12251225

@@ -1240,7 +1240,7 @@ public void TrainingRpropParamGroups()
12401240

12411241
var loss = TrainLoop(seq, x, y, optimizer);
12421242

1243-
LossIsClose(78.619f, loss);
1243+
LossIsClose(66.479f, loss);
12441244
}
12451245

12461246
/// <summary>

0 commit comments

Comments
 (0)