207 lines
7.2 KiB
Python
207 lines
7.2 KiB
Python
|
# Adapted from https://github.com/pytorch/audio/
|
||
|
|
||
|
import hashlib
|
||
|
import logging
|
||
|
import os
|
||
|
import tarfile
|
||
|
import urllib
|
||
|
import urllib.request
|
||
|
import zipfile
|
||
|
from os.path import expanduser
|
||
|
from typing import Any, Iterable, List, Optional
|
||
|
|
||
|
from torch.utils.model_zoo import tqdm
|
||
|
|
||
|
|
||
|
def stream_url(
|
||
|
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||
|
) -> Iterable:
|
||
|
"""Stream url by chunk
|
||
|
|
||
|
Args:
|
||
|
url (str): Url.
|
||
|
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
|
||
|
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
|
||
|
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||
|
"""
|
||
|
|
||
|
# If we already have the whole file, there is no need to download it again
|
||
|
req = urllib.request.Request(url, method="HEAD")
|
||
|
with urllib.request.urlopen(req) as response:
|
||
|
url_size = int(response.info().get("Content-Length", -1))
|
||
|
if url_size == start_byte:
|
||
|
return
|
||
|
|
||
|
req = urllib.request.Request(url)
|
||
|
if start_byte:
|
||
|
req.headers["Range"] = "bytes={}-".format(start_byte)
|
||
|
|
||
|
with urllib.request.urlopen(req) as upointer, tqdm(
|
||
|
unit="B",
|
||
|
unit_scale=True,
|
||
|
unit_divisor=1024,
|
||
|
total=url_size,
|
||
|
disable=not progress_bar,
|
||
|
) as pbar:
|
||
|
num_bytes = 0
|
||
|
while True:
|
||
|
chunk = upointer.read(block_size)
|
||
|
if not chunk:
|
||
|
break
|
||
|
yield chunk
|
||
|
num_bytes += len(chunk)
|
||
|
pbar.update(len(chunk))
|
||
|
|
||
|
|
||
|
def download_url(
|
||
|
url: str,
|
||
|
download_folder: str,
|
||
|
filename: Optional[str] = None,
|
||
|
hash_value: Optional[str] = None,
|
||
|
hash_type: str = "sha256",
|
||
|
progress_bar: bool = True,
|
||
|
resume: bool = False,
|
||
|
) -> None:
|
||
|
"""Download file to disk.
|
||
|
|
||
|
Args:
|
||
|
url (str): Url.
|
||
|
download_folder (str): Folder to download file.
|
||
|
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
|
||
|
(Default: ``None``).
|
||
|
hash_value (str or None, optional): Hash for url (Default: ``None``).
|
||
|
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||
|
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||
|
resume (bool, optional): Enable resuming download (Default: ``False``).
|
||
|
"""
|
||
|
|
||
|
req = urllib.request.Request(url, method="HEAD")
|
||
|
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
|
||
|
|
||
|
# Detect filename
|
||
|
filename = filename or req_info.get_filename() or os.path.basename(url)
|
||
|
filepath = os.path.join(download_folder, filename)
|
||
|
if resume and os.path.exists(filepath):
|
||
|
mode = "ab"
|
||
|
local_size: Optional[int] = os.path.getsize(filepath)
|
||
|
|
||
|
elif not resume and os.path.exists(filepath):
|
||
|
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
|
||
|
else:
|
||
|
mode = "wb"
|
||
|
local_size = None
|
||
|
|
||
|
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
|
||
|
with open(filepath, "rb") as file_obj:
|
||
|
if validate_file(file_obj, hash_value, hash_type):
|
||
|
return
|
||
|
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||
|
|
||
|
with open(filepath, mode) as fpointer:
|
||
|
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
|
||
|
fpointer.write(chunk)
|
||
|
|
||
|
with open(filepath, "rb") as file_obj:
|
||
|
if hash_value and not validate_file(file_obj, hash_value, hash_type):
|
||
|
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||
|
|
||
|
|
||
|
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
|
||
|
"""Validate a given file object with its hash.
|
||
|
|
||
|
Args:
|
||
|
file_obj: File object to read from.
|
||
|
hash_value (str): Hash for url.
|
||
|
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||
|
|
||
|
Returns:
|
||
|
bool: return True if its a valid file, else False.
|
||
|
"""
|
||
|
|
||
|
if hash_type == "sha256":
|
||
|
hash_func = hashlib.sha256()
|
||
|
elif hash_type == "md5":
|
||
|
hash_func = hashlib.md5()
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
while True:
|
||
|
# Read by chunk to avoid filling memory
|
||
|
chunk = file_obj.read(1024**2)
|
||
|
if not chunk:
|
||
|
break
|
||
|
hash_func.update(chunk)
|
||
|
|
||
|
return hash_func.hexdigest() == hash_value
|
||
|
|
||
|
|
||
|
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
|
||
|
"""Extract archive.
|
||
|
Args:
|
||
|
from_path (str): the path of the archive.
|
||
|
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
|
||
|
(Default: ``None``)
|
||
|
overwrite (bool, optional): overwrite existing files (Default: ``False``)
|
||
|
|
||
|
Returns:
|
||
|
list: List of paths to extracted files even if not overwritten.
|
||
|
"""
|
||
|
|
||
|
if to_path is None:
|
||
|
to_path = os.path.dirname(from_path)
|
||
|
|
||
|
try:
|
||
|
with tarfile.open(from_path, "r") as tar:
|
||
|
logging.info("Opened tar file %s.", from_path)
|
||
|
files = []
|
||
|
for file_ in tar: # type: Any
|
||
|
file_path = os.path.join(to_path, file_.name)
|
||
|
if file_.isfile():
|
||
|
files.append(file_path)
|
||
|
if os.path.exists(file_path):
|
||
|
logging.info("%s already extracted.", file_path)
|
||
|
if not overwrite:
|
||
|
continue
|
||
|
tar.extract(file_, to_path)
|
||
|
return files
|
||
|
except tarfile.ReadError:
|
||
|
pass
|
||
|
|
||
|
try:
|
||
|
with zipfile.ZipFile(from_path, "r") as zfile:
|
||
|
logging.info("Opened zip file %s.", from_path)
|
||
|
files = zfile.namelist()
|
||
|
for file_ in files:
|
||
|
file_path = os.path.join(to_path, file_)
|
||
|
if os.path.exists(file_path):
|
||
|
logging.info("%s already extracted.", file_path)
|
||
|
if not overwrite:
|
||
|
continue
|
||
|
zfile.extract(file_, to_path)
|
||
|
return files
|
||
|
except zipfile.BadZipFile:
|
||
|
pass
|
||
|
|
||
|
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
||
|
|
||
|
|
||
|
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
|
||
|
"""Download dataset from kaggle.
|
||
|
Args:
|
||
|
dataset_path (str):
|
||
|
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
|
||
|
dataset_name (str): Name of the folder the dataset will be saved in.
|
||
|
output_path (str): Path of the location you want the dataset folder to be saved to.
|
||
|
"""
|
||
|
data_path = os.path.join(output_path, dataset_name)
|
||
|
try:
|
||
|
import kaggle # pylint: disable=import-outside-toplevel
|
||
|
|
||
|
kaggle.api.authenticate()
|
||
|
print(f"""\nDownloading {dataset_name}...""")
|
||
|
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
|
||
|
except OSError:
|
||
|
print(
|
||
|
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
|
||
|
)
|