6
6
#include < utility>
7
7
#include < vector>
8
8
9
- extern " C" {
10
- /* Creates a dummy empty _C module that can be imported from Python.
11
- The import from Python will load the .so associated with this extension
12
- built from this file, so that all the TORCH_LIBRARY calls below are run.*/
13
- PyObject *PyInit__C (void ) {
14
- static struct PyModuleDef module_def = {
15
- PyModuleDef_HEAD_INIT,
16
- " _C" , /* name of module */
17
- NULL , /* module documentation, may be NULL */
18
- -1 , /* size of per-interpreter state of the module,
19
- or -1 if the module keeps state in global variables. */
20
- NULL , /* methods */
21
- };
22
- return PyModule_Create (&module_def);
23
- }
9
+ extern " C"
10
+ {
11
+ /* Creates a dummy empty _C module that can be imported from Python.
12
+ The import from Python will load the .so associated with this extension
13
+ built from this file, so that all the TORCH_LIBRARY calls below are run.*/
14
+ PyObject *PyInit__C (void )
15
+ {
16
+ static struct PyModuleDef module_def = {
17
+ PyModuleDef_HEAD_INIT,
18
+ " _C" , /* name of module */
19
+ NULL , /* module documentation, may be NULL */
20
+ -1 , /* size of per-interpreter state of the module,
21
+ or -1 if the module keeps state in global variables. */
22
+ NULL , /* methods */
23
+ };
24
+ return PyModule_Create (&module_def);
25
+ }
24
26
}
25
27
26
28
template <typename scalar_t >
27
29
void scan_cpu (const at::Tensor &input, const at::Tensor &weights,
28
- const at::Tensor &initials, const at::Tensor &output) {
30
+ const at::Tensor &initials, const at::Tensor &output)
31
+ {
29
32
TORCH_CHECK (input.dim () == 2 , " Input must be 2D" );
30
33
TORCH_CHECK (initials.dim () == 1 , " Initials must be 1D" );
31
34
TORCH_CHECK (weights.sizes () == input.sizes (),
@@ -50,39 +53,33 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
50
53
auto T = input.size (1 );
51
54
auto total_size = input.numel ();
52
55
53
- std::pair<scalar_t , scalar_t > buffer[total_size];
54
-
55
56
const scalar_t *input_ptr = input_contiguous.const_data_ptr <scalar_t >();
56
57
const scalar_t *initials_ptr =
57
58
initials_contiguous.const_data_ptr <scalar_t >();
58
59
const scalar_t *weights_ptr = weights_contiguous.const_data_ptr <scalar_t >();
59
60
scalar_t *output_ptr = output.mutable_data_ptr <scalar_t >();
60
61
61
- std::transform (weights_ptr, weights_ptr + total_size, input_ptr, buffer,
62
- [](const scalar_t &a, const scalar_t &b) {
63
- return std::make_pair (a, b);
64
- });
65
-
66
- at::parallel_for (0 , n_batch, 1 , [&](int64_t start, int64_t end) {
67
- for (auto b = start; b < end; b++) {
68
- std::inclusive_scan (
69
- buffer + b * T, buffer + (b + 1 ) * T, buffer + b * T,
70
- [](const std::pair<scalar_t , scalar_t > &a,
71
- const std::pair<scalar_t , scalar_t > &b) {
72
- return std::make_pair (a.first * b.first ,
73
- a.second * b.first + b.second );
74
- },
75
- std::make_pair ((scalar_t )1.0 , initials_ptr[b]));
76
- }
77
- });
78
-
79
- std::transform (
80
- buffer, buffer + total_size, output_ptr,
81
- [](const std::pair<scalar_t , scalar_t > &a) { return a.second ; });
62
+ at::parallel_for (0 , n_batch, 1 , [&](int64_t start, int64_t end)
63
+ {
64
+ for (auto b = start; b < end; b++)
65
+ {
66
+ auto initial = initials_ptr[b];
67
+ auto weights_offset = weights_ptr + b * T;
68
+ auto input_offset = input_ptr + b * T;
69
+ auto output_offset = output_ptr + b * T;
70
+ for (int64_t t = 0 ; t < T; t++)
71
+ {
72
+ auto w = weights_offset[t];
73
+ auto x = input_offset[t];
74
+ initial = initial * w + x;
75
+ output_offset[t] = initial;
76
+ }
77
+ }; });
82
78
}
83
79
84
80
template <typename scalar_t >
85
- void lpc_cpu_core (const torch::Tensor &a, const torch::Tensor &padded_out) {
81
+ void lpc_cpu_core (const torch::Tensor &a, const torch::Tensor &padded_out)
82
+ {
86
83
// Ensure input dimensions are correct
87
84
TORCH_CHECK (a.dim () == 3 , " a must be 3-dimensional" );
88
85
TORCH_CHECK (padded_out.dim () == 2 , " out must be 2-dimensional" );
@@ -106,24 +103,27 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
106
103
const scalar_t *a_ptr = a_contiguous.const_data_ptr <scalar_t >();
107
104
scalar_t *out_ptr = padded_out.mutable_data_ptr <scalar_t >();
108
105
109
- at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end) {
110
- for (auto b = start; b < end; b++) {
111
- auto out_offset = b * (T + order) + order;
112
- auto a_offset = b * T * order;
113
- for (int64_t t = 0 ; t < T; t++) {
114
- scalar_t y = out_ptr[out_offset + t];
115
- for (int64_t i = 0 ; i < order; i++) {
116
- y -= a_ptr[a_offset + t * order + i] *
117
- out_ptr[out_offset + t - i - 1 ];
106
+ at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end)
107
+ {
108
+ for (auto b = start; b < end; b++)
109
+ {
110
+ auto out_offset = out_ptr + b * (T + order) + order;
111
+ auto a_offset = a_ptr + b * T * order;
112
+ for (int64_t t = 0 ; t < T; t++)
113
+ {
114
+ scalar_t y = out_offset[t];
115
+ for (int64_t i = 0 ; i < order; i++)
116
+ {
117
+ y -= a_offset[t * order + i] * out_offset [t - i - 1 ];
118
118
}
119
- out_ptr[ out_offset + t] = y;
119
+ out_offset[ t] = y;
120
120
}
121
- }
122
- });
121
+ }; });
123
122
}
124
123
125
124
at::Tensor scan_cpu_wrapper (const at::Tensor &input, const at::Tensor &weights,
126
- const at::Tensor &initials) {
125
+ const at::Tensor &initials)
126
+ {
127
127
TORCH_CHECK (input.is_floating_point () || input.is_complex (),
128
128
" Input must be floating point or complex" );
129
129
TORCH_CHECK (initials.scalar_type () == input.scalar_type (),
@@ -135,12 +135,14 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
135
135
136
136
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
137
137
input.scalar_type (), " scan_cpu" ,
138
- [&] { scan_cpu<scalar_t >(input, weights, initials, output); });
138
+ [&]
139
+ { scan_cpu<scalar_t >(input, weights, initials, output); });
139
140
return output;
140
141
}
141
142
142
143
at::Tensor lpc_cpu (const at::Tensor &x, const at::Tensor &a,
143
- const at::Tensor &zi) {
144
+ const at::Tensor &zi)
145
+ {
144
146
TORCH_CHECK (x.is_floating_point () || x.is_complex (),
145
147
" Input must be floating point or complex" );
146
148
TORCH_CHECK (a.scalar_type () == x.scalar_type (),
@@ -156,16 +158,19 @@ at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
156
158
auto out = at::cat ({zi.flip (1 ), x}, 1 ).contiguous ();
157
159
158
160
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
159
- x.scalar_type (), " lpc_cpu" , [&] { lpc_cpu_core<scalar_t >(a, out); });
161
+ x.scalar_type (), " lpc_cpu" , [&]
162
+ { lpc_cpu_core<scalar_t >(a, out); });
160
163
return out.slice (1 , zi.size (1 ), out.size (1 )).contiguous ();
161
164
}
162
165
163
- TORCH_LIBRARY (torchlpc, m) {
166
+ TORCH_LIBRARY (torchlpc, m)
167
+ {
164
168
m.def (" torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor" );
165
169
m.def (" torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor" );
166
170
}
167
171
168
- TORCH_LIBRARY_IMPL (torchlpc, CPU, m) {
172
+ TORCH_LIBRARY_IMPL (torchlpc, CPU, m)
173
+ {
169
174
m.impl (" scan" , &scan_cpu_wrapper);
170
175
m.impl (" lpc" , &lpc_cpu);
171
176
}
0 commit comments