# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 08:58:38 2025

@author: Moritz Romeike
"""

# ------------------------------------------------------------------------
# Programmcode_26_inkl-Plots (Python): K-Modes + Auswertung + Plots
# ------------------------------------------------------------------------
import pandas as pd
import numpy as np
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt

# === Einstellungen =======================================================
csv_path = "data_kmodes_clustering.csv"  # anpassen
expected_cols = ["Kontakt", "Bankverbindung", "Rohstoff"]
K = 3
MAX_ITER = 10
RANDOM_STATE = 123
SAVE_OUTPUT = False   # True -> Cluster-Ergebnis als CSV speichern
SAVE_PLOTS  = False   # True -> Plots als PNG speichern

# === Daten laden & säubern ==============================================
df = pd.read_csv(csv_path, sep=';', dtype=str, encoding='utf-8', engine='python')
df.columns = df.columns.str.strip()
for c in df.columns:
    df[c] = df[c].astype(str).str.strip()

print("Spalten gefunden:", list(df.columns))
missing = [c for c in expected_cols if c not in df.columns]
if missing:
    raise KeyError(f"Fehlende Spalten {missing}. Vorhanden: {list(df.columns)}")

cols = expected_cols
cat_array = df[cols].to_numpy(dtype=object)

# === K-Modes Hilfsfunktionen ============================================
def compute_modes(arr: np.ndarray, labels: np.ndarray, k: int):
    modes = []
    for c in range(k):
        members = arr[labels == c]
        if members.shape[0] == 0:
            modes.append(None)
            continue
        mode_cols = []
        for j in range(members.shape[1]):
            counts = Counter(members[:, j])
            mode_cols.append(counts.most_common(1)[0][0])
        modes.append(np.array(mode_cols, dtype=object))
    return modes

def assign_clusters(arr: np.ndarray, modes: list[np.ndarray]):
    k = len(modes)
    dists = np.full((arr.shape[0], k), fill_value=10**9, dtype=int)
    for c, m in enumerate(modes):
        if m is None:
            continue
        dists[:, c] = np.sum(arr != m, axis=1)  # Hamming-Distanz
    labels = np.argmin(dists, axis=1)
    return labels, dists

def within_cluster_distance(arr: np.ndarray, labels: np.ndarray, modes: list[np.ndarray]):
    k = len(modes)
    wcd = np.zeros(k, dtype=int)
    for c in range(k):
        m = modes[c]
        if m is None:
            continue
        members = arr[labels == c]
        wcd[c] = int(np.sum(np.sum(members != m, axis=1)))
    return wcd

# === Initialisierung (robust für Strings) ===============================
rng = np.random.default_rng(RANDOM_STATE)
df_unique = pd.DataFrame(cat_array, columns=cols).drop_duplicates()
unique_rows = df_unique.to_numpy(dtype=object)
if unique_rows.shape[0] == 0:
    raise ValueError("Keine Datenzeilen gefunden.")
need_replace = unique_rows.shape[0] < K
init_idx = rng.choice(unique_rows.shape[0], size=K, replace=need_replace)
modes = [unique_rows[i].copy() for i in init_idx]

# === Iterationen =========================================================
labels_prev = None
for it in range(1, MAX_ITER + 1):
    labels, _ = assign_clusters(cat_array, modes)
    if labels_prev is not None and np.array_equal(labels, labels_prev):
        break  # konvergiert
    labels_prev = labels
    modes = compute_modes(cat_array, labels, K)
    # leere Cluster neu initialisieren
    for c in range(K):
        if modes[c] is None:
            modes[c] = unique_rows[rng.integers(0, unique_rows.shape[0])].copy()

# === Ergebnisse ==========================================================
sizes = np.array([np.sum(labels == c) for c in range(K)], dtype=int)
withindiff = within_cluster_distance(cat_array, labels, modes)
modes_df = pd.DataFrame(modes, columns=cols); modes_df.index = np.arange(1, K+1)

print(f"\nK-modes clustering with {K} clusters of sizes {', '.join(map(str, sizes))}\n")
print("Cluster modes:"); print(modes_df.to_string(index=True))
print("\nClustering vector:"); print(labels + 1)
print("\nWithin cluster simple-matching distance by cluster:"); print(withindiff)
print("\nAvailable components:"); print(['cluster','size','modes','withindiff','iterations','weighted'])

df_out = df.copy()
df_out["cluster"] = labels + 1
df_out = df_out.sort_values("cluster").reset_index(drop=True)
print("\nDaten mit Clusterzuordnung (sortiert):")
print(df_out.to_string(index=False))

if SAVE_OUTPUT:
    out_path = csv_path.with_name("kmodes_cluster_output.csv")
    df_out.to_csv(out_path, sep=';', index=False, encoding='utf-8')
    print(f"\nGespeichert: {out_path}")

# === PLOTS ===============================================================
# 1) Clustergrößen
plt.figure(figsize=(5,3))
plt.bar([f"C{c}" for c in range(1, K+1)], sizes)
plt.title("Clustergrößen (K-Modes)")
plt.xlabel("Cluster"); plt.ylabel("Anzahl")
plt.tight_layout()
if SAVE_PLOTS:
    plt.savefig(csv_path.with_name("plot_cluster_groessen.png"), dpi=150)
plt.show()

# Helper: gruppierter Balken-Plot (2 Gruppen)
def grouped_barplot(count_df_a: pd.DataFrame, label_a: str,
                    count_df_b: pd.DataFrame, label_b: str,
                    title: str, xlabel: str, ylabel: str,
                    filename: str | None = None):
    # Alle Kategorien zusammenführen
    all_keys = pd.DataFrame({xlabel: pd.unique(pd.concat([count_df_a[xlabel], count_df_b[xlabel]], ignore_index=True))})
    df_a = all_keys.merge(count_df_a, on=xlabel, how="left").fillna({ylabel:0})
    df_b = all_keys.merge(count_df_b, on=xlabel, how="left").fillna({ylabel:0})

    # Sortierung nach Gesamtfreq aufsteigend
    order = (df_a[ylabel] + df_b[ylabel]).sort_values(ascending=True).index
    df_a = df_a.loc[order].reset_index(drop=True)
    df_b = df_b.loc[order].reset_index(drop=True)

    # Plot
    x = np.arange(len(df_a))
    w = 0.45
    plt.figure(figsize=(10,4))
    plt.bar(x - w/2, df_a[ylabel], width=w, label=label_a, color="dimgrey")
    plt.bar(x + w/2, df_b[ylabel], width=w, label=label_b, color="navy")
    plt.xticks(x, df_a[xlabel], rotation=0)
    plt.title(title); plt.xlabel(xlabel); plt.ylabel(ylabel)
    plt.legend(); plt.tight_layout()
    if filename and SAVE_PLOTS:
        plt.savefig(csv_path.with_name(filename), dpi=150)
    plt.show()

# 2) Kontakt: Cluster 1 vs. Grundgesamtheit
df_plot = df_out[["Kontakt","cluster"]].copy()
cluster1 = df_plot[df_plot["cluster"] == 1]
freq_c1 = cluster1["Kontakt"].value_counts().rename_axis("Kontakt").reset_index(name="Freq")
freq_all = df_plot["Kontakt"].value_counts().rename_axis("Kontakt").reset_index(name="Freq")
grouped_barplot(freq_all, "Grundgesamtheit", freq_c1, "Cluster 1",
                "Vergleich der Kontaktpersonen: Gesamt vs. Cluster 1",
                "Kontakt", "Freq", filename="plot_kontakt_cluster1.png")

# 3) Rohstoff: Cluster 1 vs. Grundgesamtheit
df_plot_r = df_out["Rohstoff"].to_frame().assign(cluster=df_out["cluster"])
cluster1_r = df_plot_r[df_plot_r["cluster"] == 1]
freq_c1_r = cluster1_r["Rohstoff"].value_counts().rename_axis("Rohstoff").reset_index(name="Freq")
freq_all_r = df_plot_r["Rohstoff"].value_counts().rename_axis("Rohstoff").reset_index(name="Freq")
# für den Helper konsistente Spaltennamen
freq_c1_r = freq_c1_r.rename(columns={"Rohstoff":"Kategorie"})
freq_all_r = freq_all_r.rename(columns={"Rohstoff":"Kategorie"})
grouped_barplot(freq_all_r, "Grundgesamtheit", freq_c1_r, "Cluster 1",
                "Vergleich der Rohstoffe: Gesamt vs. Cluster 1",
                "Kategorie", "Freq", filename="plot_rohstoff_cluster1.png")

# Hinweis: Für Cluster 2/3 einfach '== 1' durch '== 2' oder '== 3' ersetzen.
# ------------------------------------------------------------------------
