124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""Command-line for audio compression."""
|
||
|
|
||
|
import argparse
|
||
|
from pathlib import Path
|
||
|
import sys
|
||
|
|
||
|
import torchaudio
|
||
|
|
||
|
from .compress import compress, decompress, MODELS
|
||
|
from .utils import save_audio, convert_audio
|
||
|
|
||
|
|
||
|
SUFFIX = '.ecdc'
|
||
|
|
||
|
|
||
|
def get_parser():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
'encodec',
|
||
|
description='High fidelity neural audio codec. '
|
||
|
'If input is a .ecdc, decompresses it. '
|
||
|
'If input is .wav, compresses it. If output is also wav, '
|
||
|
'do a compression/decompression cycle.')
|
||
|
parser.add_argument(
|
||
|
'input', type=Path,
|
||
|
help='Input file, whatever is supported by torchaudio on your system.')
|
||
|
parser.add_argument(
|
||
|
'output', type=Path, nargs='?',
|
||
|
help='Output file, otherwise inferred from input file.')
|
||
|
parser.add_argument(
|
||
|
'-b', '--bandwidth', type=float, default=6, choices=[1.5, 3., 6., 12., 24.],
|
||
|
help='Target bandwidth (1.5, 3, 6, 12 or 24). 1.5 is not supported with --hq.')
|
||
|
parser.add_argument(
|
||
|
'-q', '--hq', action='store_true',
|
||
|
help='Use HQ stereo model operating on 48 kHz sampled audio.')
|
||
|
parser.add_argument(
|
||
|
'-l', '--lm', action='store_true',
|
||
|
help='Use a language model to reduce the model size (5x slower though).')
|
||
|
parser.add_argument(
|
||
|
'-f', '--force', action='store_true',
|
||
|
help='Overwrite output file if it exists.')
|
||
|
parser.add_argument(
|
||
|
'-s', '--decompress_suffix', type=str, default='_decompressed',
|
||
|
help='Suffix for the decompressed output file (if no output path specified)')
|
||
|
parser.add_argument(
|
||
|
'-r', '--rescale', action='store_true',
|
||
|
help='Automatically rescale the output to avoid clipping.')
|
||
|
return parser
|
||
|
|
||
|
|
||
|
def fatal(*args):
|
||
|
print(*args, file=sys.stderr)
|
||
|
sys.exit(1)
|
||
|
|
||
|
|
||
|
def check_output_exists(args):
|
||
|
if not args.output.parent.exists():
|
||
|
fatal(f"Output folder for {args.output} does not exist.")
|
||
|
if args.output.exists() and not args.force:
|
||
|
fatal(f"Output file {args.output} exist. Use -f / --force to overwrite.")
|
||
|
|
||
|
|
||
|
def check_clipping(wav, args):
|
||
|
if args.rescale:
|
||
|
return
|
||
|
mx = wav.abs().max()
|
||
|
limit = 0.99
|
||
|
if mx > limit:
|
||
|
print(
|
||
|
f"Clipping!! max scale {mx}, limit is {limit}. "
|
||
|
"To avoid clipping, use the `-r` option to rescale the output.",
|
||
|
file=sys.stderr)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = get_parser().parse_args()
|
||
|
if not args.input.exists():
|
||
|
fatal(f"Input file {args.input} does not exist.")
|
||
|
|
||
|
if args.input.suffix.lower() == SUFFIX:
|
||
|
# Decompression
|
||
|
if args.output is None:
|
||
|
args.output = args.input.with_name(args.input.stem + args.decompress_suffix).with_suffix('.wav')
|
||
|
elif args.output.suffix.lower() != '.wav':
|
||
|
fatal("Output extension must be .wav")
|
||
|
check_output_exists(args)
|
||
|
out, out_sample_rate = decompress(args.input.read_bytes())
|
||
|
check_clipping(out, args)
|
||
|
save_audio(out, args.output, out_sample_rate, rescale=args.rescale)
|
||
|
else:
|
||
|
# Compression
|
||
|
if args.output is None:
|
||
|
args.output = args.input.with_suffix(SUFFIX)
|
||
|
elif args.output.suffix.lower() not in [SUFFIX, '.wav']:
|
||
|
fatal(f"Output extension must be .wav or {SUFFIX}")
|
||
|
check_output_exists(args)
|
||
|
|
||
|
model_name = 'encodec_48khz' if args.hq else 'encodec_24khz'
|
||
|
model = MODELS[model_name]()
|
||
|
if args.bandwidth not in model.target_bandwidths:
|
||
|
fatal(f"Bandwidth {args.bandwidth} is not supported by the model {model_name}")
|
||
|
model.set_target_bandwidth(args.bandwidth)
|
||
|
|
||
|
wav, sr = torchaudio.load(args.input)
|
||
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||
|
compressed = compress(model, wav, use_lm=args.lm)
|
||
|
if args.output.suffix.lower() == SUFFIX:
|
||
|
args.output.write_bytes(compressed)
|
||
|
else:
|
||
|
# Directly run decompression stage
|
||
|
assert args.output.suffix.lower() == '.wav'
|
||
|
out, out_sample_rate = decompress(compressed)
|
||
|
check_clipping(out, args)
|
||
|
save_audio(out, args.output, out_sample_rate, rescale=args.rescale)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|