ai-content-maker/.venv/Lib/site-packages/torchaudio/datasets/gtzan.py

1119 lines
24 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import os
from pathlib import Path
from typing import Optional, Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_tar
# The following lists prefixed with `filtered_` provide a filtered split
# that:
#
# a. Mitigate a known issue with GTZAN (duplication)
#
# b. Provide a standard split for testing it against other
# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
#
# Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
gtzan_genres = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]
filtered_test = [
"blues.00012",
"blues.00013",
"blues.00014",
"blues.00015",
"blues.00016",
"blues.00017",
"blues.00018",
"blues.00019",
"blues.00020",
"blues.00021",
"blues.00022",
"blues.00023",
"blues.00024",
"blues.00025",
"blues.00026",
"blues.00027",
"blues.00028",
"blues.00061",
"blues.00062",
"blues.00063",
"blues.00064",
"blues.00065",
"blues.00066",
"blues.00067",
"blues.00068",
"blues.00069",
"blues.00070",
"blues.00071",
"blues.00072",
"blues.00098",
"blues.00099",
"classical.00011",
"classical.00012",
"classical.00013",
"classical.00014",
"classical.00015",
"classical.00016",
"classical.00017",
"classical.00018",
"classical.00019",
"classical.00020",
"classical.00021",
"classical.00022",
"classical.00023",
"classical.00024",
"classical.00025",
"classical.00026",
"classical.00027",
"classical.00028",
"classical.00029",
"classical.00034",
"classical.00035",
"classical.00036",
"classical.00037",
"classical.00038",
"classical.00039",
"classical.00040",
"classical.00041",
"classical.00049",
"classical.00077",
"classical.00078",
"classical.00079",
"country.00030",
"country.00031",
"country.00032",
"country.00033",
"country.00034",
"country.00035",
"country.00036",
"country.00037",
"country.00038",
"country.00039",
"country.00040",
"country.00043",
"country.00044",
"country.00046",
"country.00047",
"country.00048",
"country.00050",
"country.00051",
"country.00053",
"country.00054",
"country.00055",
"country.00056",
"country.00057",
"country.00058",
"country.00059",
"country.00060",
"country.00061",
"country.00062",
"country.00063",
"country.00064",
"disco.00001",
"disco.00021",
"disco.00058",
"disco.00062",
"disco.00063",
"disco.00064",
"disco.00065",
"disco.00066",
"disco.00069",
"disco.00076",
"disco.00077",
"disco.00078",
"disco.00079",
"disco.00080",
"disco.00081",
"disco.00082",
"disco.00083",
"disco.00084",
"disco.00085",
"disco.00086",
"disco.00087",
"disco.00088",
"disco.00091",
"disco.00092",
"disco.00093",
"disco.00094",
"disco.00096",
"disco.00097",
"disco.00099",
"hiphop.00000",
"hiphop.00026",
"hiphop.00027",
"hiphop.00030",
"hiphop.00040",
"hiphop.00043",
"hiphop.00044",
"hiphop.00045",
"hiphop.00051",
"hiphop.00052",
"hiphop.00053",
"hiphop.00054",
"hiphop.00062",
"hiphop.00063",
"hiphop.00064",
"hiphop.00065",
"hiphop.00066",
"hiphop.00067",
"hiphop.00068",
"hiphop.00069",
"hiphop.00070",
"hiphop.00071",
"hiphop.00072",
"hiphop.00073",
"hiphop.00074",
"hiphop.00075",
"hiphop.00099",
"jazz.00073",
"jazz.00074",
"jazz.00075",
"jazz.00076",
"jazz.00077",
"jazz.00078",
"jazz.00079",
"jazz.00080",
"jazz.00081",
"jazz.00082",
"jazz.00083",
"jazz.00084",
"jazz.00085",
"jazz.00086",
"jazz.00087",
"jazz.00088",
"jazz.00089",
"jazz.00090",
"jazz.00091",
"jazz.00092",
"jazz.00093",
"jazz.00094",
"jazz.00095",
"jazz.00096",
"jazz.00097",
"jazz.00098",
"jazz.00099",
"metal.00012",
"metal.00013",
"metal.00014",
"metal.00015",
"metal.00022",
"metal.00023",
"metal.00025",
"metal.00026",
"metal.00027",
"metal.00028",
"metal.00029",
"metal.00030",
"metal.00031",
"metal.00032",
"metal.00033",
"metal.00038",
"metal.00039",
"metal.00067",
"metal.00070",
"metal.00073",
"metal.00074",
"metal.00075",
"metal.00078",
"metal.00083",
"metal.00085",
"metal.00087",
"metal.00088",
"pop.00000",
"pop.00001",
"pop.00013",
"pop.00014",
"pop.00043",
"pop.00063",
"pop.00064",
"pop.00065",
"pop.00066",
"pop.00069",
"pop.00070",
"pop.00071",
"pop.00072",
"pop.00073",
"pop.00074",
"pop.00075",
"pop.00076",
"pop.00077",
"pop.00078",
"pop.00079",
"pop.00082",
"pop.00088",
"pop.00089",
"pop.00090",
"pop.00091",
"pop.00092",
"pop.00093",
"pop.00094",
"pop.00095",
"pop.00096",
"reggae.00034",
"reggae.00035",
"reggae.00036",
"reggae.00037",
"reggae.00038",
"reggae.00039",
"reggae.00040",
"reggae.00046",
"reggae.00047",
"reggae.00048",
"reggae.00052",
"reggae.00053",
"reggae.00064",
"reggae.00065",
"reggae.00066",
"reggae.00067",
"reggae.00068",
"reggae.00071",
"reggae.00079",
"reggae.00082",
"reggae.00083",
"reggae.00084",
"reggae.00087",
"reggae.00088",
"reggae.00089",
"reggae.00090",
"rock.00010",
"rock.00011",
"rock.00012",
"rock.00013",
"rock.00014",
"rock.00015",
"rock.00027",
"rock.00028",
"rock.00029",
"rock.00030",
"rock.00031",
"rock.00032",
"rock.00033",
"rock.00034",
"rock.00035",
"rock.00036",
"rock.00037",
"rock.00039",
"rock.00040",
"rock.00041",
"rock.00042",
"rock.00043",
"rock.00044",
"rock.00045",
"rock.00046",
"rock.00047",
"rock.00048",
"rock.00086",
"rock.00087",
"rock.00088",
"rock.00089",
"rock.00090",
]
filtered_train = [
"blues.00029",
"blues.00030",
"blues.00031",
"blues.00032",
"blues.00033",
"blues.00034",
"blues.00035",
"blues.00036",
"blues.00037",
"blues.00038",
"blues.00039",
"blues.00040",
"blues.00041",
"blues.00042",
"blues.00043",
"blues.00044",
"blues.00045",
"blues.00046",
"blues.00047",
"blues.00048",
"blues.00049",
"blues.00073",
"blues.00074",
"blues.00075",
"blues.00076",
"blues.00077",
"blues.00078",
"blues.00079",
"blues.00080",
"blues.00081",
"blues.00082",
"blues.00083",
"blues.00084",
"blues.00085",
"blues.00086",
"blues.00087",
"blues.00088",
"blues.00089",
"blues.00090",
"blues.00091",
"blues.00092",
"blues.00093",
"blues.00094",
"blues.00095",
"blues.00096",
"blues.00097",
"classical.00030",
"classical.00031",
"classical.00032",
"classical.00033",
"classical.00043",
"classical.00044",
"classical.00045",
"classical.00046",
"classical.00047",
"classical.00048",
"classical.00050",
"classical.00051",
"classical.00052",
"classical.00053",
"classical.00054",
"classical.00055",
"classical.00056",
"classical.00057",
"classical.00058",
"classical.00059",
"classical.00060",
"classical.00061",
"classical.00062",
"classical.00063",
"classical.00064",
"classical.00065",
"classical.00066",
"classical.00067",
"classical.00080",
"classical.00081",
"classical.00082",
"classical.00083",
"classical.00084",
"classical.00085",
"classical.00086",
"classical.00087",
"classical.00088",
"classical.00089",
"classical.00090",
"classical.00091",
"classical.00092",
"classical.00093",
"classical.00094",
"classical.00095",
"classical.00096",
"classical.00097",
"classical.00098",
"classical.00099",
"country.00019",
"country.00020",
"country.00021",
"country.00022",
"country.00023",
"country.00024",
"country.00025",
"country.00026",
"country.00028",
"country.00029",
"country.00065",
"country.00066",
"country.00067",
"country.00068",
"country.00069",
"country.00070",
"country.00071",
"country.00072",
"country.00073",
"country.00074",
"country.00075",
"country.00076",
"country.00077",
"country.00078",
"country.00079",
"country.00080",
"country.00081",
"country.00082",
"country.00083",
"country.00084",
"country.00085",
"country.00086",
"country.00087",
"country.00088",
"country.00089",
"country.00090",
"country.00091",
"country.00092",
"country.00093",
"country.00094",
"country.00095",
"country.00096",
"country.00097",
"country.00098",
"country.00099",
"disco.00005",
"disco.00015",
"disco.00016",
"disco.00017",
"disco.00018",
"disco.00019",
"disco.00020",
"disco.00022",
"disco.00023",
"disco.00024",
"disco.00025",
"disco.00026",
"disco.00027",
"disco.00028",
"disco.00029",
"disco.00030",
"disco.00031",
"disco.00032",
"disco.00033",
"disco.00034",
"disco.00035",
"disco.00036",
"disco.00037",
"disco.00039",
"disco.00040",
"disco.00041",
"disco.00042",
"disco.00043",
"disco.00044",
"disco.00045",
"disco.00047",
"disco.00049",
"disco.00053",
"disco.00054",
"disco.00056",
"disco.00057",
"disco.00059",
"disco.00061",
"disco.00070",
"disco.00073",
"disco.00074",
"disco.00089",
"hiphop.00002",
"hiphop.00003",
"hiphop.00004",
"hiphop.00005",
"hiphop.00006",
"hiphop.00007",
"hiphop.00008",
"hiphop.00009",
"hiphop.00010",
"hiphop.00011",
"hiphop.00012",
"hiphop.00013",
"hiphop.00014",
"hiphop.00015",
"hiphop.00016",
"hiphop.00017",
"hiphop.00018",
"hiphop.00019",
"hiphop.00020",
"hiphop.00021",
"hiphop.00022",
"hiphop.00023",
"hiphop.00024",
"hiphop.00025",
"hiphop.00028",
"hiphop.00029",
"hiphop.00031",
"hiphop.00032",
"hiphop.00033",
"hiphop.00034",
"hiphop.00035",
"hiphop.00036",
"hiphop.00037",
"hiphop.00038",
"hiphop.00041",
"hiphop.00042",
"hiphop.00055",
"hiphop.00056",
"hiphop.00057",
"hiphop.00058",
"hiphop.00059",
"hiphop.00060",
"hiphop.00061",
"hiphop.00077",
"hiphop.00078",
"hiphop.00079",
"hiphop.00080",
"jazz.00000",
"jazz.00001",
"jazz.00011",
"jazz.00012",
"jazz.00013",
"jazz.00014",
"jazz.00015",
"jazz.00016",
"jazz.00017",
"jazz.00018",
"jazz.00019",
"jazz.00020",
"jazz.00021",
"jazz.00022",
"jazz.00023",
"jazz.00024",
"jazz.00041",
"jazz.00047",
"jazz.00048",
"jazz.00049",
"jazz.00050",
"jazz.00051",
"jazz.00052",
"jazz.00053",
"jazz.00054",
"jazz.00055",
"jazz.00056",
"jazz.00057",
"jazz.00058",
"jazz.00059",
"jazz.00060",
"jazz.00061",
"jazz.00062",
"jazz.00063",
"jazz.00064",
"jazz.00065",
"jazz.00066",
"jazz.00067",
"jazz.00068",
"jazz.00069",
"jazz.00070",
"jazz.00071",
"jazz.00072",
"metal.00002",
"metal.00003",
"metal.00005",
"metal.00021",
"metal.00024",
"metal.00035",
"metal.00046",
"metal.00047",
"metal.00048",
"metal.00049",
"metal.00050",
"metal.00051",
"metal.00052",
"metal.00053",
"metal.00054",
"metal.00055",
"metal.00056",
"metal.00057",
"metal.00059",
"metal.00060",
"metal.00061",
"metal.00062",
"metal.00063",
"metal.00064",
"metal.00065",
"metal.00066",
"metal.00069",
"metal.00071",
"metal.00072",
"metal.00079",
"metal.00080",
"metal.00084",
"metal.00086",
"metal.00089",
"metal.00090",
"metal.00091",
"metal.00092",
"metal.00093",
"metal.00094",
"metal.00095",
"metal.00096",
"metal.00097",
"metal.00098",
"metal.00099",
"pop.00002",
"pop.00003",
"pop.00004",
"pop.00005",
"pop.00006",
"pop.00007",
"pop.00008",
"pop.00009",
"pop.00011",
"pop.00012",
"pop.00016",
"pop.00017",
"pop.00018",
"pop.00019",
"pop.00020",
"pop.00023",
"pop.00024",
"pop.00025",
"pop.00026",
"pop.00027",
"pop.00028",
"pop.00029",
"pop.00031",
"pop.00032",
"pop.00033",
"pop.00034",
"pop.00035",
"pop.00036",
"pop.00038",
"pop.00039",
"pop.00040",
"pop.00041",
"pop.00042",
"pop.00044",
"pop.00046",
"pop.00049",
"pop.00050",
"pop.00080",
"pop.00097",
"pop.00098",
"pop.00099",
"reggae.00000",
"reggae.00001",
"reggae.00002",
"reggae.00004",
"reggae.00006",
"reggae.00009",
"reggae.00011",
"reggae.00012",
"reggae.00014",
"reggae.00015",
"reggae.00016",
"reggae.00017",
"reggae.00018",
"reggae.00019",
"reggae.00020",
"reggae.00021",
"reggae.00022",
"reggae.00023",
"reggae.00024",
"reggae.00025",
"reggae.00026",
"reggae.00027",
"reggae.00028",
"reggae.00029",
"reggae.00030",
"reggae.00031",
"reggae.00032",
"reggae.00042",
"reggae.00043",
"reggae.00044",
"reggae.00045",
"reggae.00049",
"reggae.00050",
"reggae.00051",
"reggae.00054",
"reggae.00055",
"reggae.00056",
"reggae.00057",
"reggae.00058",
"reggae.00059",
"reggae.00060",
"reggae.00063",
"reggae.00069",
"rock.00000",
"rock.00001",
"rock.00002",
"rock.00003",
"rock.00004",
"rock.00005",
"rock.00006",
"rock.00007",
"rock.00008",
"rock.00009",
"rock.00016",
"rock.00017",
"rock.00018",
"rock.00019",
"rock.00020",
"rock.00021",
"rock.00022",
"rock.00023",
"rock.00024",
"rock.00025",
"rock.00026",
"rock.00057",
"rock.00058",
"rock.00059",
"rock.00060",
"rock.00061",
"rock.00062",
"rock.00063",
"rock.00064",
"rock.00065",
"rock.00066",
"rock.00067",
"rock.00068",
"rock.00069",
"rock.00070",
"rock.00091",
"rock.00092",
"rock.00093",
"rock.00094",
"rock.00095",
"rock.00096",
"rock.00097",
"rock.00098",
"rock.00099",
]
filtered_valid = [
"blues.00000",
"blues.00001",
"blues.00002",
"blues.00003",
"blues.00004",
"blues.00005",
"blues.00006",
"blues.00007",
"blues.00008",
"blues.00009",
"blues.00010",
"blues.00011",
"blues.00050",
"blues.00051",
"blues.00052",
"blues.00053",
"blues.00054",
"blues.00055",
"blues.00056",
"blues.00057",
"blues.00058",
"blues.00059",
"blues.00060",
"classical.00000",
"classical.00001",
"classical.00002",
"classical.00003",
"classical.00004",
"classical.00005",
"classical.00006",
"classical.00007",
"classical.00008",
"classical.00009",
"classical.00010",
"classical.00068",
"classical.00069",
"classical.00070",
"classical.00071",
"classical.00072",
"classical.00073",
"classical.00074",
"classical.00075",
"classical.00076",
"country.00000",
"country.00001",
"country.00002",
"country.00003",
"country.00004",
"country.00005",
"country.00006",
"country.00007",
"country.00009",
"country.00010",
"country.00011",
"country.00012",
"country.00013",
"country.00014",
"country.00015",
"country.00016",
"country.00017",
"country.00018",
"country.00027",
"country.00041",
"country.00042",
"country.00045",
"country.00049",
"disco.00000",
"disco.00002",
"disco.00003",
"disco.00004",
"disco.00006",
"disco.00007",
"disco.00008",
"disco.00009",
"disco.00010",
"disco.00011",
"disco.00012",
"disco.00013",
"disco.00014",
"disco.00046",
"disco.00048",
"disco.00052",
"disco.00067",
"disco.00068",
"disco.00072",
"disco.00075",
"disco.00090",
"disco.00095",
"hiphop.00081",
"hiphop.00082",
"hiphop.00083",
"hiphop.00084",
"hiphop.00085",
"hiphop.00086",
"hiphop.00087",
"hiphop.00088",
"hiphop.00089",
"hiphop.00090",
"hiphop.00091",
"hiphop.00092",
"hiphop.00093",
"hiphop.00094",
"hiphop.00095",
"hiphop.00096",
"hiphop.00097",
"hiphop.00098",
"jazz.00002",
"jazz.00003",
"jazz.00004",
"jazz.00005",
"jazz.00006",
"jazz.00007",
"jazz.00008",
"jazz.00009",
"jazz.00010",
"jazz.00025",
"jazz.00026",
"jazz.00027",
"jazz.00028",
"jazz.00029",
"jazz.00030",
"jazz.00031",
"jazz.00032",
"metal.00000",
"metal.00001",
"metal.00006",
"metal.00007",
"metal.00008",
"metal.00009",
"metal.00010",
"metal.00011",
"metal.00016",
"metal.00017",
"metal.00018",
"metal.00019",
"metal.00020",
"metal.00036",
"metal.00037",
"metal.00068",
"metal.00076",
"metal.00077",
"metal.00081",
"metal.00082",
"pop.00010",
"pop.00053",
"pop.00055",
"pop.00058",
"pop.00059",
"pop.00060",
"pop.00061",
"pop.00062",
"pop.00081",
"pop.00083",
"pop.00084",
"pop.00085",
"pop.00086",
"reggae.00061",
"reggae.00062",
"reggae.00070",
"reggae.00072",
"reggae.00074",
"reggae.00076",
"reggae.00077",
"reggae.00078",
"reggae.00085",
"reggae.00092",
"reggae.00093",
"reggae.00094",
"reggae.00095",
"reggae.00096",
"reggae.00097",
"reggae.00098",
"reggae.00099",
"rock.00038",
"rock.00049",
"rock.00050",
"rock.00051",
"rock.00052",
"rock.00053",
"rock.00054",
"rock.00055",
"rock.00056",
"rock.00071",
"rock.00072",
"rock.00073",
"rock.00074",
"rock.00075",
"rock.00076",
"rock.00077",
"rock.00078",
"rock.00079",
"rock.00080",
"rock.00081",
"rock.00082",
"rock.00083",
"rock.00084",
"rock.00085",
]
URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
}
def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
"""
Loads a file from the dataset and returns the raw waveform
as a Torch Tensor, its sample rate as an integer, and its
genre as a string.
"""
# Filenames are of the form label.id, e.g. blues.00078
label, _ = fileid.split(".")
# Read wav
file_audio = os.path.join(path, label, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, label
class GTZAN(Dataset):
"""*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.
Note:
Please see http://marsyas.info/downloads/datasets.html if you are planning to use
this dataset to publish results.
Note:
As of October 2022, the download link is not currently working. Setting ``download=True``
in GTZAN dataset will result in a URL connection error.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (str or None, optional): Which subset of the dataset to use.
One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
If ``None``, the entire dataset is used. (default: ``None``).
"""
_ext_audio = ".wav"
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
# super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self.root = root
self.url = url
self.folder_in_archive = folder_in_archive
self.download = download
self.subset = subset
if subset is not None and subset not in ["training", "validation", "testing"]:
raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
_extract_tar(archive)
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
if self.subset is None:
# Check every subdirectory under dataset root
# which has the same name as the genres in
# GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
# This lets users remove or move around song files,
# useful when e.g. they want to use only some of the files
# in a genre or want to label other files with a different
# genre.
self._walker = []
root = os.path.expanduser(self._path)
for directory in gtzan_genres:
fulldir = os.path.join(root, directory)
if not os.path.exists(fulldir):
continue
songs_in_genre = os.listdir(fulldir)
songs_in_genre.sort()
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
# Check whether the file is of the form
# `gtzan_genre`.`5 digit number`.wav
genre, num = name.split(".")
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
self._walker.append(name)
else:
if self.subset == "training":
self._walker = filtered_train
elif self.subset == "validation":
self._walker = filtered_valid
elif self.subset == "testing":
self._walker = filtered_test
def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
Tensor:
Waveform
int:
Sample rate
str:
Label
"""
fileid = self._walker[n]
item = load_gtzan_item(fileid, self._path, self._ext_audio)
waveform, sample_rate, label = item
return waveform, sample_rate, label
def __len__(self) -> int:
return len(self._walker)