# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 09:11:53 2025

@author: Moritz Romeike
"""

# --------------------------------------------------------------------------------
# Programmcode 28 (Python): Optimale Clusteranzahl via Silhouette-Koeffizient
# --------------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

# ==== Falls D (Gower-Distanzmatrix) & PAM noch nicht aus #27 vorhanden: =========

def gower_distance_matrix(df_cat: pd.DataFrame, df_num: pd.DataFrame,
                          w_cat: float = 0.5, w_num: float = 0.5) -> np.ndarray:
    n = len(df_cat)
    A = df_cat.to_numpy(dtype=object)
    X = df_num.to_numpy(dtype=float)
    m_cat = A.shape[1]
    m_num = X.shape[1]
    if m_num > 0:
        ranges = np.nanmax(X, axis=0) - np.nanmin(X, axis=0)
        ranges[ranges == 0] = 1.0
    D = np.zeros((n, n), dtype=float)
    for i in range(n):
        cat_dist = (A[i] != A).sum(axis=1)/m_cat if m_cat>0 else 0.0
        if m_num>0:
            num_dist = np.abs(X[i]-X)/ranges
            num_dist = np.nan_to_num(num_dist, nan=0.0).mean(axis=1)
        else:
            num_dist = 0.0
        if m_cat>0 and m_num>0:
            D[i] = w_cat*cat_dist + w_num*num_dist
        elif m_cat>0:
            D[i] = cat_dist
        else:
            D[i] = num_dist
    np.fill_diagonal(D, 0.0)
    return D

def total_cost(D: np.ndarray, medoids: np.ndarray) -> float:
    nearest = np.min(D[:, medoids], axis=1)
    return float(np.sum(nearest))

def pam_build(D: np.ndarray, k: int) -> np.ndarray:
    n = D.shape[0]
    first = np.argmin(np.sum(D, axis=1))
    medoids = [int(first)]
    current = np.min(D[:, medoids], axis=1)
    while len(medoids) < k:
        best_gain, best_idx = None, None
        for i in range(n):
            if i in medoids:
                continue
            new_min = np.minimum(current, D[:, i])
            gain = np.sum(current) - np.sum(new_min)
            if (best_gain is None) or (gain > best_gain):
                best_gain, best_idx = gain, i
        medoids.append(int(best_idx))
        current = np.min(D[:, medoids], axis=1)
    return np.array(sorted(medoids), dtype=int)

def pam_swap(D: np.ndarray, medoids: np.ndarray) -> np.ndarray:
    medoids = medoids.copy()
    improved = True
    while improved:
        improved = False
        current_cost = total_cost(D, medoids)
        non_medoids = [i for i in range(D.shape[0]) if i not in medoids]
        for m_pos, m in enumerate(medoids):
            for h in non_medoids:
                cand = medoids.copy()
                cand[m_pos] = h
                cand = np.array(sorted(cand), dtype=int)
                new_cost = total_cost(D, cand)
                if new_cost + 1e-12 < current_cost:
                    medoids = cand
                    improved = True
                    break
            if improved:
                break
    return medoids
# ==== Ende Nachlade-Block =======================================================

# ==== Silhouette aus Distanzmatrix (ohne scikit-learn) ==========================
def silhouette_values_from_distance(D: np.ndarray, labels: np.ndarray) -> np.ndarray:
    """
    D: (n,n) Distanzmatrix, labels: Clusterlabels 0..K-1
    Rückgabe: s_i pro Punkt (Silhouettenwerte)
    Definition: a(i) = mittl. Distanz zu eigenem Cluster;
                 b(i) = minimaler mittl. Distanz zu anderen Clustern;
                 s(i) = (b(i) - a(i)) / max(a(i), b(i))
    """
    n = D.shape[0]
    s = np.zeros(n, dtype=float)
    K = int(labels.max()) + 1
    # Precompute Indizes je Cluster
    idx_by_cluster = [np.where(labels == c)[0] for c in range(K)]

    for i in range(n):
        ci = labels[i]
        own_idx = idx_by_cluster[ci]
        # a(i): mittlere Distanz zu eigenem Cluster (ohne i)
        if own_idx.size <= 1:
            a_i = 0.0
        else:
            a_i = np.sum(D[i, own_idx]) / (own_idx.size - 1)

        # b(i): minimale mittlere Distanz zu anderen Clustern
        b_i = np.inf
        for c in range(K):
            if c == ci or idx_by_cluster[c].size == 0:
                continue
            mean_dist = np.mean(D[i, idx_by_cluster[c]])
            if mean_dist < b_i:
                b_i = mean_dist

        if b_i == np.inf and a_i == 0.0:
            s[i] = 0.0
        else:
            denom = max(a_i, b_i)
            s[i] = (b_i - a_i) / denom if denom > 0 else 0.0
    return s

def silhouette_plot(values: np.ndarray, labels: np.ndarray, title: str):
    """Horizontale Balken nach Cluster gruppiert – ähnlich R-Plot."""
    K = int(labels.max()) + 1
    order = np.argsort(labels)  # sortieren nach Cluster
    vals_sorted = values[order]
    labs_sorted = labels[order]

    plt.figure(figsize=(6, 5))
    y_base = 0
    yticks = []
    yticklabels = []
    for c in range(K):
        c_idx = np.where(labs_sorted == c)[0]
        if c_idx.size == 0:
            continue
        seg = vals_sorted[c_idx]
        seg_sorted = np.sort(seg)
        y = np.arange(y_base, y_base + seg_sorted.size)
        plt.barh(y, seg_sorted, height=1.0)
        yticks.append(y_base + seg_sorted.size/2)
        yticklabels.append(f"Cluster {c+1}")
        y_base += seg_sorted.size + 2  # Lücke zwischen Clustern

    plt.axvline(np.mean(values), color="red", linestyle="--", linewidth=1, label=f"Ø = {np.mean(values):.3f}")
    plt.xlabel("Silhouette")
    plt.yticks(yticks, yticklabels)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

# ==== Beispiel: Silhouette für K=3 und K=2 (analog R-Code) ======================
# lade Daten (identisch wie in #27)
try:
    D  # noqa
except NameError:
    # Mini-Lader (passt Pfade/Spalten ggf. an)
    csv_path = "data_kmedoids_clustering.csv"
    df = pd.read_csv(csv_path, sep=';', dtype=str, encoding='utf-8', engine='python')
    df.columns = df.columns.str.strip()
    for c in df.columns:
        df[c] = df[c].astype(str).str.strip()
    # Spalten wie in #27:
    CAT_COLS = ["Kontakt", "Bankverbindung", "Rohstoff"]
    NUM_COLS = ["Rechnungsbetrag"]
    # numerisch sauber konvertieren (Komma -> Punkt)
    for c in NUM_COLS:
        df[c] = pd.to_numeric(df[c].str.replace(".", "", regex=False).str.replace(",", ".", regex=False), errors="coerce")
    D = gower_distance_matrix(df[CAT_COLS], df[NUM_COLS], w_cat=0.5, w_num=0.5)

# --- K=3 ---
med3_init = pam_build(D, 3)
med3 = pam_swap(D, med3_init)
labels3 = np.argmin(D[:, med3], axis=1)  # 0..2
sil3 = silhouette_values_from_distance(D, labels3)
print("Silhouette (K=3) – Mittelwert:", np.mean(sil3).round(6))
silhouette_plot(sil3, labels3, "Silhouette-Plot (K=3)")

# --- K=2 ---
med2_init = pam_build(D, 2)
med2 = pam_swap(D, med2_init)
labels2 = np.argmin(D[:, med2], axis=1)  # 0..1
sil2 = silhouette_values_from_distance(D, labels2)
print("Silhouette (K=2) – Mittelwert:", np.mean(sil2).round(6))
silhouette_plot(sil2, labels2, "Silhouette-Plot (K=2)")

# ==== Mittlere Silhouette für K=2..8 ============================================
k_values = range(2, 9)
avg_sil = []
for k in k_values:
    init = pam_build(D, k)
    med  = pam_swap(D, init)
    lab  = np.argmin(D[:, med], axis=1)
    s    = silhouette_values_from_distance(D, lab)
    avg_sil.append(np.mean(s))

plt.figure(figsize=(6,4))
plt.plot(list(k_values), avg_sil, marker="o")
plt.xticks(list(k_values))
plt.xlabel("Anzahl der Cluster (K)")
plt.ylabel("Mittlerer Silhouette-Koeffizient")
plt.title("Silhouette-Analyse zur Wahl von K")
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()

best_k = list(k_values)[int(np.argmax(avg_sil))]
print("Beste K (max. mittlere Silhouette):", best_k, "mit", round(max(avg_sil), 6))
# --------------------------------------------------------------------------------
