Skip to content

Commit 7efca51

Browse files
Added MatMul and parallel versions of the conv and bricked rnn layers
1 parent 6411d15 commit 7efca51

File tree

14 files changed

+1187
-564
lines changed

14 files changed

+1187
-564
lines changed

.gitattributes

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ c_reference/tests/kws/keyword_spotting_io_2.h filter=lfs diff=lfs merge=lfs -tex
6565
c_reference/tests/kws/keyword_spotting_io_3.h filter=lfs diff=lfs merge=lfs -text
6666
c_reference/tests/conv1d/conv1d_regular/conv_param.h filter=lfs diff=lfs merge=lfs -text
6767
c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h filter=lfs diff=lfs merge=lfs -text
68-
c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h filter=lfs diff=lfs merge=lfs -text
6968
c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h filter=lfs diff=lfs merge=lfs -text
7069
c_reference/tests/kws/precnn_params.h filter=lfs diff=lfs merge=lfs -text
7170
c_reference/tests/kws/postcnn_params.h filter=lfs diff=lfs merge=lfs -text

c_reference/include/conv1d.h

Lines changed: 97 additions & 43 deletions
Large diffs are not rendered by default.

c_reference/include/dscnn.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*,
1414
* @brief sub-layers : batchnorm1d -> conv1d_lr
1515
* @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers
1616
* @param[in] input_signal pointer to the input signal. size = in_time * in_channels
17+
* @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params)
1718
* @param[in] in_time number of time steps in the input_signal
1819
* @param[in] in_channels number of input channels
1920
* @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2
@@ -38,7 +39,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*,
3839
* 3: relu
3940
*/
4041
int phon_pred_lr_cnn(float* output_signal, float* input_signal,
41-
unsigned in_time, unsigned in_channels,
42+
conv_layer cnn, unsigned in_time, unsigned in_channels,
4243
const float* const mean, const float* const var,
4344
unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place,
4445
unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size,
@@ -49,6 +50,7 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal,
4950
* @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d
5051
* @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers
5152
* @param[in] input_signal pointer to the input signal. size = in_time * in_channels
53+
* @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params)
5254
* @param[in] in_time number of time steps in the input
5355
* @param[in] in_channels number of input channels
5456
* @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2

c_reference/include/rnn_bricked.h

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,65 @@
44
#ifndef __RNN_BRICKED_H__
55
#define __RNN_BRICKED_H__
66

7-
// Function pointer for the RNN to be passed as a parameter
8-
typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned,
9-
unsigned, const void*, void*, int, int);
7+
/* All the matrices are stored in the row major format
8+
9+
NOTES for using the layers
10+
-> Single-directional Computation
11+
While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints
12+
1) in_time % hop = 0
13+
2) fwd_window % hop = 0 and bwd_window % hop = 0
1014
11-
// NOTES for bi-direction
12-
// If bi_direction = 1, then actual rnn_output_dims is twice the rnn_hidden(rnn_hidden is output dims for each cell).
13-
// Each function will only process its given context(forward/backward).
14-
// The other context will need to be called separately with an appropriate offset.
15-
// E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...)
16-
// 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...)
17-
//
18-
// Each cell will only calculate half the hidden state i.e. rnn_hidden slots of memory from the start of the output pointer
19-
// Hence rnn_hidden is used as an offset for the backward pass. The offset for the forward pass is 0
20-
// This use of an offset is a way to exploit the nature of bi-direction to bypass the concatenation step typically associated with bi-directional passes
21-
//
22-
// Constraints
23-
// For Bi-Directional use, there are 3 constraints
24-
// 1) (in_time - fwd_window) % hop == 0 and (in_time - bwd_window) % hop == 0
25-
// 2) fwd_window % hop == 0 and bwd_window % hop == 0
26-
// 3) sample_first_brick and sample_last_brick = 1
27-
//
28-
// Violation of these constraints can lead to one of the following issues
29-
// 1) segmentation faults
30-
// 2) forward out_time != backward out_time
31-
// 3) mismatch between forward index and backward index during sampling i.e forward index 8 would correspond to backward index 6. This index error continues for all consecutive bricks
32-
// Hence, padding of the input and appropriate window choice is necessary
33-
//
34-
// These constraints can be ignored while performing uni-directional passes. However, it is favorable to follow constraints 1 and 2
15+
Violation of the above two constraints (1 & 2), will cause segmentation faults
16+
The layers first compute all the Wx steps and then compute Uh for all the windows parallelly
17+
Hence, the user needs to adhered to the constraints 1 & 2
3518
19+
-> Bi-directional Computation
20+
For bi-directional cases, there are 2 additionally constraints that would need to be followed
21+
A) sample_first_brick and sample_last_brick = 1
22+
B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call
23+
Each function will only process its given context(forward/backward). The other context will need to be called separately.
24+
E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...)
25+
2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...)
26+
27+
The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used
28+
Violating the conditions would cause index mis-matches or data corruption
29+
If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result
30+
If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only)
31+
*/
32+
33+
/**
34+
* @brief Model parameters for the 1D Convolution Layer
35+
* @var W1 pointer to first low-rank component of W. shape = [rank * in_dims]
36+
* @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank]
37+
* @var wRank rank of W matrix
38+
* @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden]
39+
* @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank]
40+
* @var uRank rank of U matrix
41+
* @var Bg pointer to bias for sigmoid
42+
* @var Bh pointer to bias for tanh
43+
* @var sigmoid_zeta first weight parameter for update from input from next step
44+
* @var sigmoid_nu second weight parameter for update from input from next step
45+
* @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x
46+
* @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x)
47+
* @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h
48+
* @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h)
49+
*/
50+
typedef struct BrickedFastGRNN_LR_Params {
51+
float* W1;
52+
float* W2;
53+
unsigned wRank;
54+
float* U1;
55+
float* U2;
56+
unsigned uRank;
57+
float* Bg;
58+
float* Bh;
59+
float sigmoid_zeta;
60+
float sigmoid_nu;
61+
unsigned block_size_w_to_lr;
62+
unsigned block_size_w_from_lr;
63+
unsigned block_size_u_to_lr;
64+
unsigned block_size_u_from_lr;
65+
} BrickedFastGRNN_LR_Params;
3666

3767
/** Forward Bricking and application of the forward RNN for an input signal
3868
* @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden
@@ -42,18 +72,16 @@ typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned,
4272
* @param[in] in_dims input dimensions
4373
* @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick)
4474
* @param[in] hop hop distance for between bricks
45-
* @param[in] rnn function pointer to the RNN
4675
* @param[in] params pointer to the parameters for the RNN
47-
* @param[in,out] buffers pointer to buffer for the RNN
4876
* @param[in] bi_direction determine if the ouput if for a bi-directional RNN.
4977
* @param[in] sample_first_brick determine if the 1st brick should also be sampled
5078
* -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1
5179
* -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1
5280
*/
53-
int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal,
54-
unsigned in_time, unsigned in_dims, unsigned window, unsigned hop,
55-
rnn_layer rnn, const void* params, void* buffers,
56-
unsigned bi_direction, unsigned sample_first_brick, int normalize);
81+
int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden,
82+
float* input_signal, unsigned in_time, unsigned in_dims,
83+
unsigned window, unsigned hop, const void* params,
84+
unsigned bi_direction, unsigned sample_first_brick);
5785

5886
/** Backward Bricking and application of the backward RNN for an input signal
5987
* @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden
@@ -63,18 +91,15 @@ int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_
6391
* @param[in] in_dims input dimensions
6492
* @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick)
6593
* @param[in] hop hop distance for between bricks
66-
* @param[in] rnn function pointer to the RNN
6794
* @param[in] params pointer to the parameters for the RNN
68-
* @param[in,out] buffers pointer to buffer for the RNN
6995
* @param[in] bi_direction determine if the ouput if for a bi-directional RNN.
7096
* @param[in] sample_last_brick determine if the last brick should also be sampled
7197
* -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1
7298
* -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1
7399
*/
74-
int backward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal,
75-
unsigned in_time, unsigned in_dims, unsigned window, unsigned hop,
76-
rnn_layer rnn, const void* params, void* buffers,
77-
unsigned bi_direction, unsigned sample_last_brick, int normalize);
78-
100+
int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden,
101+
float* input_signal, unsigned in_time, unsigned in_dims,
102+
unsigned window, unsigned hop, const void* params,
103+
unsigned bi_direction, unsigned sample_last_brick);
79104

80105
#endif

c_reference/include/utils.h

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,37 +31,83 @@ void matVec(const float* const mat, const float* const vec,
3131
float alpha, float beta,
3232
float* const ret);
3333

34-
/* Matrix-vector multiplication with a row offset
35-
This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis
36-
ret is of size nrows, vec is of size ncols
37-
mat is of size nrows * ncols, stored in row major
38-
depthwise is to change the matVec to depthwise specific convolutions
39-
row_stride is the offset factor between two adjacent rows
40-
Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped
41-
For a normal matVec case, this value will be ncols
42-
Eg : for a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. For this eg ncols will be 100 and row_stride will be 400
43-
vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals
44-
For a normal matVec case, this value will be 1
45-
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. So it's possible to enter a 400 length vector and consider every 4th element. For this ncols will be 100 and vec_stride will be 4*/
34+
/*
35+
Matrix-vector multiplication with a row offset
36+
This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis
37+
ret is of size nrows, vec is of size ncols
38+
mat is of size nrows * ncols, stored in row major
39+
depthwise is to change the matVec to depthwise specific convolutions
40+
row_stride is the offset factor between two adjacent rows
41+
Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped
42+
For a normal matVec case, this value will be ncols
43+
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
44+
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
45+
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
46+
For this eg ncols will be 100 and row_stride will be 400
47+
vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals
48+
For a normal matVec case, this value will be 1
49+
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
50+
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
51+
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
52+
So it's possible to enter a 400 length vector and consider every 4th element.
53+
So it's possible to enter a 400 length vector and consider every 4th element.
54+
So it's possible to enter a 400 length vector and consider every 4th element.
55+
For this ncols will be 100 and vec_stride will be 4
56+
*/
4657
void offset_matVec_conv1d(const float* mat, const float* vec,
4758
unsigned nrows, unsigned ncols,
4859
unsigned row_stride, unsigned vec_stride,
4960
unsigned depthwise, float* ret);
5061

51-
/* Scaled matrix-matrix multiplication: ret = alpha * ret + beta * matA * matB
52-
matA first matrix; size = nrows * ncommon
53-
matB second matrix; size = ncommon * ncols
54-
nrows number of rows in the first matrix
55-
ncommon number of columns in the first matrix/number of rows in the second matrix
56-
ncols number of columns in the second matrix
57-
alpha scaling factor for the previously-stored output matrix
58-
beta scaling factor for the result of the multiplication (matA * matB)
59-
ret matrix multiplication output
60-
*/
61-
void matMul(const float* const matA, const float* const matB,
62+
/*
63+
Tiled (cache-blocked) implementation of the Matrix Multiplication
64+
Note: If only the MatMul output is needed, then please use calloc to initialize the output
65+
An alternative is to use malloc, followed by memset 0
66+
There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix
67+
If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly
68+
This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed
69+
matA first matrix; shape = [nrows, ncommon]
70+
matB second matrix; shape = [ncommon, ncols]
71+
nrows number of rows in the first matrix
72+
ncommon number of columns in the first matrix/number of rows in the second matrix
73+
ncols number of columns in the second matrix
74+
total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored
75+
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
76+
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
77+
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
78+
ret matrix multiplication output. shape = [nrows, ncols]
79+
block_size tile/block size for optimal cache performance. A hardware specific parameter
80+
*/
81+
void tiledMatMul_float(const float* const matA, const float* const matB,
6282
unsigned nrows, unsigned ncommon, unsigned ncols,
63-
float alpha, float beta,
64-
float* const ret);
83+
unsigned total_comm_A, unsigned total_cols_B,
84+
float* const ret, unsigned block_size);
85+
86+
/*
87+
Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format
88+
The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage)
89+
Note: If only the MatMul output is needed, then please use calloc to initialize the output
90+
An alternative is to use malloc, followed by memset 0
91+
There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix
92+
If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly
93+
This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed
94+
matA first matrix; shape = [nrows, ncommon]
95+
matB second matrix; shape = [ncols, ncommon]
96+
nrows number of rows in the first matrix
97+
ncommon number of columns in the first matrix/number of rows in the second matrix
98+
ncols number of columns in the second matrix
99+
total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored
100+
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
101+
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
102+
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
103+
Since matB is transposed the columns are now the ncomm axis
104+
ret matrix multiplication output. shape = [nrows, ncols]
105+
block_size tile/block size for optimal cache performance. A hardware specific parameter
106+
*/
107+
void transposed_tiledMatMul(const float* const matA, const float* const matB,
108+
unsigned nrows, unsigned ncommon, unsigned ncols,
109+
unsigned total_comm_A, unsigned total_comm_B,
110+
float* const ret, unsigned block_size);
65111

66112
// scaled vector addition: ret = scalar1 * vec1 + scalar2 * vector2
67113
void v_add(float scalar1, const float* const vec1,

0 commit comments

Comments
 (0)