From c7cfc5d2647e26471dc394f95846a0830e7bec34 Mon Sep 17 00:00:00 2001 From: pseeth Date: Thu, 20 Jul 2023 11:05:10 -0700 Subject: [PATCH] Chunked inference for codec (#22) * Adding some devtools. * Adding delay calculation * Chunked inference for codec. * Version bump * Removing prod.yml, updating to recent main. * Turning padding off only when chunking codes. * Updating README, removing unused things. * Missed a padding. * Adding some checks to make sure pads are the same. * Factoring out latent dim, backwards compatible. * Adding latent dim, and the 44khz 16kbps model config. * Ran pre-commit. * Chunked vs unchunked inference. * Fixing padding stuff. * n quantizers back in encode * don't load unsupported versions * correct docstring * bitrate config + 16kbps models * update audiotools dep * fix argbind issue * minor correction * bump version * change model path * update audiotools deps --------- Co-authored-by: prem Co-authored-by: Ishaan Kumar --- Dockerfile.dev | 13 ++ README.md | 61 +++++-- conf/final/44khz-16kbps.yml | 124 +++++++++++++++ dac/__init__.py | 5 +- dac/__main__.py | 2 +- dac/model/__init__.py | 1 + dac/model/base.py | 306 ++++++++++++++++++++++++++++-------- dac/model/dac.py | 78 +++++---- dac/nn/quantize.py | 11 +- dac/utils/__init__.py | 61 ++++--- dac/utils/decode.py | 94 ++--------- dac/utils/encode.py | 116 ++------------ docker-compose.yml | 37 +++++ requirements.txt | 10 ++ scripts/mushra.py | 104 ++++++++++++ scripts/train.py | 13 +- setup.py | 4 +- 17 files changed, 703 insertions(+), 337 deletions(-) create mode 100644 Dockerfile.dev create mode 100644 conf/final/44khz-16kbps.yml create mode 100644 docker-compose.yml create mode 100644 requirements.txt create mode 100644 scripts/mushra.py diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 0000000..72730ab --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,13 @@ +ARG IMAGE=pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime +ARG GITHUB_TOKEN=none + +FROM $IMAGE + +RUN echo machine github.com login ${GITHUB_TOKEN} > ~/.netrc + +COPY requirements.txt /requirements.txt + +RUN apt update && apt install -y git + +# install the package +RUN pip install --upgrade -r /requirements.txt diff --git a/README.md b/README.md index fd8ba9c..6303f05 100644 --- a/README.md +++ b/README.md @@ -66,33 +66,42 @@ for more options. ### Programmatic Usage ```py import dac -from dac.utils import load_model -from dac.model import DAC - -from dac.utils.encode import process as encode -from dac.utils.decode import process as decode - from audiotools import AudioSignal -# Init an empty model -model = DAC() +# Download a model +model_path = dac.utils.download(model_type="44khz") +model = dac.DAC.load(model_path) -# Load compatible pre-trained model -model = load_model(tag="latest", model_type="44khz") -model.eval() model.to('cuda') # Load audio signal file signal = AudioSignal('input.wav') -# Encode audio signal -encoded_out = encode(signal, 'cuda', model) +# Encode audio signal as one long file +# (may run out of GPU memory on long files) +signal.to(model.device) + +x = model.preprocess(signal.audio_data, signal.sample_rate) +z, codes, latents, _, _ = model.encode(x) # Decode audio signal -recon = decode(encoded_out, 'cuda', model, preserve_sample_rate=True) +y = model.decode(z) + +# Alternatively, use the `compress` and `decompress` functions +# to compress long files. + +signal = signal.cpu() +x = model.compress(signal) + +# Save and load to and from disk +x.save("compressed.dac") +x = dac.DACFile.load("compressed.dac") + +# Decompress it back to an AudioSignal +y = model.decompress(x) # Write to file -recon.write('recon.wav') +y.write('output.wav') ``` ### Docker image @@ -131,6 +140,28 @@ Please install the correct dependencies pip install -e ".[dev]" ``` +## Environment setup + +We have provided a Dockerfile and docker compose setup that makes running experiments easy. + +To build the docker image do: + +``` +docker compose build +``` + +Then, to launch a container, do: + +``` +docker compose run -p 8888:8888 -p 6006:6006 dev +``` + +The port arguments (`-p`) are optional, but useful if you want to launch a Jupyter and Tensorboard instances within the container. The +default password for Jupyter is `password`, and the current directory +is mounted to `/u/home/src`, which also becomes the working directory. + +Then, run your training command. + ### Single GPU training ``` diff --git a/conf/final/44khz-16kbps.yml b/conf/final/44khz-16kbps.yml new file mode 100644 index 0000000..3ee405d --- /dev/null +++ b/conf/final/44khz-16kbps.yml @@ -0,0 +1,124 @@ +# Model setup +DAC.sample_rate: 44100 +DAC.encoder_dim: 64 +DAC.encoder_rates: [2, 4, 8, 8] +DAC.latent_dim: 128 +DAC.decoder_dim: 1536 +DAC.decoder_rates: [8, 8, 4, 2] + +# Quantization +DAC.n_codebooks: 18 # Max bitrate of 16kbps +DAC.codebook_size: 1024 +DAC.codebook_dim: 8 +DAC.quantizer_dropout: 0.5 + +# Discriminator +Discriminator.sample_rate: 44100 +Discriminator.rates: [] +Discriminator.periods: [2, 3, 5, 7, 11] +Discriminator.fft_sizes: [2048, 1024, 512] +Discriminator.bands: + - [0.0, 0.1] + - [0.1, 0.25] + - [0.25, 0.5] + - [0.5, 0.75] + - [0.75, 1.0] + +# Optimization +AdamW.betas: [0.8, 0.99] +AdamW.lr: 0.0001 +ExponentialLR.gamma: 0.999996 + +amp: false +val_batch_size: 100 +device: cuda +num_iters: 400000 +save_iters: [10000, 50000, 100000, 200000] +valid_freq: 1000 +sample_freq: 10000 +num_workers: 32 +val_idx: [0, 1, 2, 3, 4, 5, 6, 7] +seed: 0 +lambdas: + mel/loss: 15.0 + adv/feat_loss: 2.0 + adv/gen_loss: 1.0 + vq/commitment_loss: 0.25 + vq/codebook_loss: 1.0 + +VolumeNorm.db: [const, -16] + +# Transforms +build_transform.preprocess: + - Identity +build_transform.augment_prob: 0.0 +build_transform.augment: + - Identity +build_transform.postprocess: + - VolumeNorm + - RescaleAudio + - ShiftPhase + +# Loss setup +MultiScaleSTFTLoss.window_lengths: [2048, 512] +MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] +MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] +MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] +MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] +MelSpectrogramLoss.pow: 1.0 +MelSpectrogramLoss.clamp_eps: 1.0e-5 +MelSpectrogramLoss.mag_weight: 0.0 + +# Data +batch_size: 72 +train/AudioDataset.duration: 0.38 +train/AudioDataset.n_examples: 10000000 + +val/AudioDataset.duration: 5.0 +val/build_transform.augment_prob: 1.0 +val/AudioDataset.n_examples: 250 + +test/AudioDataset.duration: 10.0 +test/build_transform.augment_prob: 1.0 +test/AudioDataset.n_examples: 1000 + +AudioLoader.shuffle: true +AudioDataset.without_replacement: true + +train/build_dataset.folders: + speech_fb: + - /data/daps/train + speech_hq: + - /data/vctk + - /data/vocalset + - /data/read_speech + - /data/french_speech + speech_uq: + - /data/emotional_speech/ + - /data/common_voice/ + - /data/german_speech/ + - /data/russian_speech/ + - /data/spanish_speech/ + music_hq: + - /data/musdb/train + music_uq: + - /data/jamendo + general: + - /data/audioset/data/unbalanced_train_segments/ + - /data/audioset/data/balanced_train_segments/ + +val/build_dataset.folders: + speech_hq: + - /data/daps/val + music_hq: + - /data/musdb/test + general: + - /data/audioset/data/eval_segments/ + +test/build_dataset.folders: + speech_hq: + - /data/daps/test + music_hq: + - /data/musdb/test + general: + - /data/audioset/data/eval_segments/ diff --git a/dac/__init__.py b/dac/__init__.py index 7f988ef..51205ef 100644 --- a/dac/__init__.py +++ b/dac/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.5" +__version__ = "1.0.0" # preserved here for legacy reasons __model_version__ = "latest" @@ -11,3 +11,6 @@ from . import nn from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/dac/__main__.py b/dac/__main__.py index c19e004..2fa8d15 100644 --- a/dac/__main__.py +++ b/dac/__main__.py @@ -2,7 +2,7 @@ import argbind -from dac.utils import ensure_default_model as download +from dac.utils import download from dac.utils.decode import decode from dac.utils.encode import encode diff --git a/dac/model/__init__.py b/dac/model/__init__.py index 94304fd..02a75b7 100644 --- a/dac/model/__init__.py +++ b/dac/model/__init__.py @@ -1,3 +1,4 @@ from .base import CodecMixin +from .base import DACFile from .dac import DAC from .discriminator import Discriminator diff --git a/dac/model/base.py b/dac/model/base.py index 9da8a9b..546b3cb 100644 --- a/dac/model/base.py +++ b/dac/model/base.py @@ -1,116 +1,294 @@ import math +from dataclasses import dataclass from pathlib import Path from typing import Union +import numpy as np import torch import tqdm from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) class CodecMixin: - EXT = ".dac" + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L @torch.no_grad() - def reconstruct( + def compress( self, audio_path_or_signal: Union[str, Path, AudioSignal], - overlap_win_duration: float = 5.0, - overlap_hop_ratio: float = 0.5, + win_duration: float = 1.0, verbose: bool = False, normalize_db: float = -16, - match_input_db: bool = False, - mono: bool = False, - **kwargs, - ): - """Reconstructs an audio signal from a file or AudioSignal object. - This function decomposes the audio signal into overlapping windows - and reconstructs them one by one. The overlapping windows are then - overlap-and-added together to form the final output. + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. Parameters ---------- audio_path_or_signal : Union[str, Path, AudioSignal] audio signal to reconstruct - overlap_win_duration : float, optional - overlap window duration in seconds, by default 5.0 - overlap_hop_ratio : float, optional - overlap hop ratio, by default 0.5 + win_duration : float, optional + window duration in seconds, by default 5.0 verbose : bool, optional by default False normalize_db : float, optional normalize db, by default -16 - match_input_db : bool, optional - set True to match input db, by default False - mono : bool, optional - set True to convert to mono, by default False + Returns ------- - AudioSignal - reconstructed audio signal + DACFile + Object containing compressed codes and metadata + required for decompression """ - self.eval() audio_signal = audio_path_or_signal if isinstance(audio_signal, (str, Path)): audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) - if mono: - audio_signal = audio_signal.to_mono() + self.eval() + original_padding = self.padding + original_device = audio_signal.device audio_signal = audio_signal.clone() - audio_signal = audio_signal.ffmpeg_resample(self.sample_rate) + original_sr = audio_signal.sample_rate - original_length = audio_signal.signal_length - input_db = audio_signal.ffmpeg_loudness() + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness - # Fix overlap window so that it's divisible by 4 in # of samples - sr = audio_signal.sample_rate - overlap_win_duration = ((overlap_win_duration * sr) // 4) * 4 - overlap_win_duration = overlap_win_duration / sr + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() if normalize_db is not None: audio_signal.normalize(normalize_db) audio_signal.ensure_max_of_audio() - overlap_hop_duration = overlap_win_duration * overlap_hop_ratio - do_overlap_and_add = audio_signal.signal_duration > overlap_win_duration nb, nac, nt = audio_signal.audio_data.shape audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) - if do_overlap_and_add: - pad_length = ( - math.ceil(audio_signal.signal_duration / overlap_win_duration) - * overlap_win_duration - ) - audio_signal.zero_pad_to(int(pad_length * sr)) - audio_signal = audio_signal.collect_windows( - overlap_win_duration, overlap_hop_duration - ) + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + codes = [] range_fn = range if not verbose else tqdm.trange - for i in range_fn(audio_signal.batch_size): - signal_from_batch = AudioSignal( - audio_signal.audio_data[i, ...], audio_signal.sample_rate - ) - signal_from_batch.to(self.device) - _output = self.forward( - signal_from_batch.audio_data, signal_from_batch.sample_rate, **kwargs - ) - _output = _output["audio"].detach() - _output_signal = AudioSignal(_output, self.sample_rate).to(self.device) - audio_signal.audio_data[i] = _output_signal.audio_data.cpu() + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) - recons = audio_signal - recons._loudness = None - recons.stft_data = None + resample_fn = recons.resample + loudness_fn = recons.loudness - if do_overlap_and_add: - recons = recons.overlap_and_add(overlap_hop_duration) - recons.audio_data = recons.audio_data.reshape(nb, nac, -1) + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness - if match_input_db: - recons.ffmpeg_loudness() - recons = recons.normalize(input_db) + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) - recons.truncate_samples(original_length) + self.padding = original_padding return recons diff --git a/dac/model/dac.py b/dac/model/dac.py index b5a29b2..eb754b2 100644 --- a/dac/model/dac.py +++ b/dac/model/dac.py @@ -33,7 +33,11 @@ def __init__(self, dim: int = 16, dilation: int = 1): ) def forward(self, x): - return x + self.block(x) + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y class EncoderBlock(nn.Module): @@ -62,6 +66,7 @@ def __init__( self, d_model: int = 64, strides: list = [2, 4, 8, 8], + d_latent: int = 64, ): super().__init__() # Create first convolution @@ -75,7 +80,7 @@ def __init__( # Create last convolution self.block += [ Snake1d(d_model), - WNConv1d(d_model, d_model, kernel_size=3, padding=1), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), ] # Wrap black into nn.Sequential @@ -144,6 +149,7 @@ def __init__( self, encoder_dim: int = 64, encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, decoder_dim: int = 1536, decoder_rates: List[int] = [8, 8, 4, 2], n_codebooks: int = 9, @@ -160,15 +166,19 @@ def __init__( self.decoder_rates = decoder_rates self.sample_rate = sample_rate - self.hop_length = np.prod(decoder_rates) - self.encoder = Encoder(encoder_dim, encoder_rates) + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) self.n_codebooks = n_codebooks self.codebook_size = codebook_size self.codebook_dim = codebook_dim - self.quantizer = ResidualVectorQuantize( - self.encoder.enc_dim, + input_dim=latent_dim, n_codebooks=n_codebooks, codebook_size=codebook_size, codebook_dim=codebook_dim, @@ -176,13 +186,15 @@ def __init__( ) self.decoder = Decoder( - self.encoder.enc_dim, + latent_dim, decoder_dim, decoder_rates, ) self.sample_rate = sample_rate self.apply(init_weights) + self.delay = self.get_delay() + def preprocess(self, audio_data, sample_rate): if sample_rate is None: sample_rate = self.sample_rate @@ -191,12 +203,12 @@ def preprocess(self, audio_data, sample_rate): length = audio_data.shape[-1] right_pad = math.ceil(length / self.hop_length) * self.hop_length - length audio_data = nn.functional.pad(audio_data, (0, right_pad)) - return audio_data, length + + return audio_data def encode( self, audio_data: torch.Tensor, - sample_rate: int = None, n_quantizers: int = None, ): """Encode given audio data and return quantized latent codes @@ -205,9 +217,6 @@ def encode( ---------- audio_data : Tensor[B x 1 x T] Audio data to encode - sample_rate : int, optional - Sample rate of audio data in Hz, by default None - If None, defaults to `self.sample_rate` n_quantizers : int, optional Number of quantizers to use, by default None If None, all quantizers are used. @@ -231,15 +240,13 @@ def encode( "length" : int Number of samples in input audio """ - out = {} - audio_data, length = self.preprocess(audio_data, sample_rate) - out["length"] = length - - out["z"] = self.encoder(audio_data) - out.update(self.quantizer(out["z"], n_quantizers)) - return out + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss - def decode(self, z: torch.Tensor, length: int = None): + def decode(self, z: torch.Tensor): """Decode given latent codes and return audio data Parameters @@ -256,10 +263,7 @@ def decode(self, z: torch.Tensor, length: int = None): "audio" : Tensor[B x 1 x length] Decoded audio data. """ - out = {} - x = self.decoder(z) - out["audio"] = x[..., :length] - return out + return self.decoder(z) def forward( self, @@ -301,17 +305,28 @@ def forward( "audio" : Tensor[B x 1 x length] Decoded audio data. """ - out = {} - out.update(self.encode(audio_data, sample_rate, n_quantizers)) - out.update(self.decode(out["z"], out["length"])) - return out + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } if __name__ == "__main__": import numpy as np from functools import partial - model = DAC() + model = DAC().to("cpu") for n, m in model.named_modules(): o = m.extra_repr() @@ -328,6 +343,8 @@ def forward( # Make a forward pass out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) # Create gradient variable grad = torch.zeros_like(out) @@ -342,3 +359,6 @@ def forward( rf = (gradmap != 0).sum() print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/dac/nn/quantize.py b/dac/nn/quantize.py index 2a0b939..b17ff4a 100644 --- a/dac/nn/quantize.py +++ b/dac/nn/quantize.py @@ -192,13 +192,10 @@ def forward(self, z, n_quantizers: int = None): codebook_indices.append(indices_i) latents.append(z_e_i) - return { - "z": z_q, - "codes": torch.stack(codebook_indices, dim=1), - "latents": torch.cat(latents, dim=1), - "vq/commitment_loss": commitment_loss, - "vq/codebook_loss": codebook_loss, - } + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss def from_codes(self, codes: torch.Tensor): """Given the quantized codes, reconstruct the continuous representation diff --git a/dac/utils/__init__.py b/dac/utils/__init__.py index c4def3c..9e107bc 100644 --- a/dac/utils/__init__.py +++ b/dac/utils/__init__.py @@ -9,38 +9,52 @@ Accelerator = ml.Accelerator __MODEL_LATEST_TAGS__ = { - "44khz": "0.0.1", - "24khz": "0.0.4", - "16khz": "0.0.5", + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", } __MODEL_URLS__ = { ( "44khz", "0.0.1", + "8kbps", ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", ( "24khz", "0.0.4", + "8kbps", ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", ( "16khz", "0.0.5", + "8kbps", ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", } @argbind.bind(group="download", positional=True, without_prefix=True) -def ensure_default_model(tag: str = "latest", model_type: str = "44khz"): +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): """ Function that downloads the weights file from URL if a local cache is not found. Parameters ---------- - tag : str - The tag of the model to download. Defaults to "latest". model_type : str The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". Returns ------- @@ -56,10 +70,15 @@ def ensure_default_model(tag: str = "latest", model_type: str = "44khz"): "16khz", ], "model_type must be one of '44khz', '24khz', or '16khz'" + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + if tag == "latest": - tag = __MODEL_LATEST_TAGS__[model_type] + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] - download_link = __MODEL_URLS__.get((model_type, tag), None) + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) if download_link is None: raise ValueError( @@ -67,7 +86,11 @@ def ensure_default_model(tag: str = "latest", model_type: str = "44khz"): ) local_path = ( - Path.home() / ".cache" / "descript" / model_type / tag / "dac" / f"weights.pth" + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" ) if not local_path.exists(): local_path.parent.mkdir(parents=True, exist_ok=True) @@ -83,22 +106,18 @@ def ensure_default_model(tag: str = "latest", model_type: str = "44khz"): ) local_path.write_bytes(response.content) - # return the path required by audiotools to load the model - return local_path.parent.parent + return local_path def load_model( - tag: str = "latest", - load_path: str = "", model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, ): if not load_path: - load_path = ensure_default_model(tag, model_type) - kwargs = { - "folder": load_path, - "map_location": "cpu", - "package": False, - } - print(f"Loading weights from {kwargs['folder']}") - generator, _ = DAC.load_from_folder(**kwargs) + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) return generator diff --git a/dac/utils/decode.py b/dac/utils/decode.py index a3597e2..08d44e8 100644 --- a/dac/utils/decode.py +++ b/dac/utils/decode.py @@ -7,90 +7,12 @@ from audiotools import AudioSignal from tqdm import tqdm +from dac import DACFile from dac.utils import load_model warnings.filterwarnings("ignore", category=UserWarning) -@torch.no_grad() -@torch.inference_mode() -def process( - artifacts: dict, - device: str, - generator: torch.nn.Module, - preserve_sample_rate: bool, -) -> AudioSignal: - """Decode encoded audio. The `artifacts` contain codes from chunked windows - of the original audio signal. The codes are decoded one by one and windows are trimmed and concatenated together to form the final output. - - Parameters - ---------- - artifacts : dict - Dictionary of artifacts with the following keys: - - codes: the quantized codes - - metadata: dictionary with following keys - - original_db: the loudness of the input signal - - overlap_hop_duration: the hop duration of the overlap window - - original_length: the original length of the input signal - - is_overlap: whether the input signal was overlapped - - batch_size: the batch size of the input signal - - channels: the number of channels of the input signal - - original_sr: the original sample rate of the input signal - device : str - Device to use - generator : torch.nn.Module - Generator to decode with. - preserve_sample_rate : bool - If True, return audio will have the same sample rate as the original - encoded audio. If False, return audio will have the sample rate of the - generator. - - Returns - ------- - AudioSignal - """ - if isinstance(generator, torch.nn.DataParallel): - generator = generator.module - audio_signal = AudioSignal( - artifacts["codes"].astype(np.int64), generator.sample_rate - ) - metadata = artifacts["metadata"] - - # Decode chunks - output = [] - for i in range(audio_signal.batch_size): - signal_from_batch = AudioSignal( - audio_signal.audio_data[i, ...], audio_signal.sample_rate, device=device - ) - z_q = generator.quantizer.from_codes(signal_from_batch.audio_data)[0] - audio = generator.decode(z_q)["audio"].cpu() - output.append(audio) - - output = torch.cat(output, dim=0) - output_signal = AudioSignal(output, generator.sample_rate) - - # Overlap and add - if metadata["is_overlap"]: - boundary = int(metadata["overlap_hop_duration"] * generator.sample_rate / 2) - # remove window overlap - output_signal.trim(boundary, boundary) - output_signal.audio_data = output_signal.audio_data.reshape( - metadata["batch_size"], metadata["channels"], -1 - ) - # remove padding - output_signal.trim(boundary, boundary) - - # Restore loudness and truncate to original length - output_signal.ffmpeg_loudness() - output_signal = output_signal.normalize(metadata["original_db"]) - output_signal.truncate_samples(metadata["original_length"]) - - if preserve_sample_rate: - output_signal = output_signal.ffmpeg_resample(metadata["original_sr"]) - - return output_signal.to("cpu") - - @argbind.bind(group="decode", positional=True, without_prefix=True) @torch.inference_mode() @torch.no_grad() @@ -99,9 +21,10 @@ def decode( output: str = "", weights_path: str = "", model_tag: str = "latest", - preserve_sample_rate: bool = False, + model_bitrate: str = "8kbps", device: str = "cuda", model_type: str = "44khz", + verbose: bool = False, ): """Decode audio from codes. @@ -117,17 +40,18 @@ def decode( model_tag and model_type. model_tag : str, optional Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. - preserve_sample_rate : bool, optional - If True, return audio will have the same sample rate as the original + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". device : str, optional Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. model_type : str, optional The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. """ generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, tag=model_tag, load_path=weights_path, - model_type=model_type, ) generator.to(device) generator.eval() @@ -146,10 +70,10 @@ def decode( for i in tqdm(range(len(input_files)), desc=f"Decoding files"): # Load file - artifacts = np.load(input_files[i], allow_pickle=True)[()] + artifact = DACFile.load(input_files[i]) # Reconstruct audio from codes - recons = process(artifacts, device, generator, preserve_sample_rate) + recons = generator.decompress(artifact, verbose=verbose) # Compute output path relative_path = input_files[i].relative_to(input) diff --git a/dac/utils/encode.py b/dac/utils/encode.py index b33eb13..aa3f6f4 100644 --- a/dac/utils/encode.py +++ b/dac/utils/encode.py @@ -14,108 +14,6 @@ warnings.filterwarnings("ignore", category=UserWarning) -@torch.no_grad() -@torch.inference_mode() -def process( - signal: AudioSignal, device: str, generator: torch.nn.Module, **kwargs -) -> dict: - """Encode an audio signal. The signal is chunked into overlapping windows - and encoded one by one. - - Parameters - ---------- - signal : AudioSignal - Input signal to encode - device : str - Device to use - generator : torch.nn.Module - Generator to encode with - - Returns - ------- - dict - Dictionary of artifacts with the following keys: - - codes: the quantized codes - - metadata: dictionary with following keys - - original_db: the loudness of the input signal - - overlap_hop_duration: the hop duration of the overlap window - - original_length: the original length of the input signal - - is_overlap: whether the input signal was overlapped - - batch_size: the batch size of the input signal - - channels: the number of channels of the input signal - - original_sr: the original sample rate of the input signal - - """ - if isinstance(generator, torch.nn.DataParallel): - generator = generator.module - - original_sr = signal.sample_rate - - # Resample input - audio_signal = signal.ffmpeg_resample(generator.sample_rate) - - original_length = audio_signal.signal_length - input_db = audio_signal.ffmpeg_loudness() - - # Set variables - sr = audio_signal.sample_rate - overlap_win_duration = 5.0 - overlap_hop_ratio = 0.5 - - # Fix overlap window so that it's divisible by 4 in # of samples - overlap_win_duration = ((overlap_win_duration * sr) // 4) * 4 - overlap_win_duration = overlap_win_duration / sr - overlap_hop_duration = overlap_win_duration * overlap_hop_ratio - do_overlap_and_add = audio_signal.signal_duration > overlap_win_duration - - # TODO (eeishaan): Remove this when correct caching logic is implemented and - # overlap of codes is minimal - do_overlap_and_add = False - - # Sanitize input - audio_signal.normalize(-16) - audio_signal.ensure_max_of_audio() - - nb, nac, nt = audio_signal.audio_data.shape - audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) - - if do_overlap_and_add: - pad_length = ( - math.ceil(audio_signal.signal_duration / overlap_win_duration) - * overlap_win_duration - ) - audio_signal.zero_pad_to(int(pad_length * sr)) - audio_signal = audio_signal.collect_windows( - overlap_win_duration, overlap_hop_duration - ) - - codebook_indices = [] - for i in range(audio_signal.batch_size): - signal_from_batch = AudioSignal( - audio_signal.audio_data[i, ...], audio_signal.sample_rate - ) - signal_from_batch.to(device) - codes = generator.encode( - signal_from_batch.audio_data, signal_from_batch.sample_rate, **kwargs - )["codes"].cpu() - codebook_indices.append(codes) - - codebook_indices = torch.cat(codebook_indices, dim=0) - - return { - "codes": codebook_indices.numpy().astype(np.uint16), - "metadata": { - "original_db": input_db, - "overlap_hop_duration": overlap_hop_duration, - "original_length": original_length, - "is_overlap": do_overlap_and_add, - "batch_size": nb, - "channels": nac, - "original_sr": original_sr, - }, - } - - @argbind.bind(group="encode", positional=True, without_prefix=True) @torch.inference_mode() @torch.no_grad() @@ -124,9 +22,12 @@ def encode( output: str = "", weights_path: str = "", model_tag: str = "latest", + model_bitrate: str = "8kbps", n_quantizers: int = None, device: str = "cuda", model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, ): """Encode audio files in input path to .dac format. @@ -141,6 +42,8 @@ def encode( model_tag and model_type. model_tag : str, optional Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". n_quantizers : int, optional Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. device : str, optional @@ -149,9 +52,10 @@ def encode( The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. """ generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, tag=model_tag, load_path=weights_path, - model_type=model_type, ) generator.to(device) generator.eval() @@ -169,7 +73,7 @@ def encode( signal = AudioSignal(audio_files[i]) # Encode audio to .dac format - artifacts = process(signal, device, generator, **kwargs) + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) # Compute output path relative_path = audio_files[i].relative_to(input) @@ -181,9 +85,7 @@ def encode( output_path = output_dir / output_name output_path.parent.mkdir(parents=True, exist_ok=True) - # Write to file - with open(output_path, "wb") as f: - np.save(f, artifacts) + artifact.save(output_path) if __name__ == "__main__": diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..e16112a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,37 @@ +version: "3.5" +services: + base: + build: + context: . + dockerfile: ./Dockerfile.dev + args: + GITHUB_TOKEN: ${GITHUB_TOKEN} + IMAGE: ${IMAGE} + volumes: + - .:/u/home/src + - ${PATH_TO_DATA}:/data + - ${PATH_TO_RUNS}:/runs + - ~/.config/gcloud:/u/home/.config/gcloud + - ~/.zsh_history:/u/home/.zsh_history + environment: + - GITHUB_TOKEN + - HOST_USER_ID + - HOST_USER_GID + - JUPYTER_TOKEN=password + - PATH_TO_DATA=/data + - PATH_TO_RUNS=/runs + - MPLCONFIGDIR=/u/home/.mplconfig + shm_size: 32G + working_dir: /u/home/src + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] + dev: + extends: base + profiles: + - interactive + stdin_open: true + tty: true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2919657 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +argbind>=0.3.7 +descript-audiotools>=0.7.2 +einops +numpy +torch +torchaudio +tqdm +tensorboard +numba>=0.5.7 +jupyterlab diff --git a/scripts/mushra.py b/scripts/mushra.py new file mode 100644 index 0000000..2bdf006 --- /dev/null +++ b/scripts/mushra.py @@ -0,0 +1,104 @@ +import string +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import argbind +import gradio as gr +from audiotools import preference as pr + + +@argbind.bind(without_prefix=True) +@dataclass +class Config: + folder: str = None + save_path: str = "results.csv" + conditions: List[str] = None + reference: str = None + seed: int = 0 + share: bool = False + n_samples: int = 10 + + +def get_text(wav_file: str): + txt_file = Path(wav_file).with_suffix(".txt") + if Path(txt_file).exists(): + with open(txt_file, "r") as f: + txt = f.read() + else: + txt = "" + return f"""
{txt}
""" + + +def main(config: Config): + with gr.Blocks() as app: + save_path = config.save_path + samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples)) + + reference = config.reference + conditions = config.conditions + + player = pr.Player(app) + player.create() + if reference is not None: + player.add("Play Reference") + + user = pr.create_tracker(app) + ratings = [] + + with gr.Row(): + txt = gr.HTML("") + + with gr.Row(): + gr.Button("Rate audio quality", interactive=False) + with gr.Column(scale=8): + gr.HTML(pr.slider_mushra) + + for i in range(len(conditions)): + with gr.Row().style(equal_height=True): + x = string.ascii_uppercase[i] + player.add(f"Play {x}") + with gr.Column(scale=9): + ratings.append(gr.Slider(value=50, interactive=True)) + + def build(user, samples, *ratings): + # Filter out samples user has done already, by looking in the CSV. + samples.filter_completed(user, save_path) + + # Write results to CSV + if samples.current > 0: + start_idx = 1 if reference is not None else 0 + name = samples.names[samples.current - 1] + result = {"sample": name, "user": user} + for k, r in zip(samples.order[start_idx:], ratings): + result[k] = r + pr.save_result(result, save_path) + + updates, done, pbar = samples.get_next_sample(reference, conditions) + wav_file = updates[0]["value"] + + txt_update = gr.update(value=get_text(wav_file)) + + return ( + updates + + [gr.update(value=50) for _ in ratings] + + [done, samples, pbar, txt_update] + ) + + progress = gr.HTML() + begin = gr.Button("Submit", elem_id="start-survey") + begin.click( + fn=build, + inputs=[user, samples] + ratings, + outputs=player.to_list() + ratings + [begin, samples, progress, txt], + ).then(None, _js=pr.reset_player) + + # Comment this back in to actually launch the script. + app.launch(share=config.share) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + config = Config() + main(config) diff --git a/scripts/train.py b/scripts/train.py index 57e2e41..646ed57 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -149,6 +149,9 @@ def load( generator = DAC() if generator is None else generator discriminator = Discriminator() if discriminator is None else discriminator + tracker.print(generator) + tracker.print(discriminator) + generator = accel.prepare_model(generator) discriminator = accel.prepare_model(discriminator) @@ -208,8 +211,8 @@ def val_loop(batch, state, accel): batch["signal"].clone(), **batch["transform_args"] ) - recons = state.generator(signal.audio_data, signal.sample_rate)["audio"] - recons = AudioSignal(recons, signal.sample_rate) + out = state.generator(signal.audio_data, signal.sample_rate) + recons = AudioSignal(out["audio"], signal.sample_rate) return { "loss": state.mel_loss(recons, signal), @@ -320,8 +323,8 @@ def save_samples(state, val_idx, writer): batch["signal"].clone(), **batch["transform_args"] ) - recons = state.generator(signal.audio_data, signal.sample_rate)["audio"] - recons = AudioSignal(recons, signal.sample_rate) + out = state.generator(signal.audio_data, signal.sample_rate) + recons = AudioSignal(out["audio"], signal.sample_rate) audio_dict = {"recons": recons} if state.tracker.step == 0: @@ -390,7 +393,7 @@ def train( num_workers=num_workers, batch_size=val_batch_size, collate_fn=state.val_data.collate, - persistent_workers=True, + persistent_workers=True if num_workers > 0 else False, ) # Wrap the functions so that they neatly track in TensorBoard + progress bars diff --git a/setup.py b/setup.py index 68cac4d..b681d97 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="descript-audio-codec", - version="0.0.5", + version="1.0.0", classifiers=[ "Intended Audience :: Developers", "Natural Language :: English", @@ -28,7 +28,7 @@ keywords=["audio", "compression", "machine learning"], install_requires=[ "argbind>=0.3.7", - "descript-audiotools==0.7.1", + "descript-audiotools>=0.7.2", "einops", "numpy", "torch",