From b881f11d0d35d19bd92aaaa984c7fee0d0227060 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 3 Aug 2019 13:26:52 -0700 Subject: [PATCH] add modes to conv function --- src/dspbase.jl | 71 ++++++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/src/dspbase.jl b/src/dspbase.jl index 4ecadd3b4..f525bd307 100644 --- a/src/dspbase.jl +++ b/src/dspbase.jl @@ -289,41 +289,50 @@ function _conv_clip!(y::AbstractArray, minpad, axesu, axesv) end """ - conv(u,v) + conv(u,v; mode = "full") + +Convolution of two arrays. Passing "full" mode is FFT algorithm, "same" mode is +output of length max(M, N), and "valid" is output where all signals are +overlapped. This is an adaptation from the Python numpy convolve function. -Convolution of two arrays. Uses FFT algorithm. """ -function conv(u::AbstractArray{T, N}, - v::AbstractArray{T, N}) where {T<:BLAS.BlasFloat, N} - su = size(u) - sv = size(v) - minpad = su .+ sv .- 1 - padsize = map(n -> n > 1024 ? nextprod([2,3,5], n) : nextpow(2, n), minpad) - y = _conv(u, v, padsize) - _conv_clip!(y, minpad, axes(u), axes(v)) -end -function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N - fu, fv = promote(u, v) - conv(fu, fv) -end -conv(u::AbstractArray{T, N}, v::AbstractArray{T, N}) where {T<:Number, N} = - conv(float(u), float(v)) -conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} = - round.(Int, conv(float(u), float(v))) -function conv(u::AbstractArray{<:Number, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N - conv(float(u), v) -end -function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:Number, N}) where N - conv(u, float(v)) +function conv(u::StridedVector{T}, v::StridedVector{T}; mode="full") where + T<:BLAS.BlasFloat + nu = length(u) + nv = length(v) + if nu==0||nv==0 + throw( DomainError("parameter u or v", + "Argument vectors are supposed to be non-empty.") ) + elseif nv>nu + u, v = v, u + nu, nv = nv, nu + end + if mode=="full" + n = nu+nv-1 + return [u[max(1, i+1-nv):min(i,nu)]'*v[i