# -*- coding: utf-8 -*-
"""
Created on Mon Nov 10 15:34:56 2025

@author: Moritz Romeike
"""

# ------------------------------------------------------------------------
# Programmcode 39 (Python): Entscheidungsbaum zur Kundenklassifizierung
# Inntal AG – Datengenerierung, Boxplots, PairGrid (ggpairs-like),
# Segment-Mittelwerte, Decision Tree 
# ------------------------------------------------------------------------
import sys, subprocess, importlib

def ensure_package(mod_name, pip_name=None):
    try:
        return importlib.import_module(mod_name)
    except ModuleNotFoundError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name or mod_name])
        return importlib.import_module(mod_name)

# Pflichtpakete
np  = ensure_package("numpy")
pd  = ensure_package("pandas")
plt = ensure_package("matplotlib")
sns = ensure_package("seaborn")
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatch, Patch

import numpy as np
import pandas as pd

# --------------------------- Daten erzeugen ------------------------------
rng = np.random.default_rng(123)
n_seg = 50

kunden = pd.DataFrame({
    "Kundensegment": (["preisbewusst"]*n_seg + ["qualitätsorientiert"]*n_seg + ["gewerblich"]*n_seg),
    "Bestellvolumen": np.r_[rng.normal(100,20,n_seg), rng.normal(80,15,n_seg), rng.normal(300,50,n_seg)],
    "Preis_pro_Einheit": np.r_[rng.normal(2.0,0.2,n_seg), rng.normal(4.5,0.3,n_seg), rng.normal(3.0,0.4,n_seg)],
    "Bestellhaeufigkeit": np.r_[rng.normal(2,0.5,n_seg), rng.normal(1,0.3,n_seg), rng.normal(5,1.0,n_seg)],
    "Ruecksendequote": np.r_[rng.normal(0.05,0.02,n_seg), rng.normal(0.02,0.01,n_seg), rng.normal(0.03,0.01,n_seg)],
    "Kundenzufriedenheit": np.r_[rng.normal(3.5,0.5,n_seg), rng.normal(4.8,0.3,n_seg), rng.normal(4.2,0.4,n_seg)]
})

# Faktor-Reihenfolge
cat_order = ["preisbewusst","qualitätsorientiert","gewerblich"]
kunden["Kundensegment"] = pd.Categorical(kunden["Kundensegment"], categories=cat_order, ordered=True)

# --------------------------- Boxplots ----------------------------
boxplot_vars = ["Bestellvolumen","Preis_pro_Einheit","Bestellhaeufigkeit","Ruecksendequote","Kundenzufriedenheit"]

for v in boxplot_vars:
    plt.figure()
    data = [kunden.loc[kunden["Kundensegment"]==seg, v] for seg in cat_order]
    # Matplotlib ≥3.9: tick_labels
    plt.boxplot(data, tick_labels=cat_order)
    plt.title(f"{v} je Kundensegment (Inntal AG)")
    plt.xlabel("Kundensegment"); plt.ylabel(v)
    plt.tight_layout(); plt.show()

# ------------------- „GGpairs“ – Multivariate Visualisierung --------------
sns.set_theme(style="whitegrid",
              rc={"axes.facecolor": "#f5f5f5", "figure.facecolor": "white", "grid.color": "#e6e6e6"})

palette = {
    "preisbewusst":       "#D64045",   # kräftiges Rot
    "qualitätsorientiert":"#1FA187",   # kräftiges Türkis/Grün
    "gewerblich":         "#2C5AA0"    # kräftiges Blau
}

from scipy.stats import pearsonr
def _stars(p): return "***" if p < 1e-3 else "**" if p < 1e-2 else "*" if p < 5e-2 else ""

def upper_corr_panel(x, y, **kws):
    ax = plt.gca()
    if getattr(ax, "_corr_done", False):  # nur einmal pro Panel
        return
    xs_all = kunden[x.name].to_numpy(); ys_all = kunden[y.name].to_numpy()
    m_all = np.isfinite(xs_all) & np.isfinite(ys_all)
    if m_all.sum() >= 3:
        r_all, p_all = pearsonr(xs_all[m_all], ys_all[m_all])
        ax.text(0.03, 0.94, f"Corr: {r_all:+.3f}{_stars(p_all)}",
                transform=ax.transAxes, ha="left", va="top",
                fontsize=11, fontweight="bold", color="#222")
    y0, dy = 0.78, 0.14
    for i, seg in enumerate(cat_order):
        sub = kunden[kunden["Kundensegment"] == seg]
        xs = sub[x.name].to_numpy(); ys = sub[y.name].to_numpy()
        m = np.isfinite(xs) & np.isfinite(ys)
        r_s = pearsonr(xs[m], ys[m])[0] if m.sum() >= 3 else np.nan
        ax.text(0.03, y0 - i*dy, f"{seg}: {r_s:+.3f}",
                transform=ax.transAxes, ha="left", va="top",
                fontsize=10, color=palette[seg])
    ax._corr_done = True

def diag_kde(x, color=None, **kws):
    sns.kdeplot(x=x, fill=True, alpha=0.6, linewidth=1.2, color=color)

def lower_scatter(x, y, color=None, **kws):
    sns.scatterplot(x=x, y=y, s=30, alpha=0.9, color=color, edgecolor="none")

g = sns.PairGrid(
    data=kunden, vars=boxplot_vars, hue="Kundensegment",
    hue_order=cat_order, palette=palette,
    height=2.5, aspect=1.05, diag_sharey=False
)
g.map_lower(lower_scatter)
g.map_diag(diag_kde)
g.map_upper(upper_corr_panel)

from matplotlib.lines import Line2D
handles = [Line2D([0],[0], marker='o', linestyle='',
                  markerfacecolor=palette[s], markeredgecolor=palette[s],
                  markersize=8, label=s) for s in cat_order]
g.fig.legend(handles=handles, labels=cat_order, title="Kundensegment",
             loc="upper right", bbox_to_anchor=(0.97, 1.12), frameon=True)

plt.suptitle("Multivariate Analyse der Kundenmerkmale (Inntal AG)",
             fontsize=16, y=1.03, fontweight="bold")
plt.tight_layout()
plt.show()

# ------------------------- Mittelwerte je Segment -------------------------
print("\nMittelwerte je Kundensegment:")
print(
    kunden.groupby("Kundensegment")
          .agg({c:"mean" for c in boxplot_vars})
          .add_prefix("Ø_")
          .round(3)
)

# ------------------------- Entscheidungsbaum (fixe R-Splits) --------------
# Schwellen exakt wie im R-Bild:
THRESH_ROOT = 170.0                 # Bestellvolumen < 170
THRESH_LEFT = 3.4                   # Preis_pro_Einheit < 3.4

# Leaf-Zuordnung nach den fixen Regeln
left_mask    = kunden["Bestellvolumen"] < THRESH_ROOT
right_mask   = ~left_mask

leaf_left_left   = left_mask  & (kunden["Preis_pro_Einheit"] < THRESH_LEFT)   # preisbewusst
leaf_left_right  = left_mask  & (kunden["Preis_pro_Einheit"] >= THRESH_LEFT)  # qualitätsorientiert
leaf_right       = right_mask                                                # gewerblich

leaf_masks = [leaf_left_left, leaf_left_right, leaf_right]
leaf_names = ["preisbewusst", "qualitätsorientiert", "gewerblich"]

# Zählungen/Proportionen je Leaf in Reihenfolge (preisbewusst, qualitätsorientiert, gewerblich)
N = len(kunden)
def leaf_stats(mask):
    sub = kunden[mask]
    counts = sub["Kundensegment"].value_counts().reindex(cat_order, fill_value=0).to_numpy(dtype=float)
    total  = counts.sum()
    probs  = (counts/total) if total > 0 else np.zeros_like(counts)
    share  = total / N
    return counts, probs, share

leaf_info = [leaf_stats(m) for m in leaf_masks]

# Farben wie im R-Bild (Leaf-Füllung nach Mehrheitsklasse)
COLORS = {
    "preisbewusst":       "#F0806B",  # rot
    "qualitätsorientiert":"#BEBEBE",  # grau
    "gewerblich":         "#62B36F"   # grün
}

# Zeichnen im rpart.plot-Stil (type=3, extra=104, under=TRUE, faclen=0)
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_axis_off()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Titel
ax.text(0.5, 0.95, "Entscheidungsbaum zur Kundensegmentierung (Inntal AG)",
        ha="center", va="center", fontsize=18, fontweight="bold")

# Node- und Kanten-Positionen (manuell passend zum R-Screenshot)
# Root x≈0.64 in der Vorlage – wir nehmen mittig
x_root, y_root = 0.62, 0.78

# Child-Knoten (nur zur Kantenführung; Leaves sitzen unten)
x_left_mid,  y_left_mid  = 0.43, 0.55
x_right_mid, y_right_mid = 0.91, 0.30  # nur als Stützpunkt

# Leaf-Positionen
x_leaf_LL, y_leaf_LL = 0.27, 0.23   # preisbewusst
x_leaf_LR, y_leaf_LR = 0.60, 0.23   # qualitätsorientiert
x_leaf_R,  y_leaf_R  = 0.91, 0.23   # gewerblich

# Kanten Root -> Left/Right
ax.plot([x_root, x_left_mid],  [y_root-0.01, y_left_mid+0.05], color="#666666", lw=1.5)
ax.plot([x_root, x_right_mid], [y_root-0.01, y_right_mid+0.05], color="#666666", lw=1.5)

# Split-Beschriftungen an den Kanten (wie rpart.plot type=3)
ax.text((x_root+x_left_mid)/2,  (y_root+y_left_mid)/2+0.06, f"Bestellvolumen < {int(THRESH_ROOT)}",
        ha="center", va="center", fontsize=13, fontweight="bold")
ax.text(x_right_mid, y_right_mid+0.07, f">= {int(THRESH_ROOT)}", ha="center", va="center",
        fontsize=13, fontweight="bold")

# Left-Split-Kanten zu Leaves
ax.plot([x_left_mid, x_leaf_LL], [y_left_mid, y_leaf_LL+0.07], color="#666666", lw=1.5)
ax.plot([x_left_mid, x_leaf_LR], [y_left_mid, y_leaf_LR+0.07], color="#666666", lw=1.5)

# Left-Split-Beschriftung
ax.text((x_left_mid+x_leaf_LL)/2, y_left_mid+0.02, f"Preis_pro_Einheit < {THRESH_LEFT}",
        ha="center", va="center", fontsize=13, fontweight="bold")
ax.text(x_leaf_LR, y_left_mid+0.02, f">= {THRESH_LEFT}",
        ha="center", va="center", fontsize=13, fontweight="bold")

# kleine "Kerbe" oben am Root (rein optisch wie im Screenshot)
ax.plot([x_root-0.08, x_root+0.08], [y_root, y_root], color="#333333", lw=2)

def draw_leaf(x, y, name, counts, probs, share):
    # Mehrheitsklasse bestimmt Füllfarbe
    maj = name
    box = FancyBboxPatch((x-0.105, y-0.035), 0.21, 0.07,
                         boxstyle="round,pad=0.35", ec="#333333",
                         fc=COLORS[maj], lw=1.0, mutation_aspect=1.0, alpha=0.95)
    ax.add_patch(box)
    ax.text(x, y, name, ha="center", va="center", fontsize=13, color="black")

    # Proportionen in Reihenfolge cat_order, wie extra=104
    probs_str = " ".join([f"{p:.2f}" if p >= 1e-12 else "0.00" for p in probs])
    ax.text(x, y-0.065, probs_str, ha="center", va="center", fontsize=12, color="#000000")
    ax.text(x, y-0.095, f"{round(share*100):d}%", ha="center", va="center", fontsize=12, color="#000000")

# Leaf-Daten (in exakt dieser Reihenfolge)
(counts_LL, probs_LL, share_LL) = leaf_info[0]
(counts_LR, probs_LR, share_LR) = leaf_info[1]
(counts_R,  probs_R,  share_R ) = leaf_info[2]

draw_leaf(x_leaf_LL, y_leaf_LL, "preisbewusst",       counts_LL, probs_LL, share_LL)
draw_leaf(x_leaf_LR, y_leaf_LR, "qualitätsorientiert", counts_LR, probs_LR, share_LR)
draw_leaf(x_leaf_R,  y_leaf_R,  "gewerblich",          counts_R,  probs_R,  share_R )

# Legende links oben
handles = [Patch(facecolor=COLORS[c], edgecolor="#333333", label=c) for c in cat_order]
ax.legend(handles=handles, loc="upper left", bbox_to_anchor=(0.08, 0.88), frameon=False)

plt.tight_layout()
plt.show()
# ------------------------------------------------------------------------
