Skip to content

Commit 90c8d78

Browse files
authored
Add Transpose Layer for 2D matrix transposition in 4D tensors (davisking#3013)
1 parent fe6e052 commit 90c8d78

File tree

10 files changed

+298
-1
lines changed

10 files changed

+298
-1
lines changed

dlib/cuda/cpu_dlib.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2927,7 +2927,47 @@ namespace dlib
29272927
}
29282928

29292929
// ------------------------------------------------------------------------------------
2930-
// ------------------------------------------------------------------------------------
2930+
2931+
void transpose(
2932+
bool add,
2933+
tensor& dest,
2934+
const tensor& src
2935+
)
2936+
{
2937+
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
2938+
dest.k() == src.k() &&
2939+
dest.nr() == src.nc() &&
2940+
dest.nc() == src.nr(),
2941+
"Incompatible tensor dimensions.");
2942+
2943+
const float* src_data = src.host();
2944+
float* dest_data = dest.host();
2945+
2946+
const long num_samples = src.num_samples();
2947+
const long k_dim = src.k();
2948+
const long src_nr = src.nr();
2949+
const long src_nc = src.nc();
2950+
const long dest_nr = dest.nr();
2951+
const long dest_nc = dest.nc();
2952+
2953+
parallel_for(0, num_samples * k_dim, [&](long i) {
2954+
const long n = i / k_dim;
2955+
const long k = i % k_dim;
2956+
const long src_nk_offset = (n * src.k() + k) * src_nr;
2957+
const long dest_nk_offset = (n * dest.k() + k) * dest_nr;
2958+
2959+
for (long r = 0; r < src_nr; ++r) {
2960+
for (long c = 0; c < src_nc; ++c) {
2961+
const long src_idx = (src_nk_offset + r) * src_nc + c;
2962+
const long dest_idx = (dest_nk_offset + c) * dest_nc + r;
2963+
2964+
if (add) dest_data[dest_idx] += src_data[src_idx];
2965+
else dest_data[dest_idx] = src_data[src_idx];
2966+
}
2967+
}
2968+
});
2969+
}
2970+
29312971
// ------------------------------------------------------------------------------------
29322972

29332973
}

dlib/cuda/cpu_dlib.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,14 @@ namespace dlib
671671
size_t count_k
672672
);
673673

674+
// -----------------------------------------------------------------------------------
675+
676+
void transpose(
677+
bool add_to,
678+
tensor& dest,
679+
const tensor& src
680+
);
681+
674682
// -----------------------------------------------------------------------------------
675683

676684
class compute_loss_binary_log_per_pixel

dlib/cuda/cuda_dlib.cu

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,46 @@ namespace dlib
25002500

25012501
// ----------------------------------------------------------------------------------------
25022502

2503+
__global__ void _cuda_transpose(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d,
2504+
size_t sk, size_t snr, int snc, const float* s, const bool add_to)
2505+
{
2506+
const auto plane_size = dnr * dnc;
2507+
const auto sample_size = dk * plane_size;
2508+
for (auto i : grid_stride_range(0, dsize))
2509+
{
2510+
const auto n = i / sample_size;
2511+
const auto idx = i % plane_size;
2512+
const auto in_k = (i / plane_size) % dk;
2513+
const auto in_r = idx % dnc;
2514+
const auto in_c = idx / dnc;
2515+
2516+
const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;
2517+
if (add_to) d[i] += s[in_idx];
2518+
else d[i] = s[in_idx];
2519+
}
2520+
}
2521+
2522+
void transpose(
2523+
bool add_to,
2524+
tensor& dest,
2525+
const tensor& src
2526+
)
2527+
{
2528+
DLIB_CASSERT(is_same_object(dest, src) == false);
2529+
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
2530+
dest.k() == src.k() &&
2531+
dest.nr() == src.nc() &&
2532+
dest.nc() == src.nr(),
2533+
"Incompatible tensor dimensions.");
2534+
2535+
launch_kernel(_cuda_transpose, max_jobs(dest.size()), dest.size(),
2536+
dest.k(), dest.nr(), dest.nc(), dest.device(),
2537+
src.k(), src.nr(), src.nc(), src.device(), add_to);
2538+
}
2539+
2540+
// ----------------------------------------------------------------------------------------
2541+
2542+
25032543
__device__ float cuda_log1pexp(float x)
25042544
{
25052545
if (x <= -18)

dlib/cuda/cuda_dlib.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,13 @@ namespace dlib
570570
size_t count_k
571571
);
572572

573+
// ----------------------------------------------------------------------------------------
574+
575+
void transpose(
576+
bool add_to,
577+
tensor& dest,
578+
const tensor& src
579+
);
573580

574581
// ----------------------------------------------------------------------------------------
575582

dlib/cuda/tensor_tools.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,21 @@ namespace dlib { namespace tt
12791279
#endif
12801280
}
12811281

1282+
// ----------------------------------------------------------------------------------------
1283+
1284+
void transpose(
1285+
bool add_to,
1286+
tensor& dest,
1287+
const tensor& src
1288+
)
1289+
{
1290+
#ifdef DLIB_USE_CUDA
1291+
cuda::transpose(add_to, dest, src);
1292+
#else
1293+
cpu::transpose(add_to, dest, src);
1294+
#endif
1295+
}
1296+
12821297
// ----------------------------------------------------------------------------------------
12831298

12841299
}}

dlib/cuda/tensor_tools.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,32 @@ namespace dlib { namespace tt
21862186
i.e., copies content of each sample from src in to corresponding place of sample at dest.
21872187
!*/
21882188

2189+
// ----------------------------------------------------------------------------------------
2190+
2191+
void transpose(
2192+
bool add_to,
2193+
tensor& dest,
2194+
const tensor& src
2195+
);
2196+
/*!
2197+
requires
2198+
- is_same_object(dest, src) == false
2199+
- dest.num_samples() == src.num_samples()
2200+
- dest.k() == src.k()
2201+
- dest.nr() == src.nc()
2202+
- dest.nc() == src.nr()
2203+
ensures
2204+
- Performs a transpose operation on the nr() x nc() matrices within src.
2205+
- If (add_to) is false:
2206+
- The result is stored in dest, overwriting its previous contents.
2207+
- For all valid n, k, r, c:
2208+
- #dest(n,k,c,r) == src(n,k,r,c)
2209+
- If (add_to) is true:
2210+
- The result is added to the existing contents of dest.
2211+
- For all valid n, k, r, c:
2212+
- #dest(n,k,c,r) == dest(n,k,c,r) + src(n,k,r,c)
2213+
!*/
2214+
21892215
// ----------------------------------------------------------------------------------------
21902216

21912217
}}

dlib/dnn/layers.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4635,6 +4635,67 @@ namespace dlib
46354635
template <typename SUBNET>
46364636
using reorg = add_layer<reorg_<2, 2>, SUBNET>;
46374637

4638+
// ----------------------------------------------------------------------------------------
4639+
4640+
class transpose_ {
4641+
public:
4642+
transpose_() {}
4643+
template <typename SUBNET> void setup(const SUBNET& /* sub */) {}
4644+
4645+
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output) {
4646+
auto& prev = sub.get_output();
4647+
4648+
output.set_size(prev.num_samples(), prev.k(), prev.nc(), prev.nr());
4649+
tt::transpose(false, output, prev);
4650+
}
4651+
4652+
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) {
4653+
auto& prev = sub.get_gradient_input();
4654+
tt::transpose(true, prev, gradient_input);
4655+
}
4656+
4657+
inline dpoint map_input_to_output(dpoint p) const
4658+
{
4659+
dpoint temp_p;
4660+
temp_p.x() = p.y();
4661+
temp_p.y() = p.x();
4662+
return temp_p;
4663+
}
4664+
inline dpoint map_output_to_input(dpoint p) const
4665+
{
4666+
dpoint temp_p;
4667+
temp_p.x() = p.y();
4668+
temp_p.y() = p.x();
4669+
return temp_p;
4670+
}
4671+
4672+
const tensor& get_layer_params() const { return params; }
4673+
tensor& get_layer_params() { return params; }
4674+
4675+
friend void serialize(const transpose_& /* item */, std::ostream& out) {
4676+
serialize("transpose_", out);
4677+
}
4678+
friend void deserialize(transpose_& /* item */, std::istream& in) {
4679+
std::string version;
4680+
deserialize(version, in);
4681+
if (version != "transpose_")
4682+
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::transpose_.");
4683+
}
4684+
4685+
friend std::ostream& operator<<(std::ostream& out, const transpose_& /* item */) {
4686+
out << "transpose";
4687+
return out;
4688+
}
4689+
friend void to_xml(const transpose_& /* item */, std::ostream& out) {
4690+
out << "<transpose />\n";
4691+
}
4692+
4693+
private:
4694+
dlib::resizable_tensor params; // unused
4695+
};
4696+
4697+
template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;
4698+
46384699
// ----------------------------------------------------------------------------------------
46394700

46404701
}

dlib/dnn/layers_abstract.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,6 +3649,60 @@ namespace dlib
36493649
template <typename SUBNET>
36503650
using reorg = add_layer<reorg_<2, 2>, SUBNET>;
36513651

3652+
// ----------------------------------------------------------------------------------------
3653+
3654+
class transpose_
3655+
{
3656+
/*!
3657+
WHAT THIS OBJECT REPRESENTS
3658+
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
3659+
defined above. In particular, this layer performs a 2D matrix transposition
3660+
on each of the k planes within each sample of a 4D tensor.
3661+
3662+
The dimensions of the tensor output by this layer are as follows (letting
3663+
IN be the input tensor and OUT the output tensor):
3664+
- OUT.num_samples() == IN.num_samples()
3665+
- OUT.k() == IN.k()
3666+
- OUT.nr() == IN.nc()
3667+
- OUT.nc() == IN.nr()
3668+
3669+
The transposition is performed as follows:
3670+
- For each sample i and each k-plane j:
3671+
- OUT[i][j][r][c] = IN[i][j][c][r] for all r in [0, IN.nc()) and c in [0, IN.nr())
3672+
3673+
This layer does not have any learnable parameters.
3674+
!*/
3675+
3676+
public:
3677+
3678+
transpose_() = default;
3679+
3680+
template <typename SUBNET> void setup (const SUBNET& sub);
3681+
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
3682+
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
3683+
3684+
inline dpoint map_input_to_output(dpoint p) const;
3685+
inline dpoint map_output_to_input(dpoint p) const;
3686+
3687+
const tensor& get_layer_params() const;
3688+
tensor& get_layer_params();
3689+
3690+
friend void serialize(const transpose_& item, std::ostream& out);
3691+
friend void deserialize(transpose_& item, std::istream& in);
3692+
3693+
friend std::ostream& operator<<(std::ostream& out, const transpose_& item);
3694+
friend void to_xml(const transpose_& item, std::ostream& out);
3695+
3696+
/*!
3697+
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
3698+
!*/
3699+
private:
3700+
resizable_tensor params; // unused
3701+
};
3702+
3703+
template <typename SUBNET>
3704+
using transpose = add_layer<transpose_, SUBNET>;
3705+
36523706
// ----------------------------------------------------------------------------------------
36533707

36543708
}

dlib/dnn/visitors.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,14 @@ namespace dlib
10211021
update(i);
10221022
}
10231023

1024+
template <typename U, typename E>
1025+
void operator()(size_t i, const add_layer<transpose_, U, E>&)
1026+
{
1027+
start_node(i, "transpose");
1028+
end_node();
1029+
update(i);
1030+
}
1031+
10241032
template <typename T, typename U, typename E>
10251033
void operator()(size_t i, const add_layer<T, U, E>&)
10261034
{

dlib/test/dnn.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,37 @@ namespace
748748
#endif
749749
}
750750

751+
// ----------------------------------------------------------------------------------------
752+
753+
void test_transpose()
754+
{
755+
const long num_samples = 2;
756+
const long k = 3;
757+
const long nr = 4;
758+
const long nc = 5;
759+
760+
resizable_tensor input(num_samples, k, nr, nc);
761+
resizable_tensor output_cpu_a(num_samples, k, nc, nr);
762+
tt::tensor_rand rnd(0);
763+
rnd.fill_uniform(input);
764+
resizable_tensor output_cpu_b(input);
765+
766+
cpu::transpose(false, output_cpu_a, input);
767+
cpu::transpose(true, output_cpu_b, output_cpu_a);
768+
input *= 2;
769+
DLIB_TEST(max(abs(mat(output_cpu_b) - mat(input))) < 1e-5);
770+
771+
#ifdef DLIB_USE_CUDA
772+
input /= 2;
773+
resizable_tensor output_cuda_a, output_cuda_b(input);
774+
output_cuda_a.copy_size(output_cpu_a);
775+
cuda::transpose(false, output_cuda_a, input);
776+
cuda::transpose(true, output_cuda_b, output_cuda_a);
777+
DLIB_TEST(max(abs(mat(output_cpu_a) - mat(output_cuda_a))) < 1e-5);
778+
DLIB_TEST(max(abs(mat(output_cpu_b) - mat(output_cuda_b))) < 1e-5);
779+
#endif
780+
}
781+
751782
// ----------------------------------------------------------------------------------------
752783

753784
void test_basic_tensor_ops()
@@ -2280,6 +2311,12 @@ namespace
22802311
auto res = test_layer(l);
22812312
DLIB_TEST_MSG(res, res);
22822313
}
2314+
{
2315+
print_spinner();
2316+
transpose_ l;
2317+
auto res = test_layer(l);
2318+
DLIB_TEST_MSG(res, res);
2319+
}
22832320
}
22842321

22852322
// ----------------------------------------------------------------------------------------
@@ -4489,6 +4526,7 @@ namespace
44894526
test_batch_normalize_conv();
44904527
test_layer_normalize();
44914528
test_rms_normalize();
4529+
test_transpose();
44924530
test_basic_tensor_ops();
44934531
test_layers();
44944532
test_visit_functions();

0 commit comments

Comments
 (0)