Skip to content

Commit fafdac3

Browse files
authored
Add RMS Normalization Layer (davisking#2999)
* Add RMS Normalization Layer * Update dnn.cpp * Missing entry in visitors.h to take into account the new rms_norm_ layer * Fix test function name * Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient * Fixing the dnn.cpp test program for the new rms_norm_ layer * General update of the rms_norm_ class
1 parent 253098e commit fafdac3

File tree

11 files changed

+863
-4
lines changed

11 files changed

+863
-4
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,144 @@ namespace dlib
14471447
}
14481448
}
14491449

1450+
// -----------------------------------------------------------------------------------
1451+
1452+
void rms_normalize(
1453+
const double eps,
1454+
resizable_tensor& dest,
1455+
resizable_tensor& scale,
1456+
const tensor& src,
1457+
const tensor& gamma
1458+
)
1459+
{
1460+
DLIB_CASSERT(
1461+
gamma.k() == src.k() &&
1462+
gamma.nr() == 1 &&
1463+
gamma.nc() == 1 &&
1464+
eps > 0,
1465+
"\nsrc.k(): " << src.k() <<
1466+
"\ngamma.k(): " << gamma.k() <<
1467+
"\ngamma.nr(): " << gamma.nr() <<
1468+
"\ngamma.nc(): " << gamma.nc() <<
1469+
"\neps: " << eps
1470+
);
1471+
1472+
const long ns = src.num_samples();
1473+
const long ks = src.k();
1474+
const long num = src.nr() * src.nc();
1475+
1476+
dest.copy_size(src);
1477+
scale.set_size(ns);
1478+
1479+
// Compute RMS values
1480+
scale = 0;
1481+
const float* p_src = src.host();
1482+
float* p_scale = scale.host();
1483+
for (long n = 0; n < ns; ++n)
1484+
{
1485+
for (long k = 0; k < ks; ++k)
1486+
{
1487+
for (long i = 0; i < num; ++i)
1488+
{
1489+
p_scale[n] += (*p_src) * (*p_src);
1490+
++p_src;
1491+
}
1492+
}
1493+
p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast<float>(eps));
1494+
}
1495+
scale.host();
1496+
1497+
// Apply RMS normalization
1498+
p_src = src.host();
1499+
float* p_dest = dest.host();
1500+
const float* p_gamma = gamma.host();
1501+
for (long n = 0; n < ns; ++n)
1502+
{
1503+
for (long k = 0; k < ks; ++k)
1504+
{
1505+
for (long i = 0; i < num; ++i)
1506+
{
1507+
*p_dest = (*p_src) * p_scale[n] * p_gamma[k];
1508+
++p_src;
1509+
++p_dest;
1510+
}
1511+
}
1512+
}
1513+
}
1514+
1515+
void rms_normalize_gradient(
1516+
const tensor& gradient_input,
1517+
const tensor& scale,
1518+
const tensor& src,
1519+
const tensor& gamma,
1520+
tensor& src_grad,
1521+
tensor& gamma_grad,
1522+
resizable_tensor& dscale
1523+
)
1524+
{
1525+
DLIB_CASSERT(src.num_samples() == scale.size());
1526+
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
1527+
DLIB_CASSERT(gamma.k() == src.k());
1528+
DLIB_CASSERT(gamma.nr() == 1);
1529+
DLIB_CASSERT(gamma.nc() == 1);
1530+
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
1531+
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
1532+
1533+
const long ns = src.num_samples();
1534+
const long ks = src.k();
1535+
const long num = src.nr() * src.nc();
1536+
1537+
gamma_grad = 0;
1538+
dscale.copy_size(scale);
1539+
dscale = 0;
1540+
1541+
auto p_grad = gradient_input.host();
1542+
auto p_src = src.host();
1543+
const auto p_gamma = gamma.host();
1544+
const auto p_gamma_grad = gamma_grad.host();
1545+
const auto p_scale = scale.host();
1546+
auto p_dscale = dscale.host();
1547+
1548+
for (long n = 0; n < ns; ++n)
1549+
{
1550+
const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f);
1551+
for (long k = 0; k < ks; ++k)
1552+
{
1553+
for (long i = 0; i < num; ++i)
1554+
{
1555+
const float x_hat = *p_src * p_scale[n];
1556+
p_gamma_grad[k] += (*p_grad) * x_hat;
1557+
1558+
const float dx = *p_grad * p_gamma[k];
1559+
p_dscale[n] += dx * *p_src * scale_pow;
1560+
1561+
++p_grad;
1562+
++p_src;
1563+
}
1564+
}
1565+
}
1566+
1567+
p_grad = gradient_input.host();
1568+
p_src = src.host();
1569+
auto p_src_grad = src_grad.host();
1570+
const float invnum = 1.0f / (ks * num);
1571+
for (long n = 0; n < ns; ++n)
1572+
{
1573+
for (long k = 0; k < ks; ++k)
1574+
{
1575+
for (long i = 0; i < num; ++i)
1576+
{
1577+
const float dx = *p_grad * p_gamma[k];
1578+
*p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum;
1579+
1580+
++p_grad;
1581+
++p_src;
1582+
++p_src_grad;
1583+
}
1584+
}
1585+
}
1586+
}
1587+
14501588
// -----------------------------------------------------------------------------------
14511589

14521590
void threshold (

dlib/cuda/cpu_dlib.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,26 @@ namespace dlib
255255
resizable_tensor& dvars
256256
);
257257

258+
// -----------------------------------------------------------------------------------
259+
260+
void rms_normalize(
261+
const double eps,
262+
resizable_tensor& dest,
263+
resizable_tensor& scale,
264+
const tensor& src,
265+
const tensor& gamma
266+
);
267+
268+
void rms_normalize_gradient(
269+
const tensor& gradient_input,
270+
const tensor& scale,
271+
const tensor& src,
272+
const tensor& gamma,
273+
tensor& src_grad,
274+
tensor& gamma_grad,
275+
resizable_tensor& dscale
276+
);
277+
258278
// -----------------------------------------------------------------------------------
259279

260280
void threshold (

dlib/cuda/cuda_dlib.cu

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,6 +2280,166 @@ namespace dlib
22802280
dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num);
22812281
}
22822282

2283+
// ----------------------------------------------------------------------------------------
2284+
2285+
__global__ void _cuda_rms_normalize(
2286+
float* dest,
2287+
float* scale,
2288+
const float* src,
2289+
const float* gamma,
2290+
float eps,
2291+
size_t ns,
2292+
size_t ks,
2293+
size_t num
2294+
)
2295+
{
2296+
for (auto n : grid_stride_range_y(0, ns))
2297+
{
2298+
const auto ps = src + n * ks * num;
2299+
float sum_squares = 0.0f;
2300+
for (auto i : grid_stride_range(0, ks * num))
2301+
{
2302+
sum_squares += ps[i] * ps[i];
2303+
}
2304+
warp_reduce_atomic_add(scale[n], sum_squares / (ks * num));
2305+
}
2306+
__syncthreads();
2307+
2308+
for (auto n : grid_stride_range_y(0, ns))
2309+
{
2310+
for (auto i : grid_stride_range(0, 1))
2311+
{
2312+
scale[n] = 1.0f / std::sqrt(scale[n] + eps);
2313+
}
2314+
}
2315+
__syncthreads();
2316+
2317+
for (auto n : grid_stride_range_y(0, ns))
2318+
{
2319+
const auto ps = src + n * ks * num;
2320+
const auto pd = dest + n * ks * num;
2321+
for (auto i : grid_stride_range(0, ks * num))
2322+
{
2323+
pd[i] = ps[i] * scale[n] * gamma[i / num];
2324+
}
2325+
}
2326+
}
2327+
2328+
void rms_normalize(
2329+
const double eps,
2330+
resizable_tensor& dest,
2331+
resizable_tensor& scale,
2332+
const tensor& src,
2333+
const tensor& gamma
2334+
)
2335+
{
2336+
DLIB_CASSERT(
2337+
gamma.k() == src.k() &&
2338+
gamma.nr() == 1 &&
2339+
gamma.nc() == 1 &&
2340+
eps > 0,
2341+
"\nsrc.k(): " << src.k() <<
2342+
"\ngamma.k(): " << gamma.k() <<
2343+
"\ngamma.nr(): " << gamma.nr() <<
2344+
"\ngamma.nc(): " << gamma.nc() <<
2345+
"\neps: " << eps
2346+
);
2347+
2348+
const long ns = src.num_samples();
2349+
const long ks = src.k();
2350+
const long num = src.nr() * src.nc();
2351+
2352+
dest.copy_size(src);
2353+
scale.set_size(ns);
2354+
scale = 0;
2355+
2356+
launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns),
2357+
dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num);
2358+
}
2359+
2360+
// ----------------------------------------------------------------------------------------
2361+
2362+
__global__ void _cuda_rms_normalize_gradient(
2363+
float* src_grad,
2364+
float* gamma_grad,
2365+
float* dscale,
2366+
const float* src,
2367+
const float* gradient_input,
2368+
const float* scale,
2369+
const float* gamma,
2370+
size_t ns,
2371+
size_t ks,
2372+
size_t num
2373+
)
2374+
{
2375+
for (auto nk : grid_stride_range_y(0, ns * ks))
2376+
{
2377+
const auto n = nk / ks;
2378+
const auto k = nk % ks;
2379+
const auto ps = src + (n * ks + k) * num;
2380+
const auto pgi = gradient_input + (n * ks + k) * num;
2381+
const float scale_pow = -0.5f * std::pow(scale[n], 3.0f);
2382+
float temp_gg = 0.0f;
2383+
float temp_ds = 0.0f;
2384+
for (auto i : grid_stride_range(0, num))
2385+
{
2386+
const float x_hat = ps[i] * scale[n];
2387+
const float dx = pgi[i] * gamma[i / num];
2388+
temp_gg += pgi[i] * x_hat;
2389+
temp_ds += dx * ps[i] * scale_pow;
2390+
}
2391+
warp_reduce_atomic_add(gamma_grad[k], temp_gg);
2392+
warp_reduce_atomic_add(dscale[n], temp_ds);
2393+
}
2394+
__syncthreads();
2395+
2396+
const float invnum = 1.0f / (ks * num);
2397+
for (auto n : grid_stride_range_y(0, ns))
2398+
{
2399+
const auto ps = src + n * ks * num;
2400+
const auto pgi = gradient_input + n * ks * num;
2401+
const auto psg = src_grad + n * ks * num;
2402+
for (auto i : grid_stride_range(0, ks * num))
2403+
{
2404+
const float dx = pgi[i] * gamma[i / num];
2405+
psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum;
2406+
}
2407+
}
2408+
}
2409+
2410+
void rms_normalize_gradient(
2411+
const tensor& gradient_input,
2412+
const tensor& scale,
2413+
const tensor& src,
2414+
const tensor& gamma,
2415+
tensor& src_grad,
2416+
tensor& gamma_grad,
2417+
resizable_tensor& dscale
2418+
)
2419+
{
2420+
DLIB_CASSERT(src.num_samples() == scale.size());
2421+
DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad));
2422+
DLIB_CASSERT(gamma.k() == src.k());
2423+
DLIB_CASSERT(gamma.nr() == 1);
2424+
DLIB_CASSERT(gamma.nc() == 1);
2425+
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
2426+
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
2427+
2428+
const long ns = src.num_samples();
2429+
const long ks = src.k();
2430+
const long num = src.nr() * src.nc();
2431+
2432+
gamma_grad = 0;
2433+
dscale.copy_size(scale);
2434+
dscale = 0;
2435+
2436+
// Lancement du kernel CUDA
2437+
launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns),
2438+
src_grad.device(), gamma_grad.device(), dscale.device(),
2439+
src.device(), gradient_input.device(), scale.device(), gamma.device(),
2440+
ns, ks, num);
2441+
}
2442+
22832443
// ----------------------------------------------------------------------------------------
22842444

22852445
__global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)

dlib/cuda/cuda_dlib.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,26 @@ namespace dlib
362362
resizable_tensor& dvars
363363
);
364364

365+
// -----------------------------------------------------------------------------------
366+
367+
void rms_normalize(
368+
const double eps,
369+
resizable_tensor& dest,
370+
resizable_tensor& scale,
371+
const tensor& src,
372+
const tensor& gamma
373+
);
374+
375+
void rms_normalize_gradient(
376+
const tensor& gradient_input,
377+
const tensor& scale,
378+
const tensor& src,
379+
const tensor& gamma,
380+
tensor& src_grad,
381+
tensor& gamma_grad,
382+
resizable_tensor& dscale
383+
);
384+
365385
// -----------------------------------------------------------------------------------
366386

367387
void threshold (

0 commit comments

Comments
 (0)