# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Various utilities.""" from hashlib import sha256 from pathlib import Path import typing as tp import torch import torchaudio def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario # e.g., more than 2 frames per position. # The core idea is to use a weight function that is a triangle, # with a maximum value at the middle of the segment. # We use this weighting when summing the frames, and divide by the sum of weights # for each positions at the end. Thus: # - if a frame is the only one to cover a position, the weighting is a no-op. # - if 2 frames cover a position: # ... ... # / \/ \ # / /\ \ # S T , i.e. S offset of second frame starts, T end of first frame. # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. # After the final normalization, the weight of the second frame at position `t` is # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. # # - if more than 2 frames overlap at a given point, we hope that by induction # something sensible happens. assert len(frames) device = frames[0].device dtype = frames[0].dtype shape = frames[0].shape[:-1] total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] frame_length = frames[0].shape[-1] t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] weight = 0.5 - (t - 0.5).abs() sum_weight = torch.zeros(total_size, device=device, dtype=dtype) out = torch.zeros(*shape, total_size, device=device, dtype=dtype) offset: int = 0 for frame in frames: frame_length = frame.shape[-1] out[..., offset:offset + frame_length] += weight[:frame_length] * frame sum_weight[offset:offset + frame_length] += weight[:frame_length] offset += stride assert sum_weight.min() > 0 return out / sum_weight def _get_checkpoint_url(root_url: str, checkpoint: str): if not root_url.endswith('/'): root_url += '/' return root_url + checkpoint def _check_checksum(path: Path, checksum: str): sha = sha256() with open(path, 'rb') as file: while True: buf = file.read(2**20) if not buf: break sha.update(buf) actual_checksum = sha.hexdigest()[:len(checksum)] if actual_checksum != checksum: raise RuntimeError(f'Invalid checksum for file {path}, ' f'expected {checksum} but got {actual_checksum}') def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." if target_channels == 1: wav = wav.mean(0, keepdim=True) elif target_channels == 2: *shape, _, length = wav.shape wav = wav.expand(*shape, target_channels, length) elif wav.shape[0] == 1: wav = wav.expand(target_channels, -1) wav = torchaudio.transforms.Resample(sr, target_sr)(wav) return wav def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False): limit = 0.99 mx = wav.abs().max() if rescale: wav = wav * min(limit / mx, 1) else: wav = wav.clamp(-limit, limit) torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)