diff --git a/MANIFEST.in b/MANIFEST.in index 0240be0..187e82c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ -recursive-include philtorch *.txt \ No newline at end of file +recursive-include philtorch *.txt +recursive-include philtorch *.cuh \ No newline at end of file diff --git a/philtorch/csrc/host_lti_recur.cpp b/philtorch/csrc/host_lti_recur.cpp new file mode 100644 index 0000000..9f24af2 --- /dev/null +++ b/philtorch/csrc/host_lti_recur.cpp @@ -0,0 +1,85 @@ +#include +#include + +template +void host_lti_batch_linear_recurrence(int B, int T, + const scalar_t *a, + scalar_t *out) +{ +// Process each batch in parallel +#pragma omp parallel for + for (int b = 0; b < B; ++b) + { + const scalar_t a_b = a[b]; + scalar_t *out_b = out + b * T; + + // t loop + for (int t = 1; t < T; ++t) + { + out_b[t] += a_b * out_b[t - 1]; + } + } +} + +template +void host_lti_shared_linear_recurrence(int B, int T, + const scalar_t a, + scalar_t *out) +{ +// Process each batch in parallel +#pragma omp parallel for + for (int b = 0; b < B; ++b) + { + scalar_t *out_b = out + b * T; + // t loop + for (int t = 1; t < T; ++t) + { + out_b[t] += a * out_b[t - 1]; + } + } +} + +at::Tensor lti_recur_cpu_impl(const at::Tensor &a, + const at::Tensor &zi, const at::Tensor &x) +{ + TORCH_CHECK(zi.scalar_type() == x.scalar_type(), + "zi must have the same scalar type as input"); + TORCH_CHECK(a.scalar_type() == x.scalar_type(), + "A must have the same scalar type as input"); + TORCH_CHECK(a.dim() <= 1, "A must be a vector or a scalar"); + TORCH_CHECK(zi.dim() == 1, "zi must be a vector"); + + auto n_steps = x.size(1) + 1; // +1 for the initial state + auto n_batches = x.size(0); + auto output = + at::cat({zi.unsqueeze(1), x}, 1).contiguous(); + auto a_contiguous = a.contiguous(); + + if (a.dim() == 1 && a.numel() == n_batches) + { + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.scalar_type(), "host_lti_batch_linear_recurrence", [&] + { host_lti_batch_linear_recurrence( + n_batches, n_steps, + a_contiguous.const_data_ptr(), + output.mutable_data_ptr()); }); + } + else + { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.scalar_type(), "host_lti_shared_linear_recurrence", [&] + { host_lti_shared_linear_recurrence( + n_batches, n_steps, + a_contiguous.item(), + output.mutable_data_ptr()); }); + } + + return output.slice(1, 1, output.size(1)) + .contiguous(); // Remove the initial state from the output +} + +TORCH_LIBRARY_IMPL(philtorch, CPU, m) +{ + m.impl("lti_recur", <i_recur_cpu_impl); +} \ No newline at end of file diff --git a/philtorch/csrc/host_recur2.cpp b/philtorch/csrc/host_recur2.cpp index 30ddc09..2d3e3fd 100644 --- a/philtorch/csrc/host_recur2.cpp +++ b/philtorch/csrc/host_recur2.cpp @@ -166,4 +166,6 @@ TORCH_LIBRARY(philtorch, m) m.def("philtorch::lti_recur2(Tensor A, Tensor zi, Tensor x) -> Tensor"); m.def("philtorch::lti_recurN(Tensor A, Tensor zi, Tensor x) -> Tensor"); + + m.def("philtorch::lti_recur(Tensor A, Tensor zi, Tensor x) -> Tensor"); } diff --git a/philtorch/csrc/lti_recur.cu b/philtorch/csrc/lti_recur.cu new file mode 100644 index 0000000..718679f --- /dev/null +++ b/philtorch/csrc/lti_recur.cu @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +struct recur_binary_op +{ + __host__ __device__ cuda::std::tuple operator()( + const cuda::std::tuple &a, + const cuda::std::tuple &b) const + { + auto [a_first, a_second] = a; + auto [b_first, b_second] = b; + return cuda::std::make_tuple(a_first * b_first, + a_second * b_first + b_second); + } +}; + +template +struct take_second +{ + __host__ __device__ T operator()(const cuda::std::tuple &state) const + { + return thrust::get<1>(state); + } +}; + +template +struct lti_batch_recur_input_op +{ + const T *decays; + int n_steps; + __host__ __device__ cuda::std::tuple operator()(int i, const T &x) const + { + int idx = i / n_steps; + int offset = i % n_steps; + if (offset > 0) + return thrust::make_tuple(decays[idx], x); + return thrust::make_tuple(0, x); + } +}; + +template +struct lti_shared_recur_input_op +{ + const T decay; + int n_steps; + __host__ __device__ cuda::std::tuple operator()(int i, const T &x) const + { + int offset = i % n_steps; + if (offset > 0) + return thrust::make_tuple(decay, x); + return thrust::make_tuple(0, x); + } +}; + +template +void lti_batch_linear_recurrence(const scalar_t *decays, + const scalar_t *impulses, + scalar_t *out, int n_steps, int n_batches) +{ + auto total_steps = n_steps * n_batches; + thrust::counting_iterator it(0); + auto batch_input_op = thrust::make_zip_function(lti_batch_recur_input_op{decays, n_steps}); + thrust::inclusive_scan( + thrust::device, + thrust::make_transform_iterator( + thrust::make_zip_iterator(it, impulses), batch_input_op), + thrust::make_transform_iterator( + thrust::make_zip_iterator(it + total_steps, impulses + total_steps), batch_input_op), + thrust::make_transform_output_iterator(out, take_second()), + recur_binary_op()); +} + +template +void lti_shared_linear_recurrence(const scalar_t decay, + const scalar_t *impulses, + scalar_t *out, int n_steps, int n_batches) +{ + auto total_steps = n_steps * n_batches; + thrust::counting_iterator it(0); + auto shared_input_op = thrust::make_zip_function(lti_shared_recur_input_op{decay, n_steps}); + thrust::inclusive_scan( + thrust::device, + thrust::make_transform_iterator( + thrust::make_zip_iterator(it, impulses), shared_input_op), + thrust::make_transform_iterator( + thrust::make_zip_iterator(it + total_steps, impulses + total_steps), shared_input_op), + thrust::make_transform_output_iterator(out, take_second()), + recur_binary_op()); +} + +at::Tensor lti_recur_cuda_impl(const at::Tensor &a, + const at::Tensor &zi, const at::Tensor &x) +{ + TORCH_CHECK(zi.scalar_type() == x.scalar_type(), + "zi must have the same scalar type as input"); + TORCH_CHECK(a.scalar_type() == x.scalar_type(), + "A must have the same scalar type as input"); + TORCH_CHECK(a.dim() <= 1, "A must be a vector or a scalar"); + TORCH_CHECK(zi.dim() == 1, "zi must be a vector"); + + auto n_steps = x.size(1) + 1; // +1 for the initial state + auto n_batches = x.size(0); + auto x_contiguous = + at::cat({zi.unsqueeze(1), x}, 1).contiguous(); + auto a_contiguous = a.contiguous(); + auto output = at::empty_like(x_contiguous); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + if (a.dim() == 1 && a.numel() == n_batches) + { + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.scalar_type(), "lti_batch_linear_recurrence", [&] + { lti_batch_linear_recurrence( + a_contiguous.const_data_ptr(), + x_contiguous.const_data_ptr(), + output.mutable_data_ptr(), + n_steps, n_batches); }); + } + else + { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.scalar_type(), "lti_shared_linear_recurrence", [&] + { lti_shared_linear_recurrence( + a_contiguous.item(), + x_contiguous.const_data_ptr(), + output.mutable_data_ptr(), + n_steps, n_batches); }); + } + + return output.slice(1, 1, output.size(1)) + .contiguous(); // Remove the initial state from the output +} + +TORCH_LIBRARY_IMPL(philtorch, CUDA, m) { m.impl("lti_recur", <i_recur_cuda_impl); } \ No newline at end of file diff --git a/tests/test_recur_ext.py b/tests/test_recur_ext.py index c500d5b..7d1ca3a 100644 --- a/tests/test_recur_ext.py +++ b/tests/test_recur_ext.py @@ -1,6 +1,7 @@ import pytest import torch from philtorch.mat import companion +from philtorch.lti import linear_recurrence from .test_lti_lfilter import _generate_random_signal from .test_lti_ssm import _generate_random_filter_coeffs @@ -65,3 +66,30 @@ def test_lti_recurN_equiv(device: str, batch: bool, order: int): x_torch, ) assert torch.allclose(lti_y, ltv_y) + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +@pytest.mark.parametrize("batch", [True, False]) +def test_lti_recur_equiv(device: str, batch: bool): + B = 3 + T = 101 + + # Convert to torch tensors + a_torch = torch.rand(B if batch else 1).to(device).double() * 2 - 1 + x_torch = torch.randn(B, T).to(device).double() + zi = x_torch.new_zeros(B).normal_() + + lti_y = torch.ops.philtorch.lti_recur(a_torch, zi, x_torch) + torch_y = linear_recurrence(a_torch, zi, x_torch) + assert torch.allclose(lti_y, torch_y), torch.max(torch.abs(lti_y - torch_y))