# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 10:14:39 2025

@author: Moritz Romeike
"""

# ------------------------------------------------------------------------
# Programmcode 38 (Python, ohne sklearn):
# KNN-Analyse LED-Produktionsdaten (k=3) + Konfusionsmatrix/Heatmap + Scatter
# Abhängigkeiten: numpy, pandas, matplotlib
# ------------------------------------------------------------------------
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ----------------------------- Daten laden -------------------------------
BASE = Path(os.getcwd())
XLSX = BASE / "Kap_10.4.2_led_produktionsdaten.xlsx"

def load_data():
    if XLSX.exists():
        try:
            df = pd.read_excel(XLSX, engine="openpyxl")
            return df
        except Exception as e:
            print(f"⚠️ Konnte Excel nicht lesen ({e}). Versuche CSV-Fallback …")

df = load_data()
df.columns = [str(c).strip() for c in df.columns]

# ----------------------- Zielklasse (analog R-Code) ------------------------
if "Lichtausbeute_Lumen" not in df.columns:
    raise KeyError("Spalte 'Lichtausbeute_Lumen' fehlt.")
df["Qualitaetsklasse"] = pd.cut(
    df["Lichtausbeute_Lumen"],
    bins=3,
    labels=["Niedrig", "Mittel", "Hoch"],
    include_lowest=True
)

features = ["Stromaufnahme_mA", "Waermeentwicklung_C", "Lichtausbeute_Lumen"]
for f in features:
    if f not in df.columns:
        raise KeyError(f"Spalte '{f}' fehlt.")

# ----------------------------- Train/Test --------------------------------
rng = np.random.default_rng(123)             # wie set.seed(123)
n = len(df)
idx = np.arange(n)
rng.shuffle(idx)
cut = int(0.7 * n)
train_idx, test_idx = idx[:cut], idx[cut:]

train = df.iloc[train_idx].reset_index(drop=True)
test  = df.iloc[test_idx].reset_index(drop=True)

X_train = train[features].to_numpy(dtype=float)
y_train = train["Qualitaetsklasse"].astype(str).to_numpy()

X_test  = test[features].to_numpy(dtype=float)
y_test  = test["Qualitaetsklasse"].astype(str).to_numpy()

# --------------------------- KNN (ohne sklearn) --------------------------
def euclidean(a: np.ndarray, b: np.ndarray) -> float:
    return np.sqrt(np.sum((a - b) ** 2))

def knn_predict(X_train, y_train, X_test, k=3):
    """
    Einfache KNN-Vorhersage:
    - euklidische Distanz
    - Mehrheitsvotum, bei Gleichstand: entscheidet zufällig (stabil mit Seed)
    """
    rng_local = np.random.default_rng(42)  # deterministisch
    preds = []
    for x in X_test:
        # Distanzen zu allen Trainingspunkten
        dists = np.linalg.norm(X_train - x, axis=1)  # vektorisierte euklidische Distanz
        nn_idx = np.argsort(dists)[:k]
        nn_labels = y_train[nn_idx]

        # Mehrheitsvotum
        values, counts = np.unique(nn_labels, return_counts=True)
        max_count = counts.max()
        winners = values[counts == max_count]
        if len(winners) == 1:
            preds.append(winners[0])
        else:
            preds.append(rng_local.choice(winners))  # Tie-Break
    return np.array(preds, dtype=str)

y_pred = knn_predict(X_train, y_train, X_test, k=3)

# -------------------------- Konfusionsmatrix -----------------------------
labels = ["Niedrig", "Mittel", "Hoch"]  # feste Ordnung
label_to_idx = {lab:i for i, lab in enumerate(labels)}

cm = np.zeros((len(labels), len(labels)), dtype=int)
for yt, yp in zip(y_test, y_pred):
    if yt in label_to_idx and yp in label_to_idx:
        cm[label_to_idx[yt], label_to_idx[yp]] += 1

acc = np.trace(cm) / np.sum(cm) if cm.sum() > 0 else np.nan

print("Konfusionsmatrix (Zeilen = Ist, Spalten = Vorhersage):")
print(pd.DataFrame(cm, index=labels, columns=labels))
print(f"\nAccuracy: {acc:.4f}")

# Optional: einfache Kennzahlen je Klasse
def per_class_metrics(cm):
    prec, rec, f1 = [], [], []
    for i in range(cm.shape[0]):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        precision = tp / (tp + fp) if (tp + fp) else 0.0
        recall    = tp / (tp + fn) if (tp + fn) else 0.0
        f1_i = 2*precision*recall/(precision+recall) if (precision+recall) else 0.0
        prec.append(precision); rec.append(recall); f1.append(f1_i)
    return np.array(prec), np.array(rec), np.array(f1)

prec, rec, f1 = per_class_metrics(cm)
print("\nKlassenscores:")
for i, lab in enumerate(labels):
    print(f"  {lab:7s}  Precision={prec[i]:.3f}  Recall={rec[i]:.3f}  F1={f1[i]:.3f}")

# ---------------------------- Heatmap-Plot -------------------------------

from scipy.cluster.hierarchy import linkage, dendrogram
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable

# 1) Cluster-Reihenfolge
row_link = linkage(cm,   method="ward")
col_link = linkage(cm.T, method="ward")
row_leaves = dendrogram(row_link, no_plot=True)["leaves"]
col_leaves = dendrogram(col_link, no_plot=True)["leaves"]

cm_ord     = cm[np.ix_(row_leaves, col_leaves)]
row_labels = [labels[i] for i in row_leaves]
col_labels = [labels[j] for j in col_leaves]

# 2) Figure-Layout: nur 2 Spalten im Grid (links Dendrogramm, rechts Heatmap).
fig = plt.figure(figsize=(12, 4.5))
gs = gridspec.GridSpec(
    2, 2,
    width_ratios=[0.18, 1.00],
    height_ratios=[0.28, 1.00],
    wspace=0.08, hspace=0.02
)

ax_col = fig.add_subplot(gs[0, 1])   # Spalten-Dendrogramm (oben)
ax_row = fig.add_subplot(gs[1, 0])   # Zeilen-Dendrogramm (links)
ax_hm  = fig.add_subplot(gs[1, 1])   # Heatmap (rechts)

# 3) Dendrogramme
dendrogram(col_link, ax=ax_col, orientation="top", no_labels=True, color_threshold=0)
ax_col.set_xticks([]); ax_col.set_yticks([]); ax_col.set_frame_on(False)

dendrogram(row_link, ax=ax_row, orientation="left", no_labels=True, color_threshold=0)
ax_row.set_xticks([]); ax_row.set_yticks([]); ax_row.set_frame_on(False)

# 4) Heatmap
vmax = float(cm_ord.max()) if cm_ord.size else 1.0
im = ax_hm.imshow(cm_ord, cmap="Blues", vmin=0, vmax=vmax, aspect="auto", interpolation="nearest")

# Achsen & Labels
ax_hm.set_xticks(np.arange(len(col_labels)))
ax_hm.set_xticklabels(col_labels, rotation=0)
ax_hm.set_yticks(np.arange(len(row_labels)))
ax_hm.set_yticklabels(row_labels)

# Y-Ticks 
ax_hm.yaxis.tick_right()
ax_hm.yaxis.set_label_position("right")
ax_hm.tick_params(axis="y", pad=24)  # Abstand der Labels von der rechten Achse

# Zellgitter
ax_hm.set_xticks(np.arange(-.5, len(col_labels), 1), minor=True)
ax_hm.set_yticks(np.arange(-.5, len(row_labels), 1), minor=True)
ax_hm.grid(which="minor", color="white", linewidth=1.2)
ax_hm.tick_params(which="minor", bottom=False, left=False)

# Zellwerte
for i in range(cm_ord.shape[0]):
    for j in range(cm_ord.shape[1]):
        ax_hm.text(j, i, f"{cm_ord[i, j]:.2f}", ha="center", va="center", fontsize=9, color="black")

# 5) Colorbar 
divider = make_axes_locatable(ax_hm)
cax = divider.append_axes("right", size="3%", pad=0.9)   # pad ↑ sorgt für Abstand zu den Y-Labels
cb = fig.colorbar(im, cax=cax)
cb.ax.tick_params(labelsize=9)

# Titel
fig.suptitle("Heatmap der Konfusionsmatrix", fontsize=14, y=0.98)

plt.subplots_adjust(left=0.06, right=0.96, top=0.92, bottom=0.10)
plt.show()


# -------- Plot: Lichtausbeute vs. Wärmeentwicklung nach Qualitätsklasse ---
plt.figure()
markers = {"Niedrig":"o", "Mittel":"s", "Hoch":"x"}
for cls in labels:
    subset = df[df["Qualitaetsklasse"].astype(str) == cls]
    plt.scatter(subset["Waermeentwicklung_C"], subset["Lichtausbeute_Lumen"],
                label=cls, marker=markers[cls], alpha=0.7)
plt.title("Lichtausbeute vs. Wärmeentwicklung")
plt.xlabel("Wärmeentwicklung (°C)")
plt.ylabel("Lichtausbeute (Lumen)")
plt.legend(title="Qualitätsklasse")
plt.tight_layout()
plt.show()
# ------------------------------------------------------------------------
