@@ -1270,22 +1270,19 @@ namespace dlib
12701270 const tensor& beta
12711271 )
12721272 {
1273- const long num = src.k () * src.nr () * src.nc ();
12741273 DLIB_CASSERT (
12751274 have_same_dimensions (gamma, beta) &&
1276- src .k () == gamma .k () &&
1277- src .nr () == gamma. nr () &&
1278- src .nc () == gamma. nc () &&
1275+ gamma .k () == src .k () &&
1276+ gamma .nr () == 1 &&
1277+ gamma .nc () == 1 &&
12791278 eps > 0 ,
1279+ " \n src.k(): " << src.k () <<
12801280 " \n gamma.k(): " << gamma.k () <<
12811281 " \n gamma.nr(): " << gamma.nr () <<
12821282 " \n gamma.nc(): " << gamma.nc () <<
12831283 " \n beta.k(): " << beta.k () <<
12841284 " \n beta.nr(): " << beta.nr () <<
12851285 " \n beta.nc(): " << beta.nc () <<
1286- " \n src.k(): " << src.k () <<
1287- " \n src.nr(): " << src.nr () <<
1288- " \n src.nc(): " << src.nc () <<
12891286 " \n eps: " << eps
12901287 );
12911288
@@ -1296,43 +1293,50 @@ namespace dlib
12961293 // first compute means and invstds
12971294 means = 0 ;
12981295 invstds = 0 ;
1299- const auto p_invstds = invstds.host ();
1300- const auto p_means = means.host ();
1301- auto p_src = src.host ();
1296+ const float * p_src = src.host ();
1297+ float * p_invstds = invstds.host ();
1298+ float * p_means = means.host ();
1299+ const long num = src.nr () * src.nc ();
13021300 // compute means, and sum of squares
13031301 for (long n = 0 ; n < src.num_samples (); ++n)
13041302 {
1305- for (long i = 0 ; i < num ; ++i )
1303+ for (long k = 0 ; k < src. k () ; ++k )
13061304 {
1307- float val = p_src[n*num+i];
1308- p_means[n] += val;
1309- p_invstds[n] += val*val;
1305+ for (long i = 0 ; i < num; ++i)
1306+ {
1307+ p_means[n] += *p_src;
1308+ p_invstds[n] += (*p_src) * (*p_src);
1309+ ++p_src;
1310+ }
13101311 }
13111312 }
1312- means /= num;
1313- invstds /= num;
1313+ means /= src. k () * num;
1314+ invstds /= src. k () * num;
13141315 // copy data back to host
1315- invstds.host (); means.host ();
1316+ invstds.host ();
1317+ means.host ();
13161318
13171319 // compute variances
13181320 for (long n = 0 ; n < src.num_samples (); ++n)
13191321 {
1320- auto var = p_invstds[n] - p_means[n] * p_means[n];
1321- p_invstds[n] = 1 .0f / std::sqrt (var + eps);
1322+ p_invstds[n] = 1 .0f / std::sqrt (p_invstds[n] - p_means[n] * p_means[n] + eps);
13221323 }
13231324
13241325 p_src = src.host ();
1325- auto p_dest = dest.host ();
1326- auto p_gamma = gamma.host ();
1327- auto p_beta = beta.host ();
1326+ float * p_dest = dest.host ();
1327+ const float * p_gamma = gamma.host ();
1328+ const float * p_beta = beta.host ();
13281329 for (long n = 0 ; n < src.num_samples (); ++n)
13291330 {
1330- for (long i = 0 ; i < num ; ++i )
1331+ for (long k = 0 ; k < src. k () ; ++k )
13311332 {
1332- *p_dest = (*p_src - p_means[n])*p_invstds[n];
1333- *p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
1334- ++p_src;
1335- ++p_dest;
1333+ for (long i = 0 ; i < num; ++i)
1334+ {
1335+ *p_dest = (*p_src - p_means[n]) * p_invstds[n];
1336+ *p_dest = (*p_dest) * p_gamma[k] + p_beta[k];
1337+ ++p_src;
1338+ ++p_dest;
1339+ }
13361340 }
13371341 }
13381342 }
@@ -1346,22 +1350,26 @@ namespace dlib
13461350 const tensor& gamma,
13471351 tensor& src_grad,
13481352 tensor& gamma_grad,
1349- tensor& beta_grad
1353+ tensor& beta_grad,
1354+ resizable_tensor& dmeans,
1355+ resizable_tensor& dvars
13501356 )
13511357 {
1352- const long num = src.k () * src. nr () * src.nc ();
1358+ const long num = src.nr () * src.nc ();
13531359 DLIB_CASSERT (src.num_samples () == means.size ());
13541360 DLIB_CASSERT (src.num_samples () == invstds.size ());
1355- DLIB_CASSERT (src.k () == gamma.k ());
1356- DLIB_CASSERT (src.nr () == gamma_grad.nr ());
1357- DLIB_CASSERT (src.nc () == beta_grad.nc ());
1361+ DLIB_CASSERT (have_same_dimensions (gamma, gamma_grad));
1362+ DLIB_CASSERT (have_same_dimensions (gamma_grad, beta_grad));
1363+ DLIB_CASSERT (gamma.k () == src.k ());
1364+ DLIB_CASSERT (gamma.nr () == 1 );
1365+ DLIB_CASSERT (gamma.nc () == 1 );
13581366 DLIB_CASSERT (have_same_dimensions (gradient_input, src));
13591367 DLIB_CASSERT (have_same_dimensions (gradient_input, src_grad));
1360- DLIB_CASSERT (have_same_dimensions (gamma_grad, beta_grad));
13611368 DLIB_CASSERT (eps > 0 );
13621369
13631370 beta_grad = 0 ;
13641371 gamma_grad = 0 ;
1372+
13651373 auto p_grad = gradient_input.host ();
13661374 auto p_src = src.host ();
13671375 const auto p_gamma = gamma.host ();
@@ -1370,7 +1378,6 @@ namespace dlib
13701378 const auto p_invstds = invstds.host ();
13711379 const auto p_means = means.host ();
13721380
1373- resizable_tensor dvars, dmeans;
13741381 dvars.copy_size (invstds);
13751382 dmeans.copy_size (means);
13761383 dvars = 0 ;
@@ -1380,53 +1387,62 @@ namespace dlib
13801387
13811388 for (long n = 0 ; n < src.num_samples (); ++n)
13821389 {
1383- for (long i = 0 ; i < num; ++i)
1390+ const float invstd_pow = -0.5 * std::pow (p_invstds[n], 3 .0f );
1391+ for (long k = 0 ; k < src.k (); ++k)
13841392 {
1385- const float x_hat = (*p_src - p_means[n])*p_invstds[n];
1386- p_beta_grad[i] += *p_grad;
1387- p_gamma_grad[i] += (*p_grad)*x_hat;
1393+ for (long i = 0 ; i < num; ++i)
1394+ {
1395+ const float x_hat = (*p_src - p_means[n]) * p_invstds[n];
1396+ p_beta_grad[k] += *p_grad;
1397+ p_gamma_grad[k] += (*p_grad) * x_hat;
13881398
1389- const float dx = *p_grad * p_gamma[n ];
1399+ const float dx = *p_grad * p_gamma[k ];
13901400
1391- p_dvars[n] += dx* (*p_src - p_means[n])*- 0.5 *p_invstds[n]*p_invstds[n]*p_invstds[n] ;
1401+ p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow ;
13921402
1393- ++p_grad;
1394- ++p_src;
1403+ ++p_grad;
1404+ ++p_src;
1405+ }
13951406 }
13961407 }
13971408
1398- const float invnum = 1 .0f /num;
13991409 p_grad = gradient_input.host ();
14001410 p_src = src.host ();
1411+ const float invnum = 1 .0f / (src.k () * num);
14011412 for (long n = 0 ; n < src.num_samples (); ++n)
14021413 {
1403- for (long i = 0 ; i < num ; ++i )
1414+ for (long k = 0 ; k < src. k () ; ++k )
14041415 {
1405- const float dx = *p_grad * p_gamma[i];
1416+ for (long i = 0 ; i < num; ++i)
1417+ {
1418+ const float dx = *p_grad * p_gamma[k];
14061419
1407- p_dmeans[n] += dx*- p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n])* invnum;
1420+ p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum;
14081421
1409- ++p_grad;
1410- ++p_src;
1422+ ++p_grad;
1423+ ++p_src;
1424+ }
14111425 }
14121426 }
14131427 p_grad = gradient_input.host ();
14141428 p_src = src.host ();
14151429 auto p_src_grad = src_grad.host ();
14161430 for (long n = 0 ; n < src.num_samples (); ++n)
14171431 {
1418- for (long i = 0 ; i < num ; ++i )
1432+ for (long k = 0 ; k < src. k () ; ++k )
14191433 {
1420- const float dx = *p_grad * p_gamma[i];
1421-
1422- *p_src_grad += dx*p_invstds[n] +
1423- p_dvars[n] *2 *(*p_src - p_means[n])*invnum +
1424- p_dmeans[n]*invnum;
1434+ for (long i = 0 ; i < num; ++i)
1435+ {
1436+ const float dx = *p_grad * p_gamma[k];
14251437
1438+ *p_src_grad += dx * p_invstds[n] +
1439+ p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum +
1440+ p_dmeans[n] * invnum;
14261441
1427- ++p_grad;
1428- ++p_src;
1429- ++p_src_grad;
1442+ ++p_grad;
1443+ ++p_src;
1444+ ++p_src_grad;
1445+ }
14301446 }
14311447 }
14321448 }
0 commit comments