Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions audio_diffusion/blocks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import math

import torch
from torch import nn
from torch.nn import functional as F


class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()

def forward(self, input):
return self.main(input) + self.skip(input)
def forward(self, x):
return self.main(x) + self.skip(x)


# Noise level (and other) conditioning
class ResConvBlock(ResidualBlock):
Expand All @@ -25,6 +28,7 @@ def __init__(self, c_in, c_mid, c_out, is_last=False):
nn.GELU() if not is_last else nn.Identity(),
], skip)


class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
Expand All @@ -35,24 +39,29 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.):
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)

def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
def forward(self, x): # you shouldn't use input, it's a system variable
n, c, s = x.shape
qkv = self.qkv_proj(self.norm(x))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
del qkv
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
del q
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
del v
return x + self.dropout(self.out_proj(y))


class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)

def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
def forward(self, x):
return torch.cat([self.main(x), x], dim=1)


class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
Expand All @@ -61,22 +70,22 @@ def __init__(self, in_features, out_features, std=1.):
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)

def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
def forward(self, x):
f = 2 * math.pi * x @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)


_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}


Expand All @@ -87,7 +96,7 @@ def __init__(self, kernel='linear', pad_mode='reflect'):
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)

def forward(self, x):
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
Expand All @@ -103,7 +112,7 @@ def __init__(self, kernel='linear', pad_mode='reflect'):
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)

def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
Expand Down
23 changes: 13 additions & 10 deletions audio_diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
from torch.nn import functional as F

from .blocks import SkipBlock, FourierFeatures, SelfAttention1d, ResConvBlock, Downsample1d, Upsample1d
from .utils import append_dims, expand_to_planes
from .utils import expand_to_planes


class DiffusionAttnUnet1D(nn.Module):
def __init__(
self,
global_args,
io_channels = 2,
depth=14,
n_attn_layers = 6,
c_mults = [128, 128, 256, 256] + [512] * 10
self,
global_args,
io_channels=2,
depth=14,
n_attn_layers=6,
c_mults=None
):
super().__init__()

if c_mults is None:
c_mults = [128, 128, 256, 256] + [512] * 10
self.timestep_embed = FourierFeatures(1, 16)

attn_layer = depth - n_attn_layers - 1
Expand Down Expand Up @@ -72,11 +75,11 @@ def __init__(

def forward(self, input, t, cond=None):
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)

inputs = [input, timestep_embed]

if cond is not None:
cond = F.interpolate(cond, (input.shape[2], ), mode='linear', align_corners=False)
cond = F.interpolate(cond, (input.shape[2],), mode='linear', align_corners=False)
inputs.append(cond)

return self.net(torch.cat(inputs, dim=1))
return self.net(torch.cat(inputs, dim=1))
40 changes: 24 additions & 16 deletions audio_diffusion/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import contextmanager
import math
import random
import warnings
from contextlib import contextmanager

import torch
from torch import nn
import random
import math
from torch import nn
from torch import optim


def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
Expand Down Expand Up @@ -37,6 +38,7 @@ def eval_mode(model):
the previous mode on exit."""
return train_mode(model, False)


@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
Expand Down Expand Up @@ -149,16 +151,19 @@ def _get_closed_form_lr(self):
def get_alphas_sigmas(t):
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)


def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]


def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])


class PadCrop(nn.Module):
def __init__(self, n_samples, randomize=True):
super().__init__()
Expand All @@ -173,23 +178,26 @@ def __call__(self, signal):
output[:, :min(s, self.n_samples)] = signal[:, start:end]
return output


class RandomPhaseInvert(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p

def __call__(self, signal):
return -signal if (random.random() < self.p) else signal


class Stereo(nn.Module):
def __call__(self, signal):
signal_shape = signal.shape
# Check if it's mono
if len(signal_shape) == 1: # s -> 2, s
signal = signal.unsqueeze(0).repeat(2, 1)
elif len(signal_shape) == 2:
if signal_shape[0] == 1: #1, s -> 2, s
signal = signal.repeat(2, 1)
elif signal_shape[0] > 2: #?, s -> 2,s
signal = signal[:2, :]

return signal
def __call__(self, signal):
signal_shape = signal.shape
# Check if it's mono
if len(signal_shape) == 1: # s -> 2, s
signal = signal.unsqueeze(0).repeat(2, 1)
elif len(signal_shape) == 2:
if signal_shape[0] == 1: # 1, s -> 2, s
signal = signal.repeat(2, 1)
elif signal_shape[0] > 2: # ?, s -> 2,s
signal = signal[:2, :]

return signal
Loading