# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 08:50:43 2025

@author: Moritz Romeike
"""

# ------------------------------------------------------------------------
# Programmcode 26 (Python, neu): K-Modes Clustering zur Betrugserkennung
# Nur pandas/numpy, robust gegen Leerzeichen & Spaltenabweichungen
# ------------------------------------------------------------------------
import pandas as pd
import numpy as np
from collections import Counter
from pathlib import Path

# === Einstellungen =======================================================
# Pfad zur CSV (R: read.csv2 -> typ. Semikolon-getrennt)
csv_path = "data_kmodes_clustering.csv"
expected_cols = ["Kontakt", "Bankverbindung", "Rohstoff"]  # wie im R-Skript
K = 3
MAX_ITER = 10
RANDOM_STATE = 123
SAVE_OUTPUT = False  # True -> CSV mit Cluster speichern

# Plot optional (kein Fehler, wenn matplotlib fehlt)
try:
    import matplotlib.pyplot as plt
    HAS_MPL = True
except Exception:
    HAS_MPL = False

# === Daten laden & säubern ==============================================
df = pd.read_csv(csv_path, sep=';', dtype=str, encoding='utf-8', engine='python')
# Spalten & Zellen trimmen
df.columns = df.columns.str.strip()
for c in df.columns:
    df[c] = df[c].astype(str).str.strip()

print("Gefundene Spalten:", list(df.columns))

# Prüfen, ob die erwarteten Spalten existieren
missing = [c for c in expected_cols if c not in df.columns]
if missing:
    raise KeyError(f"Fehlende Spalten {missing}. Vorhanden: {list(df.columns)}")

cols = expected_cols
cat_array = df[cols].to_numpy(dtype=object)

# === Hilfsfunktionen (K-Modes) ==========================================
def compute_modes(cat_array: np.ndarray, labels: np.ndarray, k: int):
    """Berechnet pro Cluster den Modus je Spalte (häufigste Kategorie)."""
    modes = []
    for c in range(k):
        members = cat_array[labels == c]
        if members.shape[0] == 0:
            modes.append(None)
            continue
        mode_cols = []
        for j in range(members.shape[1]):
            counts = Counter(members[:, j])
            mode_cols.append(counts.most_common(1)[0][0])
        modes.append(np.array(mode_cols, dtype=object))
    return modes

def assign_clusters(cat_array: np.ndarray, modes: list[np.ndarray]):
    """Weist jede Zeile dem nächsten Modus per Hamming-Distanz zu."""
    k = len(modes)
    # groß initialisieren; dann Distanz berechnen, wo Modus vorhanden
    dists = np.full((cat_array.shape[0], k), fill_value=10**9, dtype=int)
    for c, m in enumerate(modes):
        if m is None:
            continue
        dists[:, c] = np.sum(cat_array != m, axis=1)
    labels = np.argmin(dists, axis=1)
    return labels, dists

def within_cluster_distance(cat_array: np.ndarray, labels: np.ndarray, modes: list[np.ndarray]):
    """Summe der Hamming-Distanzen je Cluster (wie withindiff in R)."""
    k = len(modes)
    wcd = np.zeros(k, dtype=int)
    for c in range(k):
        m = modes[c]
        if m is None:
            continue
        members = cat_array[labels == c]
        wcd[c] = int(np.sum(np.sum(members != m, axis=1)))
    return wcd

# === Initialisierung der Modi (robust für Strings) ======================
rng = np.random.default_rng(RANDOM_STATE)
df_unique = pd.DataFrame(cat_array, columns=cols).drop_duplicates()
unique_rows = df_unique.to_numpy(dtype=object)
if unique_rows.shape[0] == 0:
    raise ValueError("Keine Datenzeilen gefunden.")
need_replace = unique_rows.shape[0] < K  # falls weniger eindeutige Zeilen als K
init_idx = rng.choice(unique_rows.shape[0], size=K, replace=need_replace)
modes = [unique_rows[i].copy() for i in init_idx]

# === Iterationen =========================================================
labels_prev = None
for it in range(1, MAX_ITER + 1):
    labels, _ = assign_clusters(cat_array, modes)
    if labels_prev is not None and np.array_equal(labels, labels_prev):
        # konvergiert
        break
    labels_prev = labels

    # Modi aktualisieren
    modes = compute_modes(cat_array, labels, K)

    # leere Cluster neu initialisieren
    for c in range(K):
        if modes[c] is None:
            modes[c] = unique_rows[rng.integers(0, unique_rows.shape[0])].copy()

# === Ergebnisse ==========================================================
sizes = np.array([np.sum(labels == c) for c in range(K)], dtype=int)
withindiff = within_cluster_distance(cat_array, labels, modes)
modes_df = pd.DataFrame(modes, columns=cols)
modes_df.index = np.arange(1, K+1)

print(f"\nK-modes clustering with {K} clusters of sizes {', '.join(map(str, sizes))}\n")
print("Cluster modes:")
print(modes_df.to_string(index=True))

print("\nClustering vector:")
print(labels + 1)  # 1-basiert wie in R

print("\nWithin cluster simple-matching distance by cluster:")
print(withindiff)

print("\nAvailable components:")
print(['cluster', 'size', 'modes', 'withindiff', 'iterations', 'weighted'])

# Daten + Clusterzuordnung
df_out = df.copy()
df_out["cluster"] = labels + 1
df_out = df_out.sort_values("cluster").reset_index(drop=True)

print("\nDaten mit Clusterzuordnung (sortiert):")
print(df_out.to_string(index=False))

# Optional speichern
if SAVE_OUTPUT:
    out_path = csv_path.with_name("kmodes_cluster_output.csv")
    df_out.to_csv(out_path, sep=';', index=False, encoding='utf-8')
    print(f"\nGespeichert: {out_path}")

# === (Optional) einfache Visualisierung der Clustergrößen ===============
if HAS_MPL:
    try:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(5,3))
        plt.bar([f"C{c}" for c in range(1, K+1)], sizes)
        plt.title("Clustergrößen (K-Modes)")
        plt.xlabel("Cluster")
        plt.ylabel("Anzahl")
        plt.tight_layout()
        plt.show()
    except Exception:
        pass
# ------------------------------------------------------------------------
