@@ -3649,6 +3649,60 @@ namespace dlib
36493649 template <typename SUBNET>
36503650 using reorg = add_layer<reorg_<2 , 2 >, SUBNET>;
36513651
3652+ // ----------------------------------------------------------------------------------------
3653+
3654+ class transpose_
3655+ {
3656+ /* !
3657+ WHAT THIS OBJECT REPRESENTS
3658+ This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
3659+ defined above. In particular, this layer performs a 2D matrix transposition
3660+ on each of the k planes within each sample of a 4D tensor.
3661+
3662+ The dimensions of the tensor output by this layer are as follows (letting
3663+ IN be the input tensor and OUT the output tensor):
3664+ - OUT.num_samples() == IN.num_samples()
3665+ - OUT.k() == IN.k()
3666+ - OUT.nr() == IN.nc()
3667+ - OUT.nc() == IN.nr()
3668+
3669+ The transposition is performed as follows:
3670+ - For each sample i and each k-plane j:
3671+ - OUT[i][j][r][c] = IN[i][j][c][r] for all r in [0, IN.nc()) and c in [0, IN.nr())
3672+
3673+ This layer does not have any learnable parameters.
3674+ !*/
3675+
3676+ public:
3677+
3678+ transpose_ () = default ;
3679+
3680+ template <typename SUBNET> void setup (const SUBNET& sub);
3681+ template <typename SUBNET> void forward (const SUBNET& sub, resizable_tensor& output);
3682+ template <typename SUBNET> void backward (const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
3683+
3684+ inline dpoint map_input_to_output (dpoint p) const ;
3685+ inline dpoint map_output_to_input (dpoint p) const ;
3686+
3687+ const tensor& get_layer_params () const ;
3688+ tensor& get_layer_params ();
3689+
3690+ friend void serialize (const transpose_& item, std::ostream& out);
3691+ friend void deserialize (transpose_& item, std::istream& in);
3692+
3693+ friend std::ostream& operator <<(std::ostream& out, const transpose_& item);
3694+ friend void to_xml (const transpose_& item, std::ostream& out);
3695+
3696+ /* !
3697+ These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
3698+ !*/
3699+ private:
3700+ resizable_tensor params; // unused
3701+ };
3702+
3703+ template <typename SUBNET>
3704+ using transpose = add_layer<transpose_, SUBNET>;
3705+
36523706// ----------------------------------------------------------------------------------------
36533707
36543708}
0 commit comments