diff --git a/audio_diffusion/blocks.py b/audio_diffusion/blocks.py index 6261a90..65cc63a 100644 --- a/audio_diffusion/blocks.py +++ b/audio_diffusion/blocks.py @@ -1,16 +1,19 @@ import math + import torch from torch import nn from torch.nn import functional as F + class ResidualBlock(nn.Module): def __init__(self, main, skip=None): super().__init__() self.main = nn.Sequential(*main) self.skip = skip if skip else nn.Identity() - def forward(self, input): - return self.main(input) + self.skip(input) + def forward(self, x): + return self.main(x) + self.skip(x) + # Noise level (and other) conditioning class ResConvBlock(ResidualBlock): @@ -25,6 +28,7 @@ def __init__(self, c_in, c_mid, c_out, is_last=False): nn.GELU() if not is_last else nn.Identity(), ], skip) + class SelfAttention1d(nn.Module): def __init__(self, c_in, n_head=1, dropout_rate=0.): super().__init__() @@ -35,24 +39,29 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.): self.out_proj = nn.Conv1d(c_in, c_in, 1) self.dropout = nn.Dropout(dropout_rate, inplace=True) - def forward(self, input): - n, c, s = input.shape - qkv = self.qkv_proj(self.norm(input)) + def forward(self, x): # you shouldn't use input, it's a system variable + n, c, s = x.shape + qkv = self.qkv_proj(self.norm(x)) qkv = qkv.view( [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) q, k, v = qkv.chunk(3, dim=1) - scale = k.shape[3]**-0.25 + del qkv + scale = k.shape[3] ** -0.25 att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + del q y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) - return input + self.dropout(self.out_proj(y)) + del v + return x + self.dropout(self.out_proj(y)) + class SkipBlock(nn.Module): def __init__(self, *main): super().__init__() self.main = nn.Sequential(*main) - def forward(self, input): - return torch.cat([self.main(input), input], dim=1) + def forward(self, x): + return torch.cat([self.main(x), x], dim=1) + class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1.): @@ -61,22 +70,22 @@ def __init__(self, in_features, out_features, std=1.): self.weight = nn.Parameter(torch.randn( [out_features // 2, in_features]) * std) - def forward(self, input): - f = 2 * math.pi * input @ self.weight.T + def forward(self, x): + f = 2 * math.pi * x @ self.weight.T return torch.cat([f.cos(), f.sin()], dim=-1) _kernels = { 'linear': [1 / 8, 3 / 8, 3 / 8, 1 / 8], - 'cubic': + 'cubic': [-0.01171875, -0.03515625, 0.11328125, 0.43359375, - 0.43359375, 0.11328125, -0.03515625, -0.01171875], - 'lanczos3': + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, - -0.066637322306633, 0.13550527393817902, 0.44638532400131226, - 0.44638532400131226, 0.13550527393817902, -0.066637322306633, - -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] } @@ -87,7 +96,7 @@ def __init__(self, kernel='linear', pad_mode='reflect'): kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer('kernel', kernel_1d) - + def forward(self, x): x = F.pad(x, (self.pad,) * 2, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) @@ -103,7 +112,7 @@ def __init__(self, kernel='linear', pad_mode='reflect'): kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer('kernel', kernel_1d) - + def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) diff --git a/audio_diffusion/models.py b/audio_diffusion/models.py index fbaae41..c333891 100644 --- a/audio_diffusion/models.py +++ b/audio_diffusion/models.py @@ -3,19 +3,22 @@ from torch.nn import functional as F from .blocks import SkipBlock, FourierFeatures, SelfAttention1d, ResConvBlock, Downsample1d, Upsample1d -from .utils import append_dims, expand_to_planes +from .utils import expand_to_planes + class DiffusionAttnUnet1D(nn.Module): def __init__( - self, - global_args, - io_channels = 2, - depth=14, - n_attn_layers = 6, - c_mults = [128, 128, 256, 256] + [512] * 10 + self, + global_args, + io_channels=2, + depth=14, + n_attn_layers=6, + c_mults=None ): super().__init__() + if c_mults is None: + c_mults = [128, 128, 256, 256] + [512] * 10 self.timestep_embed = FourierFeatures(1, 16) attn_layer = depth - n_attn_layers - 1 @@ -72,11 +75,11 @@ def __init__( def forward(self, input, t, cond=None): timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape) - + inputs = [input, timestep_embed] if cond is not None: - cond = F.interpolate(cond, (input.shape[2], ), mode='linear', align_corners=False) + cond = F.interpolate(cond, (input.shape[2],), mode='linear', align_corners=False) inputs.append(cond) - return self.net(torch.cat(inputs, dim=1)) \ No newline at end of file + return self.net(torch.cat(inputs, dim=1)) diff --git a/audio_diffusion/utils.py b/audio_diffusion/utils.py index 7915578..0611ea7 100644 --- a/audio_diffusion/utils.py +++ b/audio_diffusion/utils.py @@ -1,12 +1,13 @@ -from contextlib import contextmanager +import math +import random import warnings +from contextlib import contextmanager import torch -from torch import nn -import random -import math +from torch import nn from torch import optim + def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim @@ -37,6 +38,7 @@ def eval_mode(model): the previous mode on exit.""" return train_mode(model, False) + @torch.no_grad() def ema_update(model, averaged_model, decay): """Incorporates updated model parameters into an exponential moving averaged @@ -149,6 +151,7 @@ def _get_closed_form_lr(self): def get_alphas_sigmas(t): return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim @@ -156,9 +159,11 @@ def append_dims(x, target_dims): raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') return x[(...,) + (None,) * dims_to_append] + def expand_to_planes(input, shape): return input[..., None].repeat([1, 1, shape[2]]) + class PadCrop(nn.Module): def __init__(self, n_samples, randomize=True): super().__init__() @@ -173,23 +178,26 @@ def __call__(self, signal): output[:, :min(s, self.n_samples)] = signal[:, start:end] return output + class RandomPhaseInvert(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p + def __call__(self, signal): return -signal if (random.random() < self.p) else signal + class Stereo(nn.Module): - def __call__(self, signal): - signal_shape = signal.shape - # Check if it's mono - if len(signal_shape) == 1: # s -> 2, s - signal = signal.unsqueeze(0).repeat(2, 1) - elif len(signal_shape) == 2: - if signal_shape[0] == 1: #1, s -> 2, s - signal = signal.repeat(2, 1) - elif signal_shape[0] > 2: #?, s -> 2,s - signal = signal[:2, :] - - return signal + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: # 1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: # ?, s -> 2,s + signal = signal[:2, :] + + return signal diff --git a/dataset/dataset.py b/dataset/dataset.py index f9ae8b1..f1c03dd 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,98 +1,100 @@ +import os +import random +from functools import partial +from glob import glob +from multiprocessing import Pool, cpu_count + import torch import torchaudio +import tqdm from torchaudio import transforms as T -import random -from glob import glob -import os + from audio_diffusion.utils import Stereo, PadCrop, RandomPhaseInvert -import tqdm -from multiprocessing import Pool, cpu_count -from functools import partial + class SampleDataset(torch.utils.data.Dataset): - def __init__(self, paths, global_args): - super().__init__() - self.filenames = [] - print(f"Random crop: {global_args.random_crop}") - self.augs = torch.nn.Sequential( - PadCrop(global_args.sample_size, randomize=global_args.random_crop), - RandomPhaseInvert(), - ) - - self.encoding = torch.nn.Sequential( - Stereo() - ) - - for path in paths: - for ext in ['wav','flac','ogg','aiff','aif','mp3']: - self.filenames += glob(f'{path}/**/*.{ext}', recursive=True) - - self.sr = global_args.sample_rate - if hasattr(global_args,'load_frac'): - self.load_frac = global_args.load_frac - else: - self.load_frac = 1.0 - self.num_gpus = global_args.num_gpus - - self.cache_training_data = global_args.cache_training_data - - if self.cache_training_data: self.preload_files() - - - def load_file(self, filename): - audio, sr = torchaudio.load(filename) - if sr != self.sr: - resample_tf = T.Resample(sr, self.sr) - audio = resample_tf(audio) - return audio - - def load_file_ind(self, file_list,i): # used when caching training data - return self.load_file(file_list[i]).cpu() - - def get_data_range(self): # for parallel runs, only grab part of the data - start, stop = 0, len(self.filenames) - try: - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - interval = stop//world_size - start, stop = local_rank*interval, (local_rank+1)*interval - print("local_rank, world_size, start, stop =",local_rank, world_size, start, stop) - return start, stop - #rank = os.environ["RANK"] - except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet - start, stop = 0, len(self.filenames)//self.num_gpus - return start, stop - - def preload_files(self): - n = int(len(self.filenames)*self.load_frac) - print(f"Caching {n} input audio files:") - wrapper = partial(self.load_file_ind, self.filenames) - start, stop = self.get_data_range() - with Pool(processes=cpu_count()) as p: # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus) - self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start,stop)), total=stop-start)) - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, idx): - audio_filename = self.filenames[idx] - try: - if self.cache_training_data: - audio = self.audio_files[idx] # .copy() - else: - audio = self.load_file(audio_filename) - - #Run augmentations on this sample (including random crop) - if self.augs is not None: - audio = self.augs(audio) - - audio = audio.clamp(-1, 1) - - #Encode the file to assist in prediction - if self.encoding is not None: - audio = self.encoding(audio) - - return (audio, audio_filename) - except Exception as e: - # print(f'Couldn\'t load file {audio_filename}: {e}') - return self[random.randrange(len(self))] \ No newline at end of file + def __init__(self, paths, global_args): + super().__init__() + self.filenames = [] + print(f"Random crop: {global_args.random_crop}") + self.augs = torch.nn.Sequential( + PadCrop(global_args.sample_size, randomize=global_args.random_crop), + RandomPhaseInvert(), + ) + + self.encoding = torch.nn.Sequential( + Stereo() + ) + + for path in paths: + for ext in ['wav', 'flac', 'ogg', 'aiff', 'aif', 'mp3']: + self.filenames += glob(f'{path}/**/*.{ext}', recursive=True) + + self.sr = global_args.sample_rate + if hasattr(global_args, 'load_frac'): + self.load_frac = global_args.load_frac + else: + self.load_frac = 1.0 + self.num_gpus = global_args.num_gpus + + self.cache_training_data = global_args.cache_training_data + + if self.cache_training_data: self.preload_files() + + def load_file(self, filename): + audio, sr = torchaudio.load(filename) + if sr != self.sr: + resample_tf = T.Resample(sr, self.sr) + audio = resample_tf(audio) + return audio + + def load_file_ind(self, file_list, i): # used when caching training data + return self.load_file(file_list[i]).cpu() + + def get_data_range(self): # for parallel runs, only grab part of the data + start, stop = 0, len(self.filenames) + try: + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + interval = stop // world_size + start, stop = local_rank * interval, (local_rank + 1) * interval + print("local_rank, world_size, start, stop =", local_rank, world_size, start, stop) + return start, stop + # rank = os.environ["RANK"] + except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet + start, stop = 0, len(self.filenames) // self.num_gpus + return start, stop + + def preload_files(self): + n = int(len(self.filenames) * self.load_frac) + print(f"Caching {n} input audio files:") + wrapper = partial(self.load_file_ind, self.filenames) + start, stop = self.get_data_range() + with Pool(processes=cpu_count()) as p: # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus) + self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start, stop)), total=stop - start)) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + try: + if self.cache_training_data: + audio = self.audio_files[idx] # .copy() + else: + audio = self.load_file(audio_filename) + + # Run augmentations on this sample (including random crop) + if self.augs is not None: + audio = self.augs(audio) + + audio = audio.clamp(-1, 1) + + # Encode the file to assist in prediction + if self.encoding is not None: + audio = self.encoding(audio) + + return (audio, audio_filename) + except Exception as e: + # print(f'Couldn\'t load file {audio_filename}: {e}') + return self[random.randrange(len(self))] diff --git a/train_uncond.py b/train_uncond.py index 1d8fb74..2f02dcd 100644 --- a/train_uncond.py +++ b/train_uncond.py @@ -31,16 +31,19 @@ def get_alphas_sigmas(t): noise (sigma), given a timestep.""" return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + def get_crash_schedule(t): sigma = torch.sin(t * math.pi / 2) ** 2 alpha = (1 - sigma ** 2) ** 0.5 return alpha_sigma_to_t(alpha, sigma) + def alpha_sigma_to_t(alpha, sigma): """Returns a timestep, given the scaling factors for the clean image and for the noise.""" return torch.atan2(sigma, alpha) / math.pi * 2 + @torch.no_grad() def sample(model, x, steps, eta): """Draws samples from a model given starting noise.""" @@ -69,9 +72,9 @@ def sample(model, x, steps, eta): if i < steps - 1: # If eta > 0, adjust the scaling factor for the predicted noise # downward according to the amount of additional noise to add - ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ - (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() - adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + ddim_sigma = eta * (sigmas[i + 1] ** 2 / sigmas[i] ** 2).sqrt() * \ + (1 - alphas[i] ** 2 / alphas[i + 1] ** 2).sqrt() + adjusted_sigma = (sigmas[i + 1] ** 2 - ddim_sigma ** 2).sqrt() # Recombine the predicted noise and predicted denoised image in the # correct proportions for the next step @@ -85,7 +88,6 @@ def sample(model, x, steps, eta): return pred - class DiffusionUncond(pl.LightningModule): def __init__(self, global_args): super().__init__() @@ -94,10 +96,10 @@ def __init__(self, global_args): self.diffusion_ema = deepcopy(self.diffusion) self.rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=global_args.seed) self.ema_decay = global_args.ema_decay - + def configure_optimizers(self): return optim.Adam([*self.diffusion.parameters()], lr=4e-5) - + def training_step(self, batch, batch_idx): reals = batch[0] @@ -133,6 +135,7 @@ def on_before_zero_grad(self, *args, **kwargs): decay = 0.95 if self.current_epoch < 25 else self.ema_decay ema_update(self.diffusion, self.diffusion_ema, decay) + class ExceptionCallback(pl.Callback): def on_exception(self, trainer, module, err): print(f'{type(err).__name__}: {err}', file=sys.stderr) @@ -150,14 +153,14 @@ def __init__(self, global_args): @rank_zero_only @torch.no_grad() - #def on_train_epoch_end(self, trainer, module): - def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): - + # def on_train_epoch_end(self, trainer, module): + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step - + noise = torch.randn([self.num_demos, 2, self.demo_samples]).to(module.device) try: @@ -167,24 +170,23 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): fakes = rearrange(fakes, 'b d n -> d (b n)') log_dict = {} - + filename = f'demo_{trainer.global_step:08}.wav' fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) - log_dict[f'demo'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Demo') - + sample_rate=self.sample_rate, + caption=f'Demo') + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict, step=trainer.global_step) except Exception as e: print(f'{type(e).__name__}: {e}', file=sys.stderr) -def main(): +def main(): args = get_all_args() args.latent_dim = 0 @@ -201,7 +203,8 @@ def main(): wandb_logger = pl.loggers.WandbLogger(project=args.name) exc_callback = ExceptionCallback() - ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, dirpath=save_path) + ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, + dirpath=save_path) demo_callback = DemoCallback(args) diffusion_model = DiffusionUncond(args) @@ -215,7 +218,7 @@ def main(): # num_nodes = args.num_nodes, # strategy='ddp', precision=16, - accumulate_grad_batches=args.accum_batches, + accumulate_grad_batches=args.accum_batches, callbacks=[ckpt_callback, demo_callback, exc_callback], logger=wandb_logger, log_every_n_steps=1, @@ -224,6 +227,6 @@ def main(): diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path) + if __name__ == '__main__': main() - diff --git a/viz/viz.py b/viz/viz.py index c2695ac..c4707d4 100644 --- a/viz/viz.py +++ b/viz/viz.py @@ -1,27 +1,77 @@ - -import math -from pathlib import Path +import librosa +import numpy as np +import torchaudio.transforms as T +from matplotlib import pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg -import matplotlib.cm as cm -import matplotlib.pyplot as plt -from matplotlib.colors import Normalize from matplotlib.figure import Figure -import numpy as np -from PIL import Image -import torch -from torch import optim, nn -from torch.nn import functional as F -import torchaudio -import torchaudio.transforms as T -import librosa -from einops import rearrange -import wandb -import numpy as np -import pandas as pd +def img_is_color(img): + if len(img.shape) == 3: + # Check the color channels to see if they're all the same. + c1, c2, c3 = img[:, :, 0], img[:, :, 1], img[:, :, 2] + if (c1 == c2).all() and (c2 == c3).all(): + return True + + return False + + +def show_image_list(list_images, list_cmaps=None, grid=True, num_cols=2, figsize=(20, 10)): + """ + Shows a grid of images, where each image is a Numpy array. The images can be either + RGB or grayscale. + + Parameters: + ---------- + images: list + List of the images to be displayed. + list_titles: list or None + Optional list of titles to be shown for each image. + list_cmaps: list or None + Optional list of cmap values for each image. If None, then cmap will be + automatically inferred. + grid: boolean + If True, show a grid over each image + num_cols: int + Number of columns to show. + figsize: tuple of width, height + Value to be passed to pyplot.figure() + title_fontsize: int + Value to be passed to set_title(). + """ + + if list_cmaps is not None: + assert isinstance(list_cmaps, list) + assert len(list_images) == len(list_cmaps), '%d imgs != %d cmaps' % (len(list_images), len(list_cmaps)) + + num_images = len(list_images) + num_cols = min(num_images, num_cols) + num_rows = int(num_images / num_cols) + (1 if num_images % num_cols != 0 else 0) + + # Create a grid of subplots. + fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize) + + # Create list of axes for easy iteration. + if isinstance(axes, np.ndarray): + list_axes = list(axes.flat) + else: + list_axes = [axes] + + for i in range(num_images): + img = list_images[i] + cmap = list_cmaps[i] if list_cmaps is not None else (None if img_is_color(img) else 'gray') + + list_axes[i].imshow(img, cmap=cmap) + list_axes[i].grid(grid) + + for i in range(num_images, len(list_axes)): + list_axes[i].set_visible(False) + + fig.tight_layout() + return fig + -def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None, db_range=[35,120]): +def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None, db_range=[35, 120]): """ # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html @@ -38,7 +88,7 @@ def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=N fig.colorbar(im, ax=axs) canvas.draw() rgba = np.asarray(canvas.buffer_rgba()) - return Image.fromarray(rgba) + return rgba def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000): @@ -51,10 +101,11 @@ def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000): n_mels = 80 mel_spectrogram_op = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, - hop_length=hop_length, center=True, pad_mode="reflect", power=power, + sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, + hop_length=hop_length, center=True, pad_mode="reflect", power=power, norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") melspec = mel_spectrogram_op(waveform.float()) - melspec = melspec[0] # TODO: only left channel for now - return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)') + return show_image_list([spectrogram_image(melspec[0], title="left channel", ylabel='mel bins (log freq)'), + spectrogram_image(melspec[1], title="right channel", ylabel='mel bins (log freq)') + ])