# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 09:07:38 2025

@author: Moritz Romeike
"""

# --------------------------------------------------------------------------------
# Programmcode 27 (Python): K-Medoids (PAM) mit Gower-Distanz zur Betrugserkennung
# --------------------------------------------------------------------------------
import pandas as pd
import numpy as np
from pathlib import Path

# ====== Pfad & Spaltennamen anpassen ============================================
csv_path = "data_kmedoids_clustering.csv"

# Erwartete Spalten wie im R-Beispiel
CAT_COLS = ["Kontakt", "Bankverbindung", "Rohstoff"]   # kategorial
NUM_COLS = ["Rechnungsbetrag"]                         # numerisch (z. B. 0,14)

K = 3
RANDOM_STATE = 123
SAVE_OUTPUT = False  # True -> CSV mit Cluster speichern

# ====== Daten laden ==============================================================
df = pd.read_csv(csv_path, sep=';', dtype=str, encoding='utf-8', engine='python')

# Trim von Spaltennamen & Zellen (WICHTIG: .str.strip() auf Series!)
df.columns = df.columns.str.strip()
for c in df.columns:
    df[c] = df[c].astype(str).str.strip()

# Verfügbarkeit prüfen
missing_cat = [c for c in CAT_COLS if c not in df.columns]
missing_num = [c for c in NUM_COLS if c not in df.columns]
if missing_cat or missing_num:
    raise KeyError(
        f"Fehlende Spalten – kategorial: {missing_cat}, numerisch: {missing_num}. "
        f"Vorhanden: {list(df.columns)}"
    )

# Numerische Spalten in float umwandeln (Komma als Dezimaltrennzeichen zulassen)
for c in NUM_COLS:
    df[c] = pd.to_numeric(
        df[c].str.replace(".", "", regex=False).str.replace(",", ".", regex=False),
        errors="coerce"
    )

# ====== Gower-Distanzmatrix (kategorial + numerisch) =============================
# Distanz = gewichtetes Mittel aus:
# - kategorial: simple matching (0 gleich, 1 ungleich)
# - numerisch: |x_i - x_j| / Range
def gower_distance_matrix(df_cat: pd.DataFrame, df_num: pd.DataFrame,
                          w_cat: float = 0.5, w_num: float = 0.5) -> np.ndarray:
    n = len(df_cat)
    m_cat = df_cat.shape[1]
    m_num = df_num.shape[1]

    A = df_cat.to_numpy(dtype=object)  # kategoriale Werte
    X = df_num.to_numpy(dtype=float)   # numerische Werte

    # Ranges für numerische Spalten (0 -> 1, um /0 zu vermeiden)
    if m_num > 0:
        ranges = np.nanmax(X, axis=0) - np.nanmin(X, axis=0)
        ranges[ranges == 0] = 1.0

    D = np.zeros((n, n), dtype=float)

    for i in range(n):
        # kategorialer Anteil
        if m_cat > 0:
            neq = (A[i] != A)  # shape (n, m_cat)
            cat_dist = neq.sum(axis=1) / m_cat
        else:
            cat_dist = 0.0

        # numerischer Anteil
        if m_num > 0:
            num_dist = np.abs(X[i] - X) / ranges  # shape (n, m_num)
            num_dist = np.nan_to_num(num_dist, nan=0.0)
            num_dist = num_dist.mean(axis=1)
        else:
            num_dist = 0.0

        if (m_cat > 0) and (m_num > 0):
            D[i] = w_cat * cat_dist + w_num * num_dist
        elif m_cat > 0:
            D[i] = cat_dist
        else:
            D[i] = num_dist

    np.fill_diagonal(D, 0.0)
    return D

D = gower_distance_matrix(df[CAT_COLS], df[NUM_COLS], w_cat=0.5, w_num=0.5)

# ====== PAM (K-Medoids) – Build + Swap ==========================================
rng = np.random.default_rng(RANDOM_STATE)

def total_cost(D: np.ndarray, medoids: np.ndarray) -> float:
    """Summe der Distanzen jedes Punkts zum nächsten Medoid."""
    nearest = np.min(D[:, medoids], axis=1)
    return float(np.sum(nearest))

def pam_build(D: np.ndarray, k: int) -> np.ndarray:
    """BUILD-Phase: wähle initiale Medoids (greedy)."""
    n = D.shape[0]
    first = np.argmin(np.sum(D, axis=1))  # minimaler Gesamtabstand
    medoids = [int(first)]
    while len(medoids) < k:
        current = np.min(D[:, medoids], axis=1)
        best_gain, best_idx = None, None
        for i in range(n):
            if i in medoids:
                continue
            new_min = np.minimum(current, D[:, i])
            gain = np.sum(current) - np.sum(new_min)
            if (best_gain is None) or (gain > best_gain):
                best_gain, best_idx = gain, i
        medoids.append(int(best_idx))
    return np.array(sorted(medoids), dtype=int)

def pam_swap(D: np.ndarray, medoids: np.ndarray) -> np.ndarray:
    """SWAP-Phase: tausche Medoids mit Nicht-Medoids, wenn Kosten sinken."""
    n = D.shape[0]
    medoids = medoids.copy()
    improved = True
    while improved:
        improved = False
        current_cost = total_cost(D, medoids)
        non_medoids = [i for i in range(n) if i not in medoids]
        for m_pos, m in enumerate(medoids):
            for h in non_medoids:
                cand = medoids.copy()
                cand[m_pos] = h
                cand = np.array(sorted(cand), dtype=int)
                new_cost = total_cost(D, cand)
                if new_cost + 1e-12 < current_cost:
                    medoids = cand
                    improved = True
                    break
            if improved:
                break
    return medoids

# Build + Swap ausführen
init_medoids = pam_build(D, K)
final_medoids = pam_swap(D, init_medoids)

# Finale Zuordnung
assign = np.argmin(D[:, final_medoids], axis=1)   # 0..K-1
clusters = assign + 1                             # 1..K (R-Stil)
sizes = np.bincount(assign, minlength=K)
obj_cost = total_cost(D, final_medoids)

# ====== Ausgabe wie in R =========================================================
print("Medoids (Zeilenindex, 1-basiert):")
print([int(m) + 1 for m in final_medoids])

print("\nClustering vector (1..K):")
print(clusters)

print("\nObjective function (Summe der minimalen Distanzen):")
print(obj_cost)

print("\nAvailable components:")
print(["medoids", "id.med", "clustering", "objective", "diss"])

# Tabelle wie cluster_output_pam (Originaldaten + Cluster), sortiert
cluster_output_pam = df.copy()
cluster_output_pam["cluster"] = clusters
cluster_output_pam = cluster_output_pam.sort_values("cluster").reset_index(drop=True)
print("\nDaten mit Clusterzuordnung (sortiert):")
print(cluster_output_pam.to_string(index=False))

# Optional speichern
if SAVE_OUTPUT:
    out_path = csv_path.with_name("kmedoids_cluster_output.csv")
    cluster_output_pam.to_csv(out_path, sep=';', index=False, encoding='utf-8')
    print(f"\nGespeichert: {out_path}")
# --------------------------------------------------------------------------------
