Skip to content

Commit 4488dc6

Browse files
authored
feat: fallback to use openmp for scan (#31)
* fix: update build-backend to use legacy mode and re-enable branch formatting in versioning * fix: simplify event triggers in version workflow * fix: increase verbosity of version display in workflow * fix: add numpy to dependencies in version workflow * fix: add numba to dependencies in version workflow and pyproject.toml * fix: update license format in pyproject.toml * fix: refactor scan_cpu function for improved performance * refactor lpc_cpu * fix: update clang++ version in build step to llvm@18 to match macos-latest * fix: update LDFLAGS and CPPFLAGS paths for libomp in macOS build step
1 parent be99c73 commit 4488dc6

File tree

4 files changed

+78
-76
lines changed

4 files changed

+78
-76
lines changed

.github/workflows/python-package.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ jobs:
7777
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
7878
- name: Build CPP extension with clang++
7979
run: |
80-
export CXX=$(brew --prefix llvm@15)/bin/clang++
81-
export LDFLAGS="-L/usr/local/opt/libomp/lib"
82-
export CPPFLAGS="-I/usr/local/opt/libomp/include"
80+
export CXX=$(brew --prefix llvm@18)/bin/clang++
81+
export LDFLAGS="-L/opt/homebrew/opt/libomp/lib"
82+
export CPPFLAGS="-I/opt/homebrew/opt/libomp/include"
8383
pip install -e .[dev]
8484
- name: Test with pytest
8585
run: |

.github/workflows/version.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
name: Display version
22

3-
on:
4-
push:
5-
branches: [ "dev", "main", "alpha", "beta" ]
6-
pull_request:
7-
branches: [ "dev", "main", "alpha", "beta" ]
3+
on: [push, pull_request]
84

95
permissions:
106
contents: read
@@ -26,8 +22,8 @@ jobs:
2622
- name: Install dependencies
2723
run: |
2824
python -m pip install --upgrade pip
29-
pip install build "setuptools-git-versioning>=2,<3"
25+
pip install build "setuptools-git-versioning>=2,<3" numpy numba
3026
pip install torch --index-url https://download.pytorch.org/whl/cpu
3127
- name: Display version
3228
run: |
33-
setuptools-git-versioning -v >> $GITHUB_STEP_SUMMARY
29+
setuptools-git-versioning -vv >> $GITHUB_STEP_SUMMARY

pyproject.toml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@ requires = [
44
"setuptools-git-versioning>=2.0,<3",
55
"wheel",
66
"torch",
7+
"numba",
78
"numpy",
89
]
9-
build-backend = "setuptools.build_meta"
10+
build-backend = "setuptools.build_meta:__legacy__"
1011

1112
[tool.setuptools-git-versioning]
1213
enabled = true
1314
# change the file path
1415
version_file = "torchlpc/VERSION.txt"
15-
count_commits_from_version_file = true # <--- enable commits tracking
16-
dev_template = "{tag}.{branch}{ccount}" # suffix for versions will be .dev
17-
dirty_template = "{tag}.{branch}{ccount}" # same thing here
18-
# Temporarily disable branch formatting due to issues with regex in _version.py
19-
# branch_formatter = "torchlpc._version:format_branch_name"
16+
count_commits_from_version_file = true # <--- enable commits tracking
17+
dev_template = "{tag}.{branch}{ccount}" # suffix for versions will be .dev
18+
dirty_template = "{tag}.{branch}{ccount}" # same thing here
19+
branch_formatter = "torchlpc._version:format_branch_name"
2020

2121
[tool.setuptools.package-data]
2222
# include VERSION file to a package
@@ -29,6 +29,7 @@ exclude = ["tests", "tests.*"]
2929
[tool.setuptools]
3030
# this package will read some included files in runtime, avoid installing it as .zip
3131
zip-safe = false
32+
license-files = ["LICENSE"]
3233

3334
[project]
3435
dynamic = ["version"]
@@ -39,8 +40,8 @@ authors = [{ name = "Chin-Yun Yu", email = "chin-yun.yu@qmul.ac.uk" }]
3940
maintainers = [{ name = "Chin-Yun Yu", email = "chin-yun.yu@qmul.ac.uk" }]
4041
description = "Fast, efficient, and differentiable time-varying LPC filtering in PyTorch."
4142
readme = "README.md"
42-
license = "MIT"
43-
license-files = ["LICENSE"]
43+
license = { text = "MIT" }
44+
4445
classifiers = [
4546
"Development Status :: 3 - Alpha",
4647
"Intended Audience :: Developers",

torchlpc/csrc/scan_cpu.cpp

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,29 @@
66
#include <utility>
77
#include <vector>
88

9-
extern "C" {
10-
/* Creates a dummy empty _C module that can be imported from Python.
11-
The import from Python will load the .so associated with this extension
12-
built from this file, so that all the TORCH_LIBRARY calls below are run.*/
13-
PyObject *PyInit__C(void) {
14-
static struct PyModuleDef module_def = {
15-
PyModuleDef_HEAD_INIT,
16-
"_C", /* name of module */
17-
NULL, /* module documentation, may be NULL */
18-
-1, /* size of per-interpreter state of the module,
19-
or -1 if the module keeps state in global variables. */
20-
NULL, /* methods */
21-
};
22-
return PyModule_Create(&module_def);
23-
}
9+
extern "C"
10+
{
11+
/* Creates a dummy empty _C module that can be imported from Python.
12+
The import from Python will load the .so associated with this extension
13+
built from this file, so that all the TORCH_LIBRARY calls below are run.*/
14+
PyObject *PyInit__C(void)
15+
{
16+
static struct PyModuleDef module_def = {
17+
PyModuleDef_HEAD_INIT,
18+
"_C", /* name of module */
19+
NULL, /* module documentation, may be NULL */
20+
-1, /* size of per-interpreter state of the module,
21+
or -1 if the module keeps state in global variables. */
22+
NULL, /* methods */
23+
};
24+
return PyModule_Create(&module_def);
25+
}
2426
}
2527

2628
template <typename scalar_t>
2729
void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
28-
const at::Tensor &initials, const at::Tensor &output) {
30+
const at::Tensor &initials, const at::Tensor &output)
31+
{
2932
TORCH_CHECK(input.dim() == 2, "Input must be 2D");
3033
TORCH_CHECK(initials.dim() == 1, "Initials must be 1D");
3134
TORCH_CHECK(weights.sizes() == input.sizes(),
@@ -50,39 +53,33 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
5053
auto T = input.size(1);
5154
auto total_size = input.numel();
5255

53-
std::pair<scalar_t, scalar_t> buffer[total_size];
54-
5556
const scalar_t *input_ptr = input_contiguous.const_data_ptr<scalar_t>();
5657
const scalar_t *initials_ptr =
5758
initials_contiguous.const_data_ptr<scalar_t>();
5859
const scalar_t *weights_ptr = weights_contiguous.const_data_ptr<scalar_t>();
5960
scalar_t *output_ptr = output.mutable_data_ptr<scalar_t>();
6061

61-
std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
62-
[](const scalar_t &a, const scalar_t &b) {
63-
return std::make_pair(a, b);
64-
});
65-
66-
at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) {
67-
for (auto b = start; b < end; b++) {
68-
std::inclusive_scan(
69-
buffer + b * T, buffer + (b + 1) * T, buffer + b * T,
70-
[](const std::pair<scalar_t, scalar_t> &a,
71-
const std::pair<scalar_t, scalar_t> &b) {
72-
return std::make_pair(a.first * b.first,
73-
a.second * b.first + b.second);
74-
},
75-
std::make_pair((scalar_t)1.0, initials_ptr[b]));
76-
}
77-
});
78-
79-
std::transform(
80-
buffer, buffer + total_size, output_ptr,
81-
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
62+
at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end)
63+
{
64+
for (auto b = start; b < end; b++)
65+
{
66+
auto initial = initials_ptr[b];
67+
auto weights_offset = weights_ptr + b * T;
68+
auto input_offset = input_ptr + b * T;
69+
auto output_offset = output_ptr + b * T;
70+
for (int64_t t = 0; t < T; t++)
71+
{
72+
auto w = weights_offset[t];
73+
auto x = input_offset[t];
74+
initial = initial * w + x;
75+
output_offset[t] = initial;
76+
}
77+
}; });
8278
}
8379

8480
template <typename scalar_t>
85-
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
81+
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out)
82+
{
8683
// Ensure input dimensions are correct
8784
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
8885
TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional");
@@ -106,24 +103,27 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
106103
const scalar_t *a_ptr = a_contiguous.const_data_ptr<scalar_t>();
107104
scalar_t *out_ptr = padded_out.mutable_data_ptr<scalar_t>();
108105

109-
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) {
110-
for (auto b = start; b < end; b++) {
111-
auto out_offset = b * (T + order) + order;
112-
auto a_offset = b * T * order;
113-
for (int64_t t = 0; t < T; t++) {
114-
scalar_t y = out_ptr[out_offset + t];
115-
for (int64_t i = 0; i < order; i++) {
116-
y -= a_ptr[a_offset + t * order + i] *
117-
out_ptr[out_offset + t - i - 1];
106+
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end)
107+
{
108+
for (auto b = start; b < end; b++)
109+
{
110+
auto out_offset = out_ptr + b * (T + order) + order;
111+
auto a_offset = a_ptr + b * T * order;
112+
for (int64_t t = 0; t < T; t++)
113+
{
114+
scalar_t y = out_offset[t];
115+
for (int64_t i = 0; i < order; i++)
116+
{
117+
y -= a_offset[t * order + i] * out_offset [t - i - 1];
118118
}
119-
out_ptr[out_offset + t] = y;
119+
out_offset[t] = y;
120120
}
121-
}
122-
});
121+
}; });
123122
}
124123

125124
at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
126-
const at::Tensor &initials) {
125+
const at::Tensor &initials)
126+
{
127127
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
128128
"Input must be floating point or complex");
129129
TORCH_CHECK(initials.scalar_type() == input.scalar_type(),
@@ -135,12 +135,14 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
135135

136136
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
137137
input.scalar_type(), "scan_cpu",
138-
[&] { scan_cpu<scalar_t>(input, weights, initials, output); });
138+
[&]
139+
{ scan_cpu<scalar_t>(input, weights, initials, output); });
139140
return output;
140141
}
141142

142143
at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
143-
const at::Tensor &zi) {
144+
const at::Tensor &zi)
145+
{
144146
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
145147
"Input must be floating point or complex");
146148
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
@@ -156,16 +158,19 @@ at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
156158
auto out = at::cat({zi.flip(1), x}, 1).contiguous();
157159

158160
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
159-
x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core<scalar_t>(a, out); });
161+
x.scalar_type(), "lpc_cpu", [&]
162+
{ lpc_cpu_core<scalar_t>(a, out); });
160163
return out.slice(1, zi.size(1), out.size(1)).contiguous();
161164
}
162165

163-
TORCH_LIBRARY(torchlpc, m) {
166+
TORCH_LIBRARY(torchlpc, m)
167+
{
164168
m.def("torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor");
165169
m.def("torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor");
166170
}
167171

168-
TORCH_LIBRARY_IMPL(torchlpc, CPU, m) {
172+
TORCH_LIBRARY_IMPL(torchlpc, CPU, m)
173+
{
169174
m.impl("scan", &scan_cpu_wrapper);
170175
m.impl("lpc", &lpc_cpu);
171176
}

0 commit comments

Comments
 (0)