Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
recursive-include philtorch *.txt
recursive-include philtorch *.txt
recursive-include philtorch *.cuh
85 changes: 85 additions & 0 deletions philtorch/csrc/host_lti_recur.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <torch/script.h>
#include <torch/torch.h>

template <typename scalar_t>
void host_lti_batch_linear_recurrence(int B, int T,
const scalar_t *a,
scalar_t *out)
{
// Process each batch in parallel
#pragma omp parallel for
for (int b = 0; b < B; ++b)
{
const scalar_t a_b = a[b];
scalar_t *out_b = out + b * T;

// t loop
for (int t = 1; t < T; ++t)
{
out_b[t] += a_b * out_b[t - 1];
}
}
}

template <typename scalar_t>
void host_lti_shared_linear_recurrence(int B, int T,
const scalar_t a,
scalar_t *out)
{
// Process each batch in parallel
#pragma omp parallel for
for (int b = 0; b < B; ++b)
{
scalar_t *out_b = out + b * T;
// t loop
for (int t = 1; t < T; ++t)
{
out_b[t] += a * out_b[t - 1];
}
}
}

at::Tensor lti_recur_cpu_impl(const at::Tensor &a,
const at::Tensor &zi, const at::Tensor &x)
{
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
"zi must have the same scalar type as input");
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
"A must have the same scalar type as input");
TORCH_CHECK(a.dim() <= 1, "A must be a vector or a scalar");
TORCH_CHECK(zi.dim() == 1, "zi must be a vector");

auto n_steps = x.size(1) + 1; // +1 for the initial state
auto n_batches = x.size(0);
auto output =
at::cat({zi.unsqueeze(1), x}, 1).contiguous();
auto a_contiguous = a.contiguous();

if (a.dim() == 1 && a.numel() == n_batches)
{

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "host_lti_batch_linear_recurrence", [&]
{ host_lti_batch_linear_recurrence<scalar_t>(
n_batches, n_steps,
a_contiguous.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>()); });
}
else
{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "host_lti_shared_linear_recurrence", [&]
{ host_lti_shared_linear_recurrence<scalar_t>(
n_batches, n_steps,
a_contiguous.item<scalar_t>(),
output.mutable_data_ptr<scalar_t>()); });
}

return output.slice(1, 1, output.size(1))
.contiguous(); // Remove the initial state from the output
}

TORCH_LIBRARY_IMPL(philtorch, CPU, m)
{
m.impl("lti_recur", &lti_recur_cpu_impl);
}
2 changes: 2 additions & 0 deletions philtorch/csrc/host_recur2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,6 @@ TORCH_LIBRARY(philtorch, m)

m.def("philtorch::lti_recur2(Tensor A, Tensor zi, Tensor x) -> Tensor");
m.def("philtorch::lti_recurN(Tensor A, Tensor zi, Tensor x) -> Tensor");

m.def("philtorch::lti_recur(Tensor A, Tensor zi, Tensor x) -> Tensor");
}
150 changes: 150 additions & 0 deletions philtorch/csrc/lti_recur.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include <assert.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/pair.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#include <torch/script.h>
#include <torch/torch.h>

template <typename T>
struct recur_binary_op
{
__host__ __device__ cuda::std::tuple<T, T> operator()(
const cuda::std::tuple<T, T> &a,
const cuda::std::tuple<T, T> &b) const
{
auto [a_first, a_second] = a;
auto [b_first, b_second] = b;
return cuda::std::make_tuple(a_first * b_first,
a_second * b_first + b_second);
}
};

template <typename T>
struct take_second
{
__host__ __device__ T operator()(const cuda::std::tuple<T, T> &state) const
{
return thrust::get<1>(state);
}
};

template <typename T>
struct lti_batch_recur_input_op
{
const T *decays;
int n_steps;
__host__ __device__ cuda::std::tuple<T, T> operator()(int i, const T &x) const
{
int idx = i / n_steps;
int offset = i % n_steps;
if (offset > 0)
return thrust::make_tuple(decays[idx], x);
return thrust::make_tuple(0, x);
}
};

template <typename T>
struct lti_shared_recur_input_op
{
const T decay;
int n_steps;
__host__ __device__ cuda::std::tuple<T, T> operator()(int i, const T &x) const
{
int offset = i % n_steps;
if (offset > 0)
return thrust::make_tuple(decay, x);
return thrust::make_tuple(0, x);
}
};

template <typename scalar_t>
void lti_batch_linear_recurrence(const scalar_t *decays,
const scalar_t *impulses,
scalar_t *out, int n_steps, int n_batches)
{
auto total_steps = n_steps * n_batches;
thrust::counting_iterator<int> it(0);
auto batch_input_op = thrust::make_zip_function(lti_batch_recur_input_op<scalar_t>{decays, n_steps});
thrust::inclusive_scan(
thrust::device,
thrust::make_transform_iterator(
thrust::make_zip_iterator(it, impulses), batch_input_op),
thrust::make_transform_iterator(
thrust::make_zip_iterator(it + total_steps, impulses + total_steps), batch_input_op),
thrust::make_transform_output_iterator(out, take_second<scalar_t>()),
recur_binary_op<scalar_t>());
}

template <typename scalar_t>
void lti_shared_linear_recurrence(const scalar_t decay,
const scalar_t *impulses,
scalar_t *out, int n_steps, int n_batches)
{
auto total_steps = n_steps * n_batches;
thrust::counting_iterator<int> it(0);
auto shared_input_op = thrust::make_zip_function(lti_shared_recur_input_op<scalar_t>{decay, n_steps});
thrust::inclusive_scan(
thrust::device,
thrust::make_transform_iterator(
thrust::make_zip_iterator(it, impulses), shared_input_op),
thrust::make_transform_iterator(
thrust::make_zip_iterator(it + total_steps, impulses + total_steps), shared_input_op),
thrust::make_transform_output_iterator(out, take_second<scalar_t>()),
recur_binary_op<scalar_t>());
}

at::Tensor lti_recur_cuda_impl(const at::Tensor &a,
const at::Tensor &zi, const at::Tensor &x)
{
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
"zi must have the same scalar type as input");
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
"A must have the same scalar type as input");
TORCH_CHECK(a.dim() <= 1, "A must be a vector or a scalar");
TORCH_CHECK(zi.dim() == 1, "zi must be a vector");

auto n_steps = x.size(1) + 1; // +1 for the initial state
auto n_batches = x.size(0);
auto x_contiguous =
at::cat({zi.unsqueeze(1), x}, 1).contiguous();
auto a_contiguous = a.contiguous();
auto output = at::empty_like(x_contiguous);

const at::cuda::OptionalCUDAGuard device_guard(device_of(x));

if (a.dim() == 1 && a.numel() == n_batches)
{

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "lti_batch_linear_recurrence", [&]
{ lti_batch_linear_recurrence<scalar_t>(
a_contiguous.const_data_ptr<scalar_t>(),
x_contiguous.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
n_steps, n_batches); });
}
else
{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "lti_shared_linear_recurrence", [&]
{ lti_shared_linear_recurrence<scalar_t>(
a_contiguous.item<scalar_t>(),
x_contiguous.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
n_steps, n_batches); });
}

return output.slice(1, 1, output.size(1))
.contiguous(); // Remove the initial state from the output
}

TORCH_LIBRARY_IMPL(philtorch, CUDA, m) { m.impl("lti_recur", &lti_recur_cuda_impl); }
28 changes: 28 additions & 0 deletions tests/test_recur_ext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
from philtorch.mat import companion
from philtorch.lti import linear_recurrence

from .test_lti_lfilter import _generate_random_signal
from .test_lti_ssm import _generate_random_filter_coeffs
Expand Down Expand Up @@ -65,3 +66,30 @@ def test_lti_recurN_equiv(device: str, batch: bool, order: int):
x_torch,
)
assert torch.allclose(lti_y, ltv_y)


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
@pytest.mark.parametrize("batch", [True, False])
def test_lti_recur_equiv(device: str, batch: bool):
B = 3
T = 101

# Convert to torch tensors
a_torch = torch.rand(B if batch else 1).to(device).double() * 2 - 1
x_torch = torch.randn(B, T).to(device).double()
zi = x_torch.new_zeros(B).normal_()

lti_y = torch.ops.philtorch.lti_recur(a_torch, zi, x_torch)
torch_y = linear_recurrence(a_torch, zi, x_torch)
assert torch.allclose(lti_y, torch_y), torch.max(torch.abs(lti_y - torch_y))
Loading