80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
|
from typing import Any
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def validate_numpy_array(value: Any):
|
||
|
r"""
|
||
|
Validates the input and makes sure it returns a numpy array (i.e on CPU)
|
||
|
|
||
|
Args:
|
||
|
value (Any): the input value
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if the value is not a numpy array or torch tensor
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: numpy array of the value
|
||
|
"""
|
||
|
if isinstance(value, np.ndarray):
|
||
|
pass
|
||
|
elif isinstance(value, list):
|
||
|
value = np.array(value)
|
||
|
elif torch.is_tensor(value):
|
||
|
value = value.cpu().numpy()
|
||
|
else:
|
||
|
raise TypeError("Value must be a numpy array, a torch tensor or a list")
|
||
|
|
||
|
return value
|
||
|
|
||
|
|
||
|
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None):
|
||
|
"""Get the most probable state means from the log_alpha_scaled.
|
||
|
|
||
|
Args:
|
||
|
log_alpha_scaled (torch.Tensor): Log alpha scaled values.
|
||
|
- Shape: :math:`(T, N)`
|
||
|
means (torch.Tensor): Means of the states.
|
||
|
- Shape: :math:`(N, T, D_out)`
|
||
|
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None.
|
||
|
"""
|
||
|
max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1]
|
||
|
max_len = means.shape[0]
|
||
|
n_mel_channels = means.shape[2]
|
||
|
max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels)
|
||
|
means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype)
|
||
|
if decoder is not None:
|
||
|
mel = (
|
||
|
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0]
|
||
|
.squeeze(0)
|
||
|
.T
|
||
|
)
|
||
|
else:
|
||
|
mel = means
|
||
|
return mel
|
||
|
|
||
|
|
||
|
def plot_transition_probabilities_to_numpy(states, transition_probabilities, output_fig=False):
|
||
|
"""Generates trainsition probabilities plot for the states and the probability of transition.
|
||
|
|
||
|
Args:
|
||
|
states (torch.IntTensor): the states
|
||
|
transition_probabilities (torch.FloatTensor): the transition probabilities
|
||
|
"""
|
||
|
states = validate_numpy_array(states)
|
||
|
transition_probabilities = validate_numpy_array(transition_probabilities)
|
||
|
|
||
|
fig, ax = plt.subplots(figsize=(30, 3))
|
||
|
ax.plot(transition_probabilities, "o")
|
||
|
ax.set_title("Transition probability of state")
|
||
|
ax.set_xlabel("hidden state")
|
||
|
ax.set_ylabel("probability")
|
||
|
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
|
||
|
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
||
|
plt.tight_layout()
|
||
|
if not output_fig:
|
||
|
plt.close()
|
||
|
return fig
|