diff --git a/pyannote/audio/models/embedding/__init__.py b/pyannote/audio/models/embedding/__init__.py index 2819096c2..2bd53d4f8 100644 --- a/pyannote/audio/models/embedding/__init__.py +++ b/pyannote/audio/models/embedding/__init__.py @@ -21,6 +21,7 @@ # SOFTWARE. +from .wavlm import WavLM_ECAPA_TDNN, WavLMEmbeddings from .wespeaker import ( WeSpeakerResNet34, WeSpeakerResNet152, @@ -36,4 +37,6 @@ "WeSpeakerResNet152", "WeSpeakerResNet221", "WeSpeakerResNet293", + "WavLMEmbeddings", + "WavLM_ECAPA_TDNN", ] diff --git a/pyannote/audio/models/embedding/wavlm.py b/pyannote/audio/models/embedding/wavlm.py new file mode 100644 index 000000000..bdec1dc67 --- /dev/null +++ b/pyannote/audio/models/embedding/wavlm.py @@ -0,0 +1,431 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import lru_cache +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.models.blocks.pooling import StatsPool +from pyannote.audio.utils.receptive_field import ( + conv1d_num_frames, + conv1d_receptive_field_center, + conv1d_receptive_field_size, +) + + +class WavLMEmbeddings(Model): + """Self-Supervised Representation for Speaker Embeddings extraction + + wav2vec > Stats pooling > Feed forward + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + def __init__( + self, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + emb_dim: Optional[int] = 512, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + self.wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + + self.pooling = StatsPool() + self.embedding = nn.Sequential( + nn.Linear(wav2vec_dim * 2, emb_dim), + nn.Linear(emb_dim, emb_dim), + ) + + self.save_hyperparameters("wav2vec", "wav2vec_layer", "emb_dim") + + @property + def dimension(self) -> int: + """Dimension of output""" + if isinstance(self.specifications, tuple): + raise ValueError("XVectorWavLM does not support multi-tasking.") + + if self.specifications.powerset: + return self.specifications.num_powerset_classes + else: + return len(self.specifications.classes) + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.wav2vec_weights, dim=0 + ) + else: + outputs = outputs[-1] + + outputs = torch.transpose(outputs, 1, 2) + outputs = self.pooling(outputs) + + return self.embedding(outputs) + + +class WavLM_ECAPA_TDNN(Model): + """Self-Supervised Representation for Speaker Embeddings extraction + + wav2vec > Stats pooling > Feed forward + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + emb_dim: int, optional + Dimension of the speaker embedding in output + """ + + def __init__( + self, + sample_rate: int = 16000, + num_channels: int = 1, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + freeze_wav2vec: bool = True, + emb_dim: Optional[int] = 192, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle.sample_rate: + raise ValueError( + f"Expected {bundle.sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + self.wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + + self.freeze_wav2vec = freeze_wav2vec + + self.ecapa_tdnn = ECAPA_TDNN(input_size=wav2vec_dim, lin_neurons=emb_dim) + + self.save_hyperparameters( + "wav2vec", "wav2vec_layer", "freeze_wav2vec", "emb_dim" + ) + + @property + def dimension(self) -> int: + """Dimension of output""" + if isinstance(self.specifications, tuple): + raise ValueError("XVectorWavLM does not support multi-tasking.") + + if self.specifications.powerset: + return self.specifications.num_powerset_classes + else: + return len(self.specifications.classes) + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + num_frames = num_samples + for conv_layer in self.wav2vec.feature_extractor.conv_layers: + num_frames = conv1d_num_frames( + num_frames, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + + return num_frames + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + receptive_field_size = num_frames + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_size = conv1d_receptive_field_size( + num_frames=receptive_field_size, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_size + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + receptive_field_center = frame + for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers): + receptive_field_center = conv1d_receptive_field_center( + receptive_field_center, + kernel_size=conv_layer.kernel_size, + stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], + dilation=conv_layer.conv.dilation[0], + ) + return receptive_field_center + + def forward(self, waveforms: torch.tensor): + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + if self.freeze_wav2vec: + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + else: + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.wav2vec_weights, dim=0 + ) + else: + outputs = outputs[-1] + + outputs = self.ecapa_tdnn(outputs) + + return outputs.squeeze(1) diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index cb6401e2b..f10383f48 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -30,14 +30,11 @@ from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric -from pyannote.audio.core.task import Task - from .mixins import SupervisedRepresentationLearningTaskMixin class SupervisedRepresentationLearningWithArcFace( SupervisedRepresentationLearningTaskMixin, - Task, ): """Supervised representation learning with ArcFace loss @@ -47,6 +44,13 @@ class SupervisedRepresentationLearningWithArcFace( ---------- protocol : Protocol pyannote.database protocol + cache : str, optional + As (meta-)data preparation might take a very long time for large datasets, + it can be cached to disk for later (and faster!) re-use. + When `cache` does not exist, `Task.prepare_data()` generates training + and validation metadata from `protocol` and save them to disk. + When `cache` exists, `Task.prepare_data()` is skipped and (meta)-data + are loaded from disk. Defaults to a temporary path. duration : float, optional Chunks duration in seconds. Defaults to two seconds (2.). min_duration : float, optional @@ -73,6 +77,8 @@ class SupervisedRepresentationLearningWithArcFace( metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. Defaults to AUROC (area under the ROC curve). + cache : string, optional + """ #  TODO: add a ".metric" property that tells how speaker embedding trained with this approach @@ -82,6 +88,7 @@ class SupervisedRepresentationLearningWithArcFace( def __init__( self, protocol: Protocol, + cache: Optional[str] = None, min_duration: Optional[float] = None, duration: float = 2.0, num_classes_per_batch: int = 32, @@ -109,12 +116,13 @@ def __init__( pin_memory=pin_memory, augmentation=augmentation, metric=metric, + cache=cache, ) def setup_loss_func(self): - + self.model.eval() _, embedding_size = self.model(self.model.example_input_array).shape - + self.model.train() self.model.loss_func = pytorch_metric_learning.losses.ArcFaceLoss( len(self.specifications.classes), embedding_size, diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index 9b404f9cf..b2e7993b4 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -21,6 +21,9 @@ # SOFTWARE. import math +import pickle +from pathlib import Path +from tempfile import mkstemp from typing import Dict, Sequence, Union import torch @@ -35,12 +38,12 @@ from torchmetrics.classification import BinaryAUROC from tqdm import tqdm -from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.torchmetrics.classification import EqualErrorRate from pyannote.audio.utils.random import create_rng_for_worker -class SupervisedRepresentationLearningTaskMixin: +class SupervisedRepresentationLearningTaskMixin(Task): """Methods common to most supervised representation tasks""" # batch_size = num_classes_per_batch x num_chunks_per_class @@ -75,11 +78,22 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int): self.batch_size_ = batch_size - def setup(self, stage=None): + def prepare_data(self): # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. - - self._train = dict() + if self.cache: + # check if cache exists and is not empty: + if self.cache.exists() and self.cache.stat().st_size > 0: + # data was already created, nothing to do + return + # create parent directory if needed + self.cache.parent.mkdir(parents=True, exist_ok=True) + else: + # if no cache was provided by user, create a temporary file + # in system directory used for temp files + self.cache = Path(mkstemp()[1]) + + train = {} desc = f"Loading {self.protocol.name} training labels" for f in tqdm(iterable=self.protocol.train(), desc=desc, unit="file"): @@ -99,10 +113,10 @@ def setup(self, stage=None): duration = sum(segment.duration for segment in speech_turns) # add class to the list of classes - if klass not in self._train: - self._train[klass] = list() + if klass not in train: + train[klass] = list() - self._train[klass].append( + train[klass].append( { "uri": f["uri"], "audio": f["audio"], @@ -111,12 +125,42 @@ def setup(self, stage=None): } ) + prepared_data = {"train": train, "protocol": self.protocol.name} + + self.prepare_validation(prepared_data) + self.post_prepare_data(prepared_data) + + # save prepared data on the disk + with open(self.cache, "wb") as cache_file: + pickle.dump(prepared_data, cache_file) + + def setup(self, stage=None): + if stage == "fit": + self.cache = self.trainer.strategy.broadcast(self.cache) + + try: + with open(self.cache, "rb") as cache_file: + self.prepared_data = pickle.load(cache_file) + except FileNotFoundError: + print( + "Cached data for protocol not found. Ensure that prepare_data() was called", + " and executed correctly or/and that the path to the task cache is correct.", + ) + raise + + # checks that the task current protocol matches the cached protocol + if self.protocol.name != self.prepared_data["protocol"]: + raise ValueError( + f"Protocol specified for the task ({self.protocol.name}) " + f"does not correspond to the cached one ({self.prepared_data['protocol']})" + ) + self.specifications = Specifications( problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=self.duration, min_duration=self.min_duration, - classes=sorted(self._train), + classes=sorted(self.prepared_data["train"]), ) def default_metric( @@ -162,8 +206,10 @@ def train__iter__(self): for _ in range(self.num_chunks_per_class): # select one file at random (with probability proportional to its class duration) file, *_ = rng.choices( - self._train[klass], - weights=[f["duration"] for f in self._train[klass]], + self.prepared_data["train"][klass], + weights=[ + f["duration"] for f in self.prepared_data["train"][klass] + ], k=1, ) @@ -207,7 +253,9 @@ def train__iter__(self): def train__len__(self): duration = sum( - datum["duration"] for data in self._train.values() for datum in data + datum["duration"] + for data in self.prepared_data["train"].values() + for datum in data ) avg_chunk_duration = 0.5 * (self.min_duration + self.duration) return max(self.batch_size, math.ceil(duration / avg_chunk_duration)) @@ -229,6 +277,18 @@ def training_step(self, batch, batch_idx: int): X, y = batch["X"], batch["y"] loss = self.model.loss_func(self.model(X), y) + if not self.model.automatic_optimization: + + wavlm_opt, other_opt = self.model.optimizers() + + wavlm_opt.zero_grad() + other_opt.zero_grad() + + self.model.manual_backward(loss) + + wavlm_opt.step() + other_opt.step() + # skip batch if something went wrong for some reason if torch.isnan(loss): return None @@ -246,7 +306,15 @@ def training_step(self, batch, batch_idx: int): def prepare_validation(self, prepared_dict: Dict): if isinstance(self.protocol, SpeakerVerificationProtocol): - prepared_dict["validation"] = list(self.protocol.development_trial()) + prepared_dict["validation"] = [] + for trial in self.protocol.development_trial(): + prepared_dict["validation"].append( + { + "reference": trial["reference"], + "file1": trial["file1"]["audio"], + "file2": trial["file2"]["audio"], + } + ) def val__getitem__(self, idx): if isinstance(self.protocol, SpeakerVerificationProtocol):