Skip to content

Commit 72822fe

Browse files
Cydralarrufatdavisking
authored
Fix stride indexing bugs in reorg and reorg_gradient functions (CPU & CUDA) (davisking#3012)
* Fix Stride Indexing Bugs in `reorg` and `reorg_gradient` Functions (CPU & CUDA) and Add `add_to` Parameter * 'add_to' parameter missing in cuda call reorg_gradient.launch_kernel() * Cleanup: remove using namespace std; (davisking#3016) * remove using namespace std from headers * more std:: * more std:: * more std:: on windows stuff * remove uses of using namespace std::chrono * do not use C++17 features * Add Davis suggestion * revert some more stuff * revert removing include * more std::chrono stuff * fix build error * Adjust comment formatting to be like other dlib comments --------- Co-authored-by: Adrià <1671644+arrufat@users.noreply.github.com> Co-authored-by: Davis King <davis@dlib.net>
1 parent 90c8d78 commit 72822fe

File tree

9 files changed

+190
-129
lines changed

9 files changed

+190
-129
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,58 +2333,67 @@ namespace dlib
23332333

23342334
// ----------------------------------------------------------------------------------------
23352335

2336-
void reorg (
2336+
void reorg(
2337+
bool add_to,
23372338
tensor& dest,
23382339
const int row_stride,
23392340
const int col_stride,
23402341
const tensor& src
23412342
)
23422343
{
2343-
DLIB_CASSERT(is_same_object(dest, src)==false);
2344-
DLIB_CASSERT(src.nr() % row_stride == 0);
2345-
DLIB_CASSERT(src.nc() % col_stride == 0);
2346-
DLIB_CASSERT(dest.num_samples() == src.num_samples());
2347-
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
2348-
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
2349-
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
2344+
DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects.");
2345+
DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride.");
2346+
DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride.");
2347+
DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match.");
2348+
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match.");
2349+
DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match.");
2350+
DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match.");
2351+
23502352
const float* s = src.host();
23512353
float* d = dest.host();
23522354

2353-
parallel_for(0, dest.num_samples(), [&](long n)
2355+
const size_t sk = src.k(), snr = src.nr(), snc = src.nc();
2356+
const size_t dk = dest.k(), dnr = dest.nr(), dnc = dest.nc(), dsize = dest.size();
2357+
2358+
dlib::parallel_for(0, dsize, [&](long i)
23542359
{
2355-
for (long k = 0; k < dest.k(); ++k)
2356-
{
2357-
for (long r = 0; r < dest.nr(); ++r)
2358-
{
2359-
for (long c = 0; c < dest.nc(); ++c)
2360-
{
2361-
const auto out_idx = tensor_index(dest, n, k, r, c);
2362-
const auto in_idx = tensor_index(src,
2363-
n,
2364-
k % src.k(),
2365-
r * row_stride + (k / src.k()) / row_stride,
2366-
c * col_stride + (k / src.k()) % col_stride);
2367-
d[out_idx] = s[in_idx];
2368-
}
2369-
}
2370-
}
2360+
const size_t out_plane_size = dnr * dnc;
2361+
const size_t out_sample_size = dk * out_plane_size;
2362+
2363+
const size_t n = i / out_sample_size;
2364+
const size_t out_idx = i % out_sample_size;
2365+
const size_t out_k = out_idx / out_plane_size;
2366+
const size_t out_rc = out_idx % out_plane_size;
2367+
const size_t out_r = out_rc / dnc;
2368+
const size_t out_c = out_rc % dnc;
2369+
2370+
const size_t in_k = out_k % sk;
2371+
const size_t in_r = out_r * row_stride + (out_k / sk) / col_stride;
2372+
const size_t in_c = out_c * col_stride + (out_k / sk) % col_stride;
2373+
2374+
const size_t in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;
2375+
2376+
if (add_to) d[i] += s[in_idx];
2377+
else d[i] = s[in_idx];
23712378
});
23722379
}
23732380

2374-
void reorg_gradient (
2381+
void reorg_gradient(
2382+
bool add_to,
23752383
tensor& grad,
23762384
const int row_stride,
23772385
const int col_stride,
23782386
const tensor& gradient_input
23792387
)
23802388
{
2381-
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
2382-
DLIB_CASSERT(grad.nr() % row_stride == 0);
2383-
DLIB_CASSERT(grad.nc() % col_stride == 0);
2384-
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
2385-
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
2386-
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
2387-
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
2389+
DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects.");
2390+
DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride.");
2391+
DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride.");
2392+
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match.");
2393+
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride.");
2394+
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride.");
2395+
DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride.");
2396+
23882397
const float* gi = gradient_input.host();
23892398
float* g = grad.host();
23902399

@@ -2396,13 +2405,15 @@ namespace dlib
23962405
{
23972406
for (long c = 0; c < gradient_input.nc(); ++c)
23982407
{
2399-
const auto in_idx = tensor_index(gradient_input, n, k, r, c);
2400-
const auto out_idx = tensor_index(grad,
2401-
n,
2402-
k % grad.k(),
2403-
r * row_stride + (k / grad.k()) / row_stride,
2404-
c * col_stride + (k / grad.k()) % col_stride);
2405-
g[out_idx] += gi[in_idx];
2408+
const auto in_idx = tensor_index(gradient_input, n, k, r, c);
2409+
const auto out_idx = tensor_index(grad,
2410+
n,
2411+
k % grad.k(),
2412+
r * row_stride + (k / grad.k()) / col_stride,
2413+
c * col_stride + (k / grad.k()) % col_stride);
2414+
2415+
if (add_to) g[out_idx] += gi[in_idx];
2416+
else g[out_idx] = gi[in_idx];
24062417
}
24072418
}
24082419
}

dlib/cuda/cpu_dlib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,15 @@ namespace dlib
502502
// -----------------------------------------------------------------------------------
503503

504504
void reorg (
505+
bool add_to,
505506
tensor& dest,
506507
const int row_stride,
507508
const int col_stride,
508509
const tensor& src
509510
);
510511

511512
void reorg_gradient (
513+
bool add_to,
512514
tensor& grad,
513515
const int row_stride,
514516
const int col_stride,

dlib/cuda/cuda_dlib.cu

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,86 +2001,91 @@ namespace dlib
20012001

20022002
__global__ void _cuda_reorg(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d,
20032003
size_t sk, size_t snr, int snc, const float* s,
2004-
const size_t row_stride, const size_t col_stride)
2004+
const size_t row_stride, const size_t col_stride, const bool add_to)
20052005
{
20062006
const auto out_plane_size = dnr * dnc;
2007-
const auto sample_size = dk * out_plane_size;
2008-
for(auto i : grid_stride_range(0, dsize))
2007+
const auto out_sample_size = dk * out_plane_size;
2008+
for (auto i : grid_stride_range(0, dsize))
20092009
{
2010-
const auto n = i / sample_size;
2011-
const auto idx = i % out_plane_size;
2012-
const auto out_k = (i / out_plane_size) % dk;
2013-
const auto out_r = idx / dnc;
2014-
const auto out_c = idx % dnc;
2010+
const auto n = i / out_sample_size;
2011+
const auto out_idx = i % out_sample_size;
2012+
const auto out_k = out_idx / out_plane_size;
2013+
const auto out_rc = out_idx % out_plane_size;
2014+
const auto out_r = out_rc / dnc;
2015+
const auto out_c = out_rc % dnc;
20152016

20162017
const auto in_k = out_k % sk;
2017-
const auto in_r = out_r * row_stride + (out_k / sk) / row_stride;
2018+
const auto in_r = out_r * row_stride + (out_k / sk) / col_stride;
20182019
const auto in_c = out_c * col_stride + (out_k / sk) % col_stride;
20192020

20202021
const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;
2021-
d[i] = s[in_idx];
2022+
if (add_to) d[i] += s[in_idx];
2023+
else d[i] = s[in_idx];
20222024
}
20232025
}
2026+
20242027
__global__ void _cuda_reorg_gradient(size_t ssize, size_t dk, size_t dnr, size_t dnc, float* d,
2025-
size_t sk, size_t snr, int snc, const float* s,
2026-
const size_t row_stride, const size_t col_stride)
2028+
size_t sk, size_t snr, int snc, const float* s, const size_t row_stride,
2029+
const size_t col_stride, const bool add_to
2030+
)
20272031
{
2028-
const auto in_plane_size = snr * snc;
2029-
const auto sample_size = sk * in_plane_size;
20302032
for(auto i : grid_stride_range(0, ssize))
20312033
{
2032-
const auto n = i / sample_size;
2033-
const auto idx = i % in_plane_size;
2034-
const auto in_k = (i / in_plane_size) % sk;
2035-
const auto in_r = idx / snc;
2036-
const auto in_c = idx % snc;
2034+
const auto n = i / (sk * snr * snc);
2035+
const auto sample_idx = i % (sk * snr * snc);
2036+
const auto in_k = (sample_idx / (snr * snc)) % sk;
2037+
const auto in_r = (sample_idx / snc) % snr;
2038+
const auto in_c = sample_idx % snc;
20372039

20382040
const auto out_k = in_k % dk;
2039-
const auto out_r = in_r * row_stride + (in_k / dk) / row_stride;
2041+
const auto out_r = in_r * row_stride + (in_k / dk) / col_stride;
20402042
const auto out_c = in_c * col_stride + (in_k / dk) % col_stride;
2041-
20422043
const auto out_idx = ((n * dk + out_k) * dnr + out_r) * dnc + out_c;
2043-
d[out_idx] += s[i];
2044+
2045+
if (add_to) d[out_idx] += s[i];
2046+
else d[out_idx] = s[i];
20442047
}
20452048
}
20462049

2047-
void reorg (
2050+
void reorg(
2051+
bool add_to,
20482052
tensor& dest,
20492053
const int row_stride,
20502054
const int col_stride,
20512055
const tensor& src
20522056
)
20532057
{
2054-
DLIB_CASSERT(is_same_object(dest, src)==false);
2055-
DLIB_CASSERT(src.nr() % row_stride == 0);
2056-
DLIB_CASSERT(src.nc() % col_stride == 0);
2057-
DLIB_CASSERT(dest.num_samples() == src.num_samples());
2058-
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
2059-
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
2060-
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
2058+
DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects.");
2059+
DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride.");
2060+
DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride.");
2061+
DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match.");
2062+
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match.");
2063+
DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match.");
2064+
DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match.");
20612065

20622066
launch_kernel(_cuda_reorg, dest.size(), dest.k(), dest.nr(), dest.nc(), dest.device(),
2063-
src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride);
2067+
src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride, add_to);
20642068
}
20652069

2066-
void reorg_gradient (
2070+
void reorg_gradient(
2071+
bool add_to,
20672072
tensor& grad,
20682073
const int row_stride,
20692074
const int col_stride,
20702075
const tensor& gradient_input
20712076
)
20722077
{
2073-
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
2074-
DLIB_CASSERT(grad.nr() % row_stride == 0);
2075-
DLIB_CASSERT(grad.nc() % col_stride == 0);
2076-
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
2077-
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
2078-
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
2079-
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
2078+
DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects.");
2079+
DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride.");
2080+
DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride.");
2081+
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match.");
2082+
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride.");
2083+
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride.");
2084+
DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride.");
20802085

20812086
launch_kernel(_cuda_reorg_gradient, gradient_input.size(), grad.k(), grad.nr(), grad.nc(), grad.device(),
2082-
gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(),
2083-
row_stride, col_stride);
2087+
gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(),
2088+
row_stride, col_stride, add_to);
20842089
}
20852090

20862091
// ----------------------------------------------------------------------------------------

dlib/cuda/cuda_dlib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,13 +546,15 @@ namespace dlib
546546
// ----------------------------------------------------------------------------------------
547547

548548
void reorg (
549+
bool add_to,
549550
tensor& dest,
550551
const int row_stride,
551552
const int col_stride,
552553
const tensor& src
553554
);
554555

555556
void reorg_gradient (
557+
bool add_to,
556558
tensor& grad,
557559
const int row_stride,
558560
const int col_stride,

dlib/cuda/tensor_tools.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,30 +1219,32 @@ namespace dlib { namespace tt
12191219
// ------------------------------------------------------------------------------------
12201220

12211221
void reorg (
1222+
bool add_to,
12221223
tensor& dest,
12231224
const int row_stride,
12241225
const int col_stride,
12251226
const tensor& src
12261227
)
12271228
{
12281229
#ifdef DLIB_USE_CUDA
1229-
cuda::reorg(dest, row_stride, col_stride, src);
1230+
cuda::reorg(add_to, dest, row_stride, col_stride, src);
12301231
#else
1231-
cpu::reorg(dest, row_stride, col_stride, src);
1232+
cpu::reorg(add_to, dest, row_stride, col_stride, src);
12321233
#endif
12331234
}
12341235

12351236
void reorg_gradient (
1237+
bool add_to,
12361238
tensor& grad,
12371239
const int row_stride,
12381240
const int col_stride,
12391241
const tensor& gradient_input
12401242
)
12411243
{
12421244
#ifdef DLIB_USE_CUDA
1243-
cuda::reorg_gradient(grad, row_stride, col_stride, gradient_input);
1245+
cuda::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input);
12441246
#else
1245-
cpu::reorg_gradient(grad, row_stride, col_stride, gradient_input);
1247+
cpu::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input);
12461248
#endif
12471249
}
12481250

0 commit comments

Comments
 (0)