"""Trainiert 'Plappi'-Keyword-Modell (ONNX [1,16,96]->[1,1]).
Features EXAKT wie die WASM-Engine (Audio [-1,1], melspec/10+2, 1280er-Chunks=5 Frames,
76er-Fenster Schritt 8 -> Embedding, 16er-Embedding-History).
Daten: Wort rechtsbündig ans Ende (gleiche Struktur pos/neg, kein Längen-Shortcut).
Positiv = LETZTES Streaming-Fenster eines Plappi-Clips. Negativ = alle Teil-/Stille-Fenster
(auch der Plappi-Clips) + alle Fenster der Negativ-Clips. So feuert das Modell NUR auf das
fertige Plappi und nicht auf Zwischenzustände (gegen runWav-MAX-Fehlalarme).
"""
import glob
import os
import numpy as np
import onnxruntime as ort
import torch
import torch.nn as nn
from scipy.io import wavfile

MODELS = "clients/web/wakeword/models"
MEL = ort.InferenceSession(f"{MODELS}/melspectrogram.onnx")
EMB = ort.InferenceSession(f"{MODELS}/embedding_model.onnx")
MEL_IN, MEL_OUT = MEL.get_inputs()[0].name, MEL.get_outputs()[0].name
EMB_IN, EMB_OUT = EMB.get_inputs()[0].name, EMB.get_outputs()[0].name
WIN = 16
FIX = 32000  # 2s


def load_rightaligned(p):
    sr, a = wavfile.read(p)
    if a.ndim > 1:
        a = a[:, 0]
    a = a.astype(np.float32) / 32768.0 if a.dtype == np.int16 else a.astype(np.float32)
    e = np.abs(a)
    if e.max() > 0:
        nz = np.where(e > 0.02 * e.max())[0]
        if len(nz):
            a = a[nz[0]:nz[-1] + 1]
    a = a[:FIX]
    out = np.zeros(FIX, dtype=np.float32)
    out[FIX - len(a):] = a  # rechtsbündig: Wort ans Ende
    return out


def all_windows(audio):
    """Alle 16er-Embedding-Fenster wie das Streaming der WASM-Engine sie erzeugt."""
    melbuf, hist, wins = [], [np.zeros(96, dtype=np.float32) for _ in range(WIN)], []
    for i in range(0, len(audio) - 1280 + 1, 1280):
        chunk = audio[i:i + 1280][None, :].astype(np.float32)
        md = np.array(MEL.run([MEL_OUT], {MEL_IN: chunk})[0]).reshape(-1) / 10.0 + 2.0
        for j in range(5):
            melbuf.append(md[j * 32:(j + 1) * 32])
        while len(melbuf) >= 76:
            win = np.array(melbuf[:76], dtype=np.float32).reshape(1, 76, 32, 1)
            emb = np.array(EMB.run([EMB_OUT], {EMB_IN: win})[0]).reshape(-1).astype(np.float32)
            hist.pop(0); hist.append(emb)
            wins.append(np.array(hist, dtype=np.float32).copy())
            del melbuf[:8]
    return wins


posP = sorted(glob.glob("wakeword/data/pos/*.wav"))
negP = sorted(glob.glob("wakeword/data/neg/*.wav"))
print("Dateien: pos", len(posP), "neg", len(negP), flush=True)

if os.path.exists("wakeword/data/Xp2.npy"):
    Xp, Xn = np.load("wakeword/data/Xp2.npy"), np.load("wakeword/data/Xn2.npy")
    print("Features aus Cache", flush=True)
else:
    Xpos, Xneg = [], []
    for p in posP:
        ws = all_windows(load_rightaligned(p))
        if not ws:
            continue
        Xpos.append(ws[-1])        # letztes Fenster (Plappi fertig) = positiv
        Xneg.extend(ws[:-1])       # Teil-/Stille-Fenster = negativ
    for p in negP:
        Xneg.extend(all_windows(load_rightaligned(p)))
    Xp, Xn = np.array(Xpos, dtype=np.float32), np.array(Xneg, dtype=np.float32)
    np.save("wakeword/data/Xp2.npy", Xp); np.save("wakeword/data/Xn2.npy", Xn)
print("Fenster: pos", Xp.shape, "neg", Xn.shape, flush=True)

X = np.concatenate([Xp, Xn], 0).astype(np.float32)
y = np.concatenate([np.ones(len(Xp)), np.zeros(len(Xn))]).astype(np.float32)
rng = np.random.default_rng(0)
idx = rng.permutation(len(X)); X, y = X[idx], y[idx]
ntr = int(len(X) * 0.9)
Xtr, ytr = torch.tensor(X[:ntr]), torch.tensor(y[:ntr]).unsqueeze(1)
Xva, yva = torch.tensor(X[ntr:]), torch.tensor(y[ntr:]).unsqueeze(1)
pos_w = len(Xn) / max(1, len(Xp))
w_tr = torch.where(ytr > 0.5, torch.tensor(float(pos_w)), torch.tensor(1.0))


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(WIN * 96, 128); self.ln1 = nn.LayerNorm(128)
        self.l2 = nn.Linear(128, 128); self.ln2 = nn.LayerNorm(128)
        self.drop = nn.Dropout(0.3)
        self.out = nn.Linear(128, 1)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.drop(torch.relu(self.ln1(self.l1(x))))
        x = self.drop(torch.relu(self.ln2(self.l2(x))))
        return torch.sigmoid(self.out(x))


net = Net()
opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=2e-4)
for ep in range(300):
    net.train(); opt.zero_grad()
    p = net(Xtr)
    loss = (nn.functional.binary_cross_entropy(p, ytr, reduction='none') * w_tr).mean()
    loss.backward(); opt.step()
    if ep % 50 == 0 or ep == 299:
        net.eval()
        with torch.no_grad():
            pv = net(Xva)
            posm, negm = (yva > 0.5), (yva < 0.5)
            tp = (pv[posm] > 0.5).float().mean().item() if posm.any() else 0
            fp = (pv[negm] > 0.5).float().mean().item() if negm.any() else 0
        print(f"ep{ep} loss {loss.item():.3f} TP {tp:.2f} FP {fp:.3f}", flush=True)

net.eval()
torch.onnx.export(net, torch.zeros(1, WIN, 96), f"{MODELS}/plappi.onnx",
                  input_names=["x"], output_names=["score"], opset_version=14)
import onnx
from onnx.external_data_helper import load_external_data_for_model
m = onnx.load(f"{MODELS}/plappi.onnx", load_external_data=False)
try:
    load_external_data_for_model(m, MODELS)
except Exception:
    pass
onnx.save(m, f"{MODELS}/plappi.onnx", save_as_external_data=False)
if os.path.exists(f"{MODELS}/plappi.onnx.data"):
    os.remove(f"{MODELS}/plappi.onnx.data")
print("EXPORT fertig", flush=True)
