# coding=utf-8 # Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization class for Pop2Piano.""" import json import os from typing import List, Optional, Tuple, Union import numpy as np from ...feature_extraction_utils import BatchFeature from ...tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy from ...utils import TensorType, is_pretty_midi_available, logging, requires_backends, to_numpy if is_pretty_midi_available(): import pretty_midi logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { "vocab": "vocab.json", } def token_time_to_note(number, cutoff_time_idx, current_idx): current_idx += number if cutoff_time_idx is not None: current_idx = min(current_idx, cutoff_time_idx) return current_idx def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes): if note_onsets_ready[number] is not None: # offset with onset onset_idx = note_onsets_ready[number] if onset_idx < current_idx: # Time shift after previous note_on offset_idx = current_idx notes.append([onset_idx, offset_idx, number, default_velocity]) onsets_ready = None if current_velocity == 0 else current_idx note_onsets_ready[number] = onsets_ready else: note_onsets_ready[number] = current_idx return notes class Pop2PianoTokenizer(PreTrainedTokenizer): """ Constructs a Pop2Piano tokenizer. This tokenizer does not require training. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab (`str`): Path to the vocab file which contains the vocabulary. default_velocity (`int`, *optional*, defaults to 77): Determines the default velocity to be used while creating midi Notes. num_bars (`int`, *optional*, defaults to 2): Determines cutoff_time_idx in for each token. unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1): The end of sequence token. pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0): A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2): The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. """ model_input_names = ["token_ids", "attention_mask"] vocab_files_names = VOCAB_FILES_NAMES def __init__( self, vocab, default_velocity=77, num_bars=2, unk_token="-1", eos_token="1", pad_token="0", bos_token="2", **kwargs, ): unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token self.default_velocity = default_velocity self.num_bars = num_bars # Load the vocab with open(vocab, "rb") as file: self.encoder = json.load(file) # create mappings for encoder self.decoder = {v: k for k, v in self.encoder.items()} super().__init__( unk_token=unk_token, eos_token=eos_token, pad_token=pad_token, bos_token=bos_token, **kwargs, ) @property def vocab_size(self): """Returns the vocabulary size of the tokenizer.""" return len(self.encoder) def get_vocab(self): """Returns the vocabulary of the tokenizer.""" return dict(self.encoder, **self.added_tokens_encoder) def _convert_id_to_token(self, token_id: int) -> list: """ Decodes the token ids generated by the transformer into notes. Args: token_id (`int`): This denotes the ids generated by the transformers to be converted to Midi tokens. Returns: `List`: A list consists of token_type (`str`) and value (`int`). """ token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME") token_type_value = token_type_value.split("_") token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0]) return [token_type, value] def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int: """ Encodes the Midi tokens to transformer generated token ids. Args: token (`int`): This denotes the token value. token_type (`str`): This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME", "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL". Returns: `int`: returns the id of the token. """ return self.encoder.get(f"{token}_{token_type}", int(self.unk_token)) def relative_batch_tokens_ids_to_notes( self, tokens: np.ndarray, beat_offset_idx: int, bars_per_batch: int, cutoff_time_idx: int, ): """ Converts relative tokens to notes which are then used to generate pretty midi object. Args: tokens (`numpy.ndarray`): Tokens to be converted to notes. beat_offset_idx (`int`): Denotes beat offset index for each note in generated Midi. bars_per_batch (`int`): A parameter to control the Midi output generation. cutoff_time_idx (`int`): Denotes the cutoff time index for each note in generated Midi. """ notes = None for index in range(len(tokens)): _tokens = tokens[index] _start_idx = beat_offset_idx + index * bars_per_batch * 4 _cutoff_time_idx = cutoff_time_idx + _start_idx _notes = self.relative_tokens_ids_to_notes( _tokens, start_idx=_start_idx, cutoff_time_idx=_cutoff_time_idx, ) if len(_notes) == 0: pass elif notes is None: notes = _notes else: notes = np.concatenate((notes, _notes), axis=0) if notes is None: return [] return notes def relative_batch_tokens_ids_to_midi( self, tokens: np.ndarray, beatstep: np.ndarray, beat_offset_idx: int = 0, bars_per_batch: int = 2, cutoff_time_idx: int = 12, ): """ Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens to notes then uses `notes_to_midi` method to convert them to Midi. Args: tokens (`numpy.ndarray`): Denotes tokens which alongside beatstep will be converted to Midi. beatstep (`np.ndarray`): We get beatstep from feature extractor which is also used to get Midi. beat_offset_idx (`int`, *optional*, defaults to 0): Denotes beat offset index for each note in generated Midi. bars_per_batch (`int`, *optional*, defaults to 2): A parameter to control the Midi output generation. cutoff_time_idx (`int`, *optional*, defaults to 12): Denotes the cutoff time index for each note in generated Midi. """ beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx notes = self.relative_batch_tokens_ids_to_notes( tokens=tokens, beat_offset_idx=beat_offset_idx, bars_per_batch=bars_per_batch, cutoff_time_idx=cutoff_time_idx, ) midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx]) return midi # Taken from the original code # Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257 def relative_tokens_ids_to_notes(self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: float = None): """ Converts relative tokens to notes which will then be used to create Pretty Midi objects. Args: tokens (`numpy.ndarray`): Relative Tokens which will be converted to notes. start_idx (`float`): A parameter which denotes the starting index. cutoff_time_idx (`float`, *optional*): A parameter used while converting tokens to notes. """ words = [self._convert_id_to_token(token) for token in tokens] current_idx = start_idx current_velocity = 0 note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder.keys()]) + 1)] notes = [] for token_type, number in words: if token_type == "TOKEN_SPECIAL": if number == 1: break elif token_type == "TOKEN_TIME": current_idx = token_time_to_note( number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx ) elif token_type == "TOKEN_VELOCITY": current_velocity = number elif token_type == "TOKEN_NOTE": notes = token_note_to_note( number=number, current_velocity=current_velocity, default_velocity=self.default_velocity, note_onsets_ready=note_onsets_ready, current_idx=current_idx, notes=notes, ) else: raise ValueError("Token type not understood!") for pitch, note_onset in enumerate(note_onsets_ready): # force offset if no offset for each pitch if note_onset is not None: if cutoff_time_idx is None: cutoff = note_onset + 1 else: cutoff = max(cutoff_time_idx, note_onset + 1) offset_idx = max(current_idx, cutoff) notes.append([note_onset, offset_idx, pitch, self.default_velocity]) if len(notes) == 0: return [] else: notes = np.array(notes) note_order = notes[:, 0] * 128 + notes[:, 1] notes = notes[note_order.argsort()] return notes def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0): """ Converts notes to Midi. Args: notes (`numpy.ndarray`): This is used to create Pretty Midi objects. beatstep (`numpy.ndarray`): This is the extrapolated beatstep that we get from feature extractor. offset_sec (`int`, *optional*, defaults to 0.0): This represents the offset seconds which is used while creating each Pretty Midi Note. """ requires_backends(self, ["pretty_midi"]) new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0) new_inst = pretty_midi.Instrument(program=0) new_notes = [] for onset_idx, offset_idx, pitch, velocity in notes: new_note = pretty_midi.Note( velocity=velocity, pitch=pitch, start=beatstep[onset_idx] - offset_sec, end=beatstep[offset_idx] - offset_sec, ) new_notes.append(new_note) new_inst.notes = new_notes new_pm.instruments.append(new_inst) new_pm.remove_invalid_notes() return new_pm def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Saves the tokenizer's vocabulary dictionary to the provided save_directory. Args: save_directory (`str`): A path to the directory where to saved. It will be created if it doesn't exist. filename_prefix (`Optional[str]`, *optional*): A prefix to add to the names of the files saved by the tokenizer. """ if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return # Save the encoder. out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"] ) with open(out_vocab_file, "w") as file: file.write(json.dumps(self.encoder)) return (out_vocab_file,) def encode_plus( self, notes: Union[np.ndarray, List[pretty_midi.Note]], truncation_strategy: Optional[TruncationStrategy] = None, max_length: Optional[int] = None, **kwargs, ) -> BatchEncoding: r""" This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated token ids. It only works on a single batch, to process multiple batches please use `batch_encode_plus` or `__call__` method. Args: notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects): This represents the midi notes. If `notes` is a `numpy.ndarray`: - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. If `notes` is a `list` containing `pretty_midi.Note` objects: - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): Indicates the truncation strategy that is going to be used during truncation. max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). Returns: `BatchEncoding` containing the tokens ids. """ requires_backends(self, ["pretty_midi"]) # check if notes is a pretty_midi object or not, if yes then extract the attributes and put them into a numpy # array. if isinstance(notes[0], pretty_midi.Note): notes = np.array( [[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes] ).reshape(-1, 4) # to round up all the values to the closest int values. notes = np.round(notes).astype(np.int32) max_time_idx = notes[:, :2].max() times = [[] for i in range((max_time_idx + 1))] for onset, offset, pitch, velocity in notes: times[onset].append([pitch, velocity]) times[offset].append([pitch, 0]) tokens = [] current_velocity = 0 for i, time in enumerate(times): if len(time) == 0: continue tokens.append(self._convert_token_to_id(i, "TOKEN_TIME")) for pitch, velocity in time: velocity = int(velocity > 0) if current_velocity != velocity: current_velocity = velocity tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY")) tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE")) total_len = len(tokens) # truncation if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: tokens, _, _ = self.truncate_sequences( ids=tokens, num_tokens_to_remove=total_len - max_length, truncation_strategy=truncation_strategy, **kwargs, ) return BatchEncoding({"token_ids": tokens}) def batch_encode_plus( self, notes: Union[np.ndarray, List[pretty_midi.Note]], truncation_strategy: Optional[TruncationStrategy] = None, max_length: Optional[int] = None, **kwargs, ) -> BatchEncoding: r""" This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop. Args: notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects): This represents the midi notes. If `notes` is a `numpy.ndarray`: - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. If `notes` is a `list` containing `pretty_midi.Note` objects: - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): Indicates the truncation strategy that is going to be used during truncation. max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). Returns: `BatchEncoding` containing the tokens ids. """ encoded_batch_token_ids = [] for i in range(len(notes)): encoded_batch_token_ids.append( self.encode_plus( notes[i], truncation_strategy=truncation_strategy, max_length=max_length, **kwargs, )["token_ids"] ) return BatchEncoding({"token_ids": encoded_batch_token_ids}) def __call__( self, notes: Union[ np.ndarray, List[pretty_midi.Note], List[List[pretty_midi.Note]], ], padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, **kwargs, ) -> BatchEncoding: r""" This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated token ids. Args: notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects): This represents the midi notes. If `notes` is a `numpy.ndarray`: - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. If `notes` is a `list` containing `pretty_midi.Note` objects: - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): Activates and controls padding. Accepts the following values: - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): Activates and controls truncation. Accepts the following values: - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater than the model maximum admissible input size). max_length (`int`, *optional*): Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length is required by one of the truncation/padding parameters. If the model has no specific maximum input length (like XLNet) truncation/padding to a maximum length will be deactivated. pad_to_multiple_of (`int`, *optional*): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). return_attention_mask (`bool`, *optional*): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the `return_outputs` attribute. [What are attention masks?](../glossary#attention-mask) return_tensors (`str` or [`~file_utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return Numpy `np.ndarray` objects. verbose (`bool`, *optional*, defaults to `True`): Whether or not to print more information and warnings. Returns: `BatchEncoding` containing the token_ids. """ # check if it is batched or not # it is batched if its a list containing a list of `pretty_midi.Notes` where the outer list contains all the # batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be # considered batched if it has shape of `[batch_size, seqence_length, 4]` or ndim=3. is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list) # get the truncation and padding strategy padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) if is_batched: # If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True return_attention_mask = True if return_attention_mask is None else return_attention_mask token_ids = self.batch_encode_plus( notes=notes, truncation_strategy=truncation_strategy, max_length=max_length, **kwargs, ) else: token_ids = self.encode_plus( notes=notes, truncation_strategy=truncation_strategy, max_length=max_length, **kwargs, ) # since we already have truncated sequnences we are just left to do padding token_ids = self.pad( token_ids, padding=padding_strategy, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, return_tensors=return_tensors, verbose=verbose, ) return token_ids def batch_decode( self, token_ids, feature_extractor_output: BatchFeature, return_midi: bool = True, ): r""" This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the transformer to midi_notes and returns them. Args: token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`): Output token_ids of `Pop2PianoConditionalGeneration` model. feature_extractor_output (`BatchFeature`): Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and `"attention_mask_extrapolated_beatstep"` should be present if they were returned by the feature extractor. return_midi (`bool`, *optional*, defaults to `True`): Whether to return midi object or not. Returns: If `return_midi` is True: - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects. If `return_midi` is False: - `BatchEncoding` containing `notes`. """ # check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not attention_masks_present = bool( hasattr(feature_extractor_output, "attention_mask") and hasattr(feature_extractor_output, "attention_mask_beatsteps") and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep") ) # if we are processing batched inputs then we must need attention_masks if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1: raise ValueError( "attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present " "for batched inputs! But one of them were not present." ) # check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep if attention_masks_present: # since we know about the number of examples in token_ids from attention_mask if ( sum(feature_extractor_output["attention_mask"][:, 0] == 0) != feature_extractor_output["beatsteps"].shape[0] or feature_extractor_output["beatsteps"].shape[0] != feature_extractor_output["extrapolated_beatstep"].shape[0] ): raise ValueError( "Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found " f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} " f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}" ) if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]: raise ValueError( f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}" ) else: # if there is no attention mask present then it's surely a single example if ( feature_extractor_output["beatsteps"].shape[0] != 1 or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1 ): raise ValueError( "Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, " f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}." ) if attention_masks_present: # check for zeros(since token_ids are seperated by zero arrays) batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0] else: batch_idx = [token_ids.shape[0]] notes_list = [] pretty_midi_objects_list = [] start_idx = 0 for index, end_idx in enumerate(batch_idx): each_tokens_ids = token_ids[start_idx:end_idx] # check where the whole example ended by searching for eos_token_id and getting the upper bound each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1] beatsteps = feature_extractor_output["beatsteps"][index] extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index] # if attention mask is present then mask out real array/tensor if attention_masks_present: attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index] attention_mask_extrapolated_beatstep = feature_extractor_output[ "attention_mask_extrapolated_beatstep" ][index] beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1] extrapolated_beatstep = extrapolated_beatstep[ : np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1 ] each_tokens_ids = to_numpy(each_tokens_ids) beatsteps = to_numpy(beatsteps) extrapolated_beatstep = to_numpy(extrapolated_beatstep) pretty_midi_object = self.relative_batch_tokens_ids_to_midi( tokens=each_tokens_ids, beatstep=extrapolated_beatstep, bars_per_batch=self.num_bars, cutoff_time_idx=(self.num_bars + 1) * 4, ) for note in pretty_midi_object.instruments[0].notes: note.start += beatsteps[0] note.end += beatsteps[0] notes_list.append(note) pretty_midi_objects_list.append(pretty_midi_object) start_idx += end_idx + 1 # 1 represents the zero array if return_midi: return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_objects_list}) return BatchEncoding({"notes": notes_list})