Skip to content

Commit 253098e

Browse files
authored
Fix layer_normalize gradients (davisking#3001)
* Fix layer_normalize gradients * fix layer_norm CPU * attempt to fix the cuda version * fix gamma_grad and beta_grad * update cuda test * use a block of size 1 to avoid race conditions * improve the speed of CUDA path of layer_norm * improve the speed of CUDA path of layer_norm
1 parent 27a0135 commit 253098e

File tree

8 files changed

+213
-149
lines changed

8 files changed

+213
-149
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,22 +1270,19 @@ namespace dlib
12701270
const tensor& beta
12711271
)
12721272
{
1273-
const long num = src.k() * src.nr() * src.nc();
12741273
DLIB_CASSERT(
12751274
have_same_dimensions(gamma, beta) &&
1276-
src.k() == gamma.k() &&
1277-
src.nr() == gamma.nr() &&
1278-
src.nc() == gamma.nc() &&
1275+
gamma.k() == src.k() &&
1276+
gamma.nr() == 1 &&
1277+
gamma.nc() == 1 &&
12791278
eps > 0,
1279+
"\nsrc.k(): " << src.k() <<
12801280
"\ngamma.k(): " << gamma.k() <<
12811281
"\ngamma.nr(): " << gamma.nr() <<
12821282
"\ngamma.nc(): " << gamma.nc() <<
12831283
"\nbeta.k(): " << beta.k() <<
12841284
"\nbeta.nr(): " << beta.nr() <<
12851285
"\nbeta.nc(): " << beta.nc() <<
1286-
"\nsrc.k(): " << src.k() <<
1287-
"\nsrc.nr(): " << src.nr() <<
1288-
"\nsrc.nc(): " << src.nc() <<
12891286
"\neps: " << eps
12901287
);
12911288

@@ -1296,43 +1293,50 @@ namespace dlib
12961293
// first compute means and invstds
12971294
means = 0;
12981295
invstds = 0;
1299-
const auto p_invstds = invstds.host();
1300-
const auto p_means = means.host();
1301-
auto p_src = src.host();
1296+
const float* p_src = src.host();
1297+
float* p_invstds = invstds.host();
1298+
float* p_means = means.host();
1299+
const long num = src.nr() * src.nc();
13021300
// compute means, and sum of squares
13031301
for (long n = 0; n < src.num_samples(); ++n)
13041302
{
1305-
for (long i = 0; i < num; ++i)
1303+
for (long k = 0; k < src.k(); ++k)
13061304
{
1307-
float val = p_src[n*num+i];
1308-
p_means[n] += val;
1309-
p_invstds[n] += val*val;
1305+
for (long i = 0; i < num; ++i)
1306+
{
1307+
p_means[n] += *p_src;
1308+
p_invstds[n] += (*p_src) * (*p_src);
1309+
++p_src;
1310+
}
13101311
}
13111312
}
1312-
means /= num;
1313-
invstds /= num;
1313+
means /= src.k() * num;
1314+
invstds /= src.k () * num;
13141315
// copy data back to host
1315-
invstds.host(); means.host();
1316+
invstds.host();
1317+
means.host();
13161318

13171319
// compute variances
13181320
for (long n = 0; n < src.num_samples(); ++n)
13191321
{
1320-
auto var = p_invstds[n] - p_means[n] * p_means[n];
1321-
p_invstds[n] = 1.0f / std::sqrt(var + eps);
1322+
p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps);
13221323
}
13231324

13241325
p_src = src.host();
1325-
auto p_dest = dest.host();
1326-
auto p_gamma = gamma.host();
1327-
auto p_beta = beta.host();
1326+
float* p_dest = dest.host();
1327+
const float* p_gamma = gamma.host();
1328+
const float* p_beta = beta.host();
13281329
for (long n = 0; n < src.num_samples(); ++n)
13291330
{
1330-
for (long i = 0; i < num; ++i)
1331+
for (long k = 0; k < src.k(); ++k)
13311332
{
1332-
*p_dest = (*p_src - p_means[n])*p_invstds[n];
1333-
*p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
1334-
++p_src;
1335-
++p_dest;
1333+
for (long i = 0; i < num; ++i)
1334+
{
1335+
*p_dest = (*p_src - p_means[n]) * p_invstds[n];
1336+
*p_dest = (*p_dest) * p_gamma[k] + p_beta[k];
1337+
++p_src;
1338+
++p_dest;
1339+
}
13361340
}
13371341
}
13381342
}
@@ -1346,22 +1350,26 @@ namespace dlib
13461350
const tensor& gamma,
13471351
tensor& src_grad,
13481352
tensor& gamma_grad,
1349-
tensor& beta_grad
1353+
tensor& beta_grad,
1354+
resizable_tensor& dmeans,
1355+
resizable_tensor& dvars
13501356
)
13511357
{
1352-
const long num = src.k() * src.nr() * src.nc();
1358+
const long num = src.nr() * src.nc();
13531359
DLIB_CASSERT(src.num_samples() == means.size());
13541360
DLIB_CASSERT(src.num_samples() == invstds.size());
1355-
DLIB_CASSERT(src.k() == gamma.k());
1356-
DLIB_CASSERT(src.nr() == gamma_grad.nr());
1357-
DLIB_CASSERT(src.nc() == beta_grad.nc());
1361+
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
1362+
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
1363+
DLIB_CASSERT(gamma.k() == src.k());
1364+
DLIB_CASSERT(gamma.nr() == 1);
1365+
DLIB_CASSERT(gamma.nc() == 1);
13581366
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
13591367
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
1360-
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
13611368
DLIB_CASSERT(eps > 0);
13621369

13631370
beta_grad = 0;
13641371
gamma_grad = 0;
1372+
13651373
auto p_grad = gradient_input.host();
13661374
auto p_src = src.host();
13671375
const auto p_gamma = gamma.host();
@@ -1370,7 +1378,6 @@ namespace dlib
13701378
const auto p_invstds = invstds.host();
13711379
const auto p_means = means.host();
13721380

1373-
resizable_tensor dvars, dmeans;
13741381
dvars.copy_size(invstds);
13751382
dmeans.copy_size(means);
13761383
dvars = 0;
@@ -1380,53 +1387,62 @@ namespace dlib
13801387

13811388
for (long n = 0; n < src.num_samples(); ++n)
13821389
{
1383-
for (long i = 0; i < num; ++i)
1390+
const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f);
1391+
for (long k = 0; k < src.k(); ++k)
13841392
{
1385-
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
1386-
p_beta_grad[i] += *p_grad;
1387-
p_gamma_grad[i] += (*p_grad)*x_hat;
1393+
for (long i = 0; i < num; ++i)
1394+
{
1395+
const float x_hat = (*p_src - p_means[n]) * p_invstds[n];
1396+
p_beta_grad[k] += *p_grad;
1397+
p_gamma_grad[k] += (*p_grad) * x_hat;
13881398

1389-
const float dx = *p_grad * p_gamma[n];
1399+
const float dx = *p_grad * p_gamma[k];
13901400

1391-
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
1401+
p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow;
13921402

1393-
++p_grad;
1394-
++p_src;
1403+
++p_grad;
1404+
++p_src;
1405+
}
13951406
}
13961407
}
13971408

1398-
const float invnum = 1.0f/num;
13991409
p_grad = gradient_input.host();
14001410
p_src = src.host();
1411+
const float invnum = 1.0f / (src.k() * num);
14011412
for (long n = 0; n < src.num_samples(); ++n)
14021413
{
1403-
for (long i = 0; i < num; ++i)
1414+
for (long k = 0; k < src.k(); ++k)
14041415
{
1405-
const float dx = *p_grad * p_gamma[i];
1416+
for (long i = 0; i < num; ++i)
1417+
{
1418+
const float dx = *p_grad * p_gamma[k];
14061419

1407-
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
1420+
p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum;
14081421

1409-
++p_grad;
1410-
++p_src;
1422+
++p_grad;
1423+
++p_src;
1424+
}
14111425
}
14121426
}
14131427
p_grad = gradient_input.host();
14141428
p_src = src.host();
14151429
auto p_src_grad = src_grad.host();
14161430
for (long n = 0; n < src.num_samples(); ++n)
14171431
{
1418-
for (long i = 0; i < num; ++i)
1432+
for (long k = 0; k < src.k(); ++k)
14191433
{
1420-
const float dx = *p_grad * p_gamma[i];
1421-
1422-
*p_src_grad += dx*p_invstds[n] +
1423-
p_dvars[n] *2*(*p_src - p_means[n])*invnum +
1424-
p_dmeans[n]*invnum;
1434+
for (long i = 0; i < num; ++i)
1435+
{
1436+
const float dx = *p_grad * p_gamma[k];
14251437

1438+
*p_src_grad += dx * p_invstds[n] +
1439+
p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum +
1440+
p_dmeans[n] * invnum;
14261441

1427-
++p_grad;
1428-
++p_src;
1429-
++p_src_grad;
1442+
++p_grad;
1443+
++p_src;
1444+
++p_src_grad;
1445+
}
14301446
}
14311447
}
14321448
}

dlib/cuda/cpu_dlib.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ namespace dlib
250250
const tensor& gamma,
251251
tensor& src_grad,
252252
tensor& gamma_grad,
253-
tensor& beta_grad
253+
tensor& beta_grad,
254+
resizable_tensor& dmeans,
255+
resizable_tensor& dvars
254256
);
255257

256258
// -----------------------------------------------------------------------------------

0 commit comments

Comments
 (0)