#!/usr/bin/env python3
import argparse
import gc
import itertools as it
import json
import logging
import os
import unicodedata
import wave
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import numpy as np
import torch
from piper import PiperVoice, SynthesisConfig
from piper.phonemize_espeak import EspeakPhonemizer

try:
    from piper_train.vits import commons
except ImportError:
    from piper_train.vits import commons

_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


# Main generation function
def generate_samples(
    text: Union[List[str], str],
    output_dir: Union[str, Path],
    model: Union[str, Path],
    max_samples: Optional[int] = None,
    file_names: Optional[Iterable[str]] = None,
    batch_size: int = 1,
    slerp_weights: Tuple[float, ...] = (0.5,),
    length_scales: Tuple[float, ...] = (0.75, 1, 1.25),
    noise_scales: Tuple[float, ...] = (0.667,),
    noise_scale_ws: Tuple[float, ...] = (0.8,),
    max_speakers: Optional[int] = None,
    verbose: bool = False,
    phoneme_input: bool = False,
    **kwargs,
) -> None:
    """
    Generate synthetic speech clips, saving the clips to the specified output directory.

    Args:
        text (List[str]): The text to convert into speech. Can be either a
                          a list of strings, or a path to a file with text on each line.
        output_dir (str): The location to save the generated clips.
        model (str): The path to the TTS generator model (.pt).
        max_samples (int): The maximum number of samples to generate.
        file_names (List[str]): The names to use when saving the files. Must be the same length
                                as the `text` argument, if a list.
        batch_size (int): The batch size to use when generated the clips
        slerp_weights (List[float]): The weights to use when mixing speakers via SLERP.
        length_scales (List[float]): Controls the average duration/speed of the generated speech.
        noise_scales (List[float]): A parameter for overall variability of the generated speech.
        noise_scale_ws (List[float]): A parameter for the stochastic duration of words/phonemes.
        max_speakers (int): The maximum speaker number to use, if the model is multi-speaker.
        verbose (bool): Enable or disable more detailed logging messages (default: False).
        phoneme_input (bool): Set to indicate given input text is phoneme input.
    Returns:
        None
    """

    if max_samples is None:
        max_samples = len(text)

    _LOGGER.debug("Loading %s", model)
    model_path = Path(model)

    torch_model = torch.load(model_path, weights_only=False)
    torch_model.eval()
    _LOGGER.info("Successfully loaded the model")

    if torch.cuda.is_available():
        torch_model.cuda()
        _LOGGER.debug("CUDA available, using GPU")
    elif torch.backends.mps.is_available():
        mps_device = torch.device("mps")
        torch_model.to(mps_device)
        _LOGGER.debug("MPS available, using GPU")

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    config_path = f"{model_path}.json"
    with open(config_path, "r", encoding="utf-8") as config_file:
        config = json.load(config_file)

    voice = config["espeak"]["voice"]
    sample_rate = config["audio"]["sample_rate"]
    num_speakers = config["num_speakers"]
    if max_speakers is not None:
        num_speakers = min(num_speakers, max_speakers)

    max_len = None

    sample_idx = 0
    is_done = False
    settings_iter = it.cycle(
        it.product(
            slerp_weights,
            length_scales,
            noise_scales,
            noise_scale_ws,
        )
    )

    speakers_iter = it.cycle(it.product(range(num_speakers), range(num_speakers)))
    speakers_batch = list(it.islice(speakers_iter, 0, batch_size))
    if isinstance(text, str) and os.path.isfile(text):
        texts = it.cycle(
            [
                i.strip()
                for i in open(text, "r", encoding="utf-8").readlines()
                if len(i.strip()) > 0
            ]
        )
    elif isinstance(text, list):
        texts = it.cycle(text)
    else:
        texts = it.cycle([text])

    if file_names:
        file_names = it.cycle(file_names)

    batch_idx = 0
    while speakers_batch:
        if is_done:
            break

        batch_size = len(speakers_batch)
        slerp_weight, length_scale, noise_scale, noise_scale_w = next(settings_iter)

        with torch.no_grad():
            speaker_1 = torch.LongTensor([s[0] for s in speakers_batch])
            speaker_2 = torch.LongTensor([s[1] for s in speakers_batch])

            phoneme_ids_by_batch = []
            for i in range(batch_size):
                phoneme_ids = get_phonemes(
                    voice, config, next(texts), verbose, phoneme_input
                )
                phoneme_ids_by_batch.append(phoneme_ids)

            def right_pad_lists(lists):
                max_length = max(len(lst) for lst in lists)
                padded_lists = []
                for lst in lists:
                    padded_l = lst + [1] * (
                        max_length - len(lst)
                    )  # phoneme 1 (corresponding to '^' character seems to work best)
                    padded_lists.append(padded_l)
                return padded_lists

            phoneme_ids_by_batch = right_pad_lists(phoneme_ids_by_batch)
            audio, phoneme_samples = generate_audio(
                torch_model,
                speaker_1,
                speaker_2,
                phoneme_ids_by_batch,
                slerp_weight,
                noise_scale,
                noise_scale_w,
                length_scale,
                max_len,
            )

            # Trim audio to actual length based on phoneme samples
            for i in range(audio.shape[0]):
                # Fill time after last speech with silence (zeros)
                # It will be removed in the next stage with np.trim_zeros
                last_sample_idx = int(phoneme_samples[i].flatten().sum().item())
                audio[i, 0, last_sample_idx + 1 :] = 0

            audio_numpy = audio.cpu().numpy()

            if torch.backends.mps.is_available():
                # There seems to be a memory leak if we don't empty the cache
                # after each batch with mps
                torch.mps.empty_cache()
                gc.collect()

            audio_int16 = audio_float_to_int16(audio_numpy)
            for audio_idx in range(audio_int16.shape[0]):
                audio_data = np.trim_zeros(audio_int16[audio_idx].flatten())

                if isinstance(file_names, it.cycle):
                    wav_path = output_dir / next(file_names)
                else:
                    wav_path = output_dir / f"{sample_idx}.wav"

                wav_file: wave.Wave_write = wave.open(str(wav_path), "wb")
                with wav_file:
                    wav_file.setframerate(sample_rate)
                    wav_file.setsampwidth(2)
                    wav_file.setnchannels(1)
                    wav_file.writeframes(audio_data)

                sample_idx += 1
                if sample_idx >= max_samples:
                    is_done = True
                    break

            # print(f"Batch {batch_idx +1}/{max_samples//batch_size} complete", " "*200, end='\r')

        # Next batch
        _LOGGER.debug("Batch %s/%s complete", batch_idx + 1, max_samples // batch_size)
        speakers_batch = list(it.islice(speakers_iter, 0, batch_size))
        batch_idx += 1

    _LOGGER.info("Done")


# -----------------------------------------------------------------------------


def generate_samples_onnx(
    text: Union[List[str], str],
    output_dir: Union[str, Path],
    model: Union[str, Path, List[Union[str, Path]]],
    max_samples: Optional[int] = None,
    file_names: Optional[Iterable[str]] = None,
    length_scales: Tuple[float, ...] = (0.75, 1, 1.25),
    noise_scales: Tuple[float, ...] = (0.667,),
    noise_scale_ws: Tuple[float, ...] = (0.8,),
    max_speakers: Optional[int] = None,
    phoneme_input: bool = False,
    **kwargs,
) -> None:
    """
    Generate synthetic speech clips, saving the clips to the specified output directory.

    Args:
        text (List[str]): The text to convert into speech. Can be either a
                          a list of strings, or a path to a file with text on each line.
        output_dir (str): The location to save the generated clips.
        model (str): The path to the Piper TTS model (.onnx).
        max_samples (int): The maximum number of samples to generate.
        file_names (List[str]): The names to use when saving the files. Must be the same length
                                as the `text` argument, if a list.
        length_scales (List[float]): Controls the average duration/speed of the generated speech.
        noise_scales (List[float]): A parameter for overall variability of the generated speech.
        noise_scale_ws (List[float]): A parameter for the stochastic duration of words/phonemes.
        max_speakers (int): The maximum speaker number to use, if the model is multi-speaker.
        phoneme_input (bool): Set to indicate given input text is phoneme input.

    Returns:
        None
    """

    if max_samples is None:
        max_samples = len(text)

    if not isinstance(model, list):
        model = [model]

    _LOGGER.debug("Loading %s", model)
    voices = [PiperVoice.load(m, use_cuda=torch.cuda.is_available()) for m in model]
    _LOGGER.info("Successfully loaded model(s)")

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    sample_idx = 0
    settings_iter = it.cycle(
        it.product(
            voices,
            length_scales,
            noise_scales,
            noise_scale_ws,
        )
    )

    if isinstance(text, str) and os.path.exists(text):
        texts = it.cycle(
            [
                i.strip()
                for i in open(text, "r", encoding="utf-8").readlines()
                if len(i.strip()) > 0
            ]
        )
    elif isinstance(text, list):
        texts = it.cycle(text)
    else:
        texts = it.cycle([text])

    if file_names:
        file_names = it.cycle(file_names)

    for voice, length_scale, noise_scale, noise_w_scale in settings_iter:
        num_speakers = voice.config.num_speakers
        if max_speakers is not None:
            num_speakers = min(num_speakers, max_speakers)

        for speaker_id in range(num_speakers):
            if isinstance(file_names, it.cycle):
                wav_path = output_dir / next(file_names)
            else:
                wav_path = output_dir / f"{sample_idx}.wav"

            text_input = next(texts)

            if phoneme_input:
                # For ONNX models with phoneme input, build phoneme IDs manually
                phonemes = list(unicodedata.normalize("NFD", text_input))

                # Build phoneme IDs similar to get_phonemes function
                id_map = voice.config.phoneme_id_map

                # Beginning of utterance
                phoneme_ids = list(id_map.get("^", [1]))  # Default to [1] if not found
                phoneme_ids.extend(id_map.get("_", [0]))  # Default to [0] if not found

                # Add phonemes
                for phoneme in phonemes:
                    p_ids = id_map.get(phoneme)
                    if p_ids is not None:
                        phoneme_ids.extend(p_ids)
                        phoneme_ids.extend(id_map.get("_", [0]))
                    else:
                        _LOGGER.warning(
                            "Phoneme '%s' not found in model's phoneme map", phoneme
                        )

                # End of utterance
                phoneme_ids.extend(id_map.get("$", [2]))  # Default to [2] if not found

                # Generate audio from phoneme IDs
                syn_config = SynthesisConfig(
                    speaker_id=speaker_id,
                    length_scale=length_scale,
                    noise_scale=noise_scale,
                    noise_w_scale=noise_w_scale,
                )
                audio = voice.phoneme_ids_to_audio(phoneme_ids, syn_config)

                # Convert to int16 and write to WAV
                audio_int16 = audio_float_to_int16(audio[np.newaxis, :])
                wav_file: wave.Wave_write = wave.open(str(wav_path), "wb")
                with wav_file:
                    wav_file.setframerate(voice.config.sample_rate)
                    wav_file.setsampwidth(2)
                    wav_file.setnchannels(1)
                    wav_file.writeframes(audio_int16.flatten())
            else:
                with wave.open(str(wav_path), "wb") as wav_file:
                    voice.synthesize_wav(
                        text_input,
                        wav_file=wav_file,
                        syn_config=SynthesisConfig(
                            speaker_id=speaker_id,
                            length_scale=length_scale,
                            noise_scale=noise_scale,
                            noise_w_scale=noise_w_scale,
                        ),
                    )

            sample_idx += 1
            if sample_idx >= max_samples:
                return

    _LOGGER.info("Done")


# -----------------------------------------------------------------------------


def generate_audio(
    model,
    speaker_1,
    speaker_2,
    phoneme_ids,
    slerp_weight,
    noise_scale,
    noise_scale_w,
    length_scale,
    max_len,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    x = torch.LongTensor(phoneme_ids)
    x_lengths = torch.LongTensor([len(i) for i in phoneme_ids])

    if torch.cuda.is_available():
        speaker_1 = speaker_1.cuda()
        speaker_2 = speaker_2.cuda()
        x = cast(torch.LongTensor, x.cuda())
        x_lengths = cast(torch.LongTensor, x_lengths.cuda())
    elif torch.backends.mps.is_available():
        mps_device = torch.device("mps")
        speaker_1 = speaker_1.to(mps_device)
        speaker_2 = speaker_2.to(mps_device)
        x = cast(torch.LongTensor, x.to(mps_device))
        x_lengths = cast(torch.LongTensor, x_lengths.to(mps_device))

    x, m_p_orig, logs_p_orig, x_mask = model.enc_p(x, x_lengths)
    emb0 = model.emb_g(speaker_1)
    emb1 = model.emb_g(speaker_2)
    g = slerp(emb0, emb1, slerp_weight).unsqueeze(-1)  # [b, h, 1]

    if model.use_sdp:
        logw = model.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
    else:
        logw = model.dp(x, x_mask, g=g)
    w = torch.exp(logw) * x_mask * length_scale
    w_ceil = torch.ceil(w)
    y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
    y_mask = torch.unsqueeze(
        commons.sequence_mask(y_lengths, int(y_lengths.max().item())), 1
    ).type_as(x_mask)
    attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
    attn = commons.generate_path(w_ceil, attn_mask)

    m_p = torch.matmul(attn.squeeze(1), m_p_orig.transpose(1, 2)).transpose(
        1, 2
    )  # [b, t', t], [b, t, d] -> [b, d, t']
    logs_p = torch.matmul(attn.squeeze(1), logs_p_orig.transpose(1, 2)).transpose(
        1, 2
    )  # [b, t', t], [b, t, d] -> [b, d, t']

    z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
    z = model.flow(z_p, y_mask, g=g, reverse=True)
    o = model.dec((z * y_mask)[:, :, :max_len], g=g)

    audio = cast(torch.FloatTensor, o)
    phoneme_samples = cast(torch.FloatTensor, w_ceil * 256)  # hop length

    return audio, phoneme_samples


_PHONEMIZER = EspeakPhonemizer()


def get_phonemes(
    voice: str,
    config: Dict[str, Any],
    text: str,
    verbose: bool = False,
    phoneme_input: bool = False,
) -> List[int]:
    # Combine all sentences
    if phoneme_input:
        phonemes = list(unicodedata.normalize("NFD", text))
    else:
        phonemes = [
            p
            for sentence_phonemes in _PHONEMIZER.phonemize(voice, text)
            for p in sentence_phonemes
        ]
    if verbose is True:
        _LOGGER.debug("Phonemes: %s", phonemes)

    id_map = config["phoneme_id_map"]

    # Beginning of utterance
    phoneme_ids = list(id_map["^"])
    phoneme_ids.extend(id_map["_"])

    # Phoneme ids for just the text
    text_phoneme_ids = []

    for phoneme in phonemes:
        p_ids = id_map.get(phoneme)
        if p_ids is not None:
            phoneme_ids.extend(p_ids)
            text_phoneme_ids.extend(p_ids)
            phoneme_ids.extend(id_map["_"])
            text_phoneme_ids.extend(id_map["_"])

    # End of utterance
    phoneme_ids.extend(id_map["$"])

    return phoneme_ids


def slerp(v1, v2, t: float, DOT_THR: float = 0.9995, zdim: int = -1):
    """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.

    `DOT_THR` determines when the vectors are too close to parallel.
        If they are too close, then a regular linear interpolation is used.

    `zdim` is the feature dimension over which to compute norms and find angles.
        For example: if a sequence of 5 vectors is input with shape [5, 768]
        Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.

    Theory Reference:
    https://splines.readthedocs.io/en/latest/rotation/slerp.html
    PyTorch reference:
    https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
    Numpy reference:
    https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
    """

    # take the dot product between normalized vectors
    v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
    v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
    dot = (v1_norm * v2_norm).sum(zdim)

    # if the vectors are too close, return a simple linear interpolation
    if (torch.abs(dot) > DOT_THR).any():
        res = (1 - t) * v1 + t * v2

    # else apply SLERP
    else:
        # compute the angle terms we need
        theta = torch.acos(dot)
        theta_t = theta * t
        sin_theta = torch.sin(theta)
        sin_theta_t = torch.sin(theta_t)

        # compute the sine scaling terms for the vectors
        s1 = torch.sin(theta - theta_t) / sin_theta
        s2 = sin_theta_t / sin_theta

        # interpolate the vectors
        res = (s1.unsqueeze(zdim) * v1) + (s2.unsqueeze(zdim) * v2)

    return res


def audio_float_to_int16(
    audio: np.ndarray, max_wav_value: float = 32767.0
) -> np.ndarray:
    """Normalize audio and convert to int16 range"""
    audio_norm = audio * (max_wav_value / max(0.01, np.max(np.abs(audio))))
    audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
    audio_norm = audio_norm.astype("int16")
    return audio_norm


# -----------------------------------------------------------------------------


def main() -> int:
    """Main entry point."""

    # Get command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("text")
    parser.add_argument(
        "--max-samples",
        required=True,
        type=int,
        help="Maximum number of samples to generate",
    )
    parser.add_argument(
        "--model",
        required=True,
        action="append",
        help="Path to PyTorch generator (.pt) or Piper voice model (.onnx)",
    )
    parser.add_argument(
        "--batch-size", type=int, default=1, help="CUDA batch size (generator only)"
    )
    parser.add_argument(
        "--slerp-weights",
        nargs="+",
        type=float,
        default=[0.5],
        help="Speaker blending weights (generator only)",
    )
    parser.add_argument(
        "--length-scales",
        nargs="+",
        type=float,
        default=[1.0, 0.75, 1.25, 1.4],
        help="Audio length scales (< 1 is faster, > 1 is slower)",
    )
    parser.add_argument(
        "--noise-scales",
        nargs="+",
        type=float,
        default=[0.667, 0.75, 0.85, 0.9, 1.0, 1.4],
        help="Noise amounts added to audio (most voices use 0.667)",
    )
    parser.add_argument(
        "--noise-scale-ws",
        nargs="+",
        type=float,
        default=[0.8],
        help="Phoneme width variation (most voices use 0.8)",
    )
    parser.add_argument(
        "--output-dir",
        default="output",
        help="Directory to output WAV files (default: ./output)",
    )
    parser.add_argument(
        "--max-speakers",
        type=int,
        help="Maximum number of speakers to use (default: no limit)",
    )
    parser.add_argument(
        "--phoneme-input", action="store_true", help="Treat input text as phoneme input"
    )
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args().__dict__

    # Generate speech
    model_paths = [Path(m) for m in args["model"]]
    assert model_paths

    if any(mp for mp in model_paths[1:] if mp.suffix != model_paths[0].suffix):
        _LOGGER.error("All models must have the same suffix (.pt or .onnx)")
        return 1

    if model_paths[0].suffix == ".onnx":
        # Use Piper voice (.onnx)
        generate_samples_onnx(**args)
    elif model_paths[0].suffix == ".pt":
        # Use PyTorch generator (.pt)
        if len(model_paths) > 1:
            _LOGGER.error("Only one generator (.pt) is supported")
            return 1

        args["model"] = args["model"][0]
        generate_samples(**args)
    else:
        _LOGGER.error("Models must have .pt or .onnx suffix")
        return 1

    return 0


if __name__ == "__main__":
    main()
