# -*- coding: utf-8 -*-
"""
Created on Sat Oct 11 08:48:53 2025

@author: Moritz Romeike
"""

# ------------------------------------------------------------------------
# Programmcode 08 (python): Random Forest Modell inkl. Datenbereinigung und Evaluation
# ------------------------------------------------------------------------

# ============================================================
# ROBUSTER INSTALLER + IMPORTS 
# ============================================================
import sys, subprocess, importlib, shutil
from pathlib import Path

def _try_install(cmd_args, desc):
    try:
        print(f"⚙️  {desc}: {' '.join(cmd_args)}")
        subprocess.check_call(cmd_args)
        return True
    except subprocess.CalledProcessError as e:
        print(f"❌ {desc} fehlgeschlagen (Exit {e.returncode}).")
        return False

def ensure_sklearn():
    """Versucht scikit-learn robust zu installieren. True, wenn Import klappt."""
    try:
        importlib.import_module("sklearn")
        return True
    except ImportError:
        pass

    py = sys.executable
    # 1) pip-Toolchain aktualisieren
    _try_install([py, "-m", "pip", "install", "-U", "pip", "setuptools", "wheel"],
                 "Upgrade pip/setuptools/wheel")

    # 2) Numerik zuerst (Wheels)
    _try_install([py, "-m", "pip", "install", "--only-binary", ":all:", "numpy", "scipy"],
                 "Installiere numpy/scipy (Wheels bevorzugt)")

    # 3) scikit-learn Standardversuch
    if _try_install([py, "-m", "pip", "install", "scikit-learn"], "Installiere scikit-learn"):
        try:
            importlib.import_module("sklearn")
            return True
        except ImportError:
            pass

    # 4) Fallbacks (nur Wheels, dann ohne Build-Isolation)
    if _try_install([py, "-m", "pip", "install", "--only-binary", ":all:", "scikit-learn"],
                    "Installiere scikit-learn (nur Wheels)"):
        try:
            importlib.import_module("sklearn")
            return True
        except ImportError:
            pass

    if _try_install([py, "-m", "pip", "install", "--no-build-isolation", "scikit-learn"],
                    "Installiere scikit-learn (kein Build-Isolation)"):
        try:
            importlib.import_module("sklearn")
            return True
        except ImportError:
            pass

    # 5) conda (falls vorhanden)
    conda = shutil.which("conda")
    if conda:
        _try_install([conda, "install", "-y", "scikit-learn"], "Installiere scikit-learn via conda")
        try:
            importlib.import_module("sklearn")
            return True
        except ImportError:
            pass

    return False

# Basis-Imports (NumPy/Pandas/Matplotlib funktionieren ohne sklearn)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# KONFIGURATION
# ============================================================
# Wenn True, werden ALLE verfügbaren numerischen Prädiktoren (außer Zielvariable)
# als Features verwendet. Wenn False, wird – falls vorhanden – nur "Likelihood" genutzt.
USE_ALL_NUMERIC_FEATURES = True

# Pfad zur Excel-Datei (liegt i. d. R. neben diesem Skript)
base_path = Path(__file__).resolve().parent
data_path = base_path / "Kap_2.2_Risk_Management_Data_Imputed.xlsx"

TARGET_COL = "Risk_Score"

# Optional: Spalten, bei denen Komma→Punkt erzwungen werden soll (falls als Text gespeichert)
FORCE_NUMERIC_COLS = ["Likelihood", "Risk_Score", "Impact", "ControlQuality",
                      "Frequency", "Severity"]

# ============================================================
# DATEN LADEN & BEREINIGEN
# ============================================================
df = pd.read_excel(data_path)

# Komma-zu-Punkt & numerisch coercen (nur auf sinnvolle/gelistete Spalten anwenden, falls vorhanden)
for col in FORCE_NUMERIC_COLS:
    if col in df.columns:
        df[col] = df[col].astype(str).str.replace(",", ".", regex=False)
        df[col] = pd.to_numeric(df[col], errors="coerce")

# Zusätzlich: alle objekt-Spalten versuchen numerisch zu konvertieren, ohne Zwang
for col in df.select_dtypes(include=["object"]).columns:
    # Nur konvertieren, wenn das wie eine Zahl aussieht (z.B. "12,3" / "12.3")
    sample = df[col].dropna().astype(str).head(50)
    if sample.str.match(r"^[+-]?(\d+([.,]\d*)?|[.,]\d+)$").mean() > 0.7:
        df[col] = df[col].astype(str).str.replace(",", ".", regex=False)
        df[col] = pd.to_numeric(df[col], errors="ignore")

# Zielvariable prüfen
if TARGET_COL not in df.columns:
    raise ValueError(f"Zielvariable '{TARGET_COL}' nicht in Daten gefunden.")

# Features wählen
if USE_ALL_NUMERIC_FEATURES:
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    feature_cols = [c for c in numeric_cols if c != TARGET_COL]
    # Fallback: falls keine numerischen Prädiktoren gefunden, versuche "Likelihood"
    if not feature_cols and "Likelihood" in df.columns:
        feature_cols = ["Likelihood"]
else:
    feature_cols = ["Likelihood"] if "Likelihood" in df.columns else []

if not feature_cols:
    raise ValueError("Keine geeigneten Feature-Spalten gefunden.")

# NA-Zeilen für Modellspalten entfernen
model_cols = feature_cols + [TARGET_COL]
df_model = df[model_cols].dropna().copy()

X = df_model[feature_cols].copy()
y = df_model[TARGET_COL].copy()

print(f"Verwendete Features: {feature_cols}")
print(f"Zielvariable: {TARGET_COL}")
print(f"Datensätze nach Bereinigung: {len(df_model)}")

# ============================================================
# MODELLIERUNG
# ============================================================
has_sklearn = ensure_sklearn()

if has_sklearn:
    # --- Random Forest (wie in R) ---
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error, r2_score

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=123
    )

    rf = RandomForestRegressor(
        n_estimators=500,
        random_state=123,
        n_jobs=-1
    )
    rf.fit(X_train, y_train)
    y_pred = rf.predict(X_test)

    rmse = np.sqrt(mean_squared_error(y_test, y_pred))
    r2   = r2_score(y_test, y_pred)

    print("\n=== Random Forest Modellgüte ===")
    print(f"RMSE : {rmse:.6f}")
    print(f"R²   : {r2:.6f}")

    # ------------------------------
    # Plot 1: Actual vs Predicted
    # ------------------------------
    plt.figure(figsize=(6, 5))
    plt.scatter(y_test, y_pred, alpha=0.7, label="Testdaten")
    lims = [min(y_test.min(), y_pred.min()), max(y_test.max(), y_pred.max())]
    plt.plot(lims, lims, linewidth=2, label="y = x", color="red")
    plt.xlim(lims); plt.ylim(lims)
    plt.xlabel("Tatsächlicher Risk Score")
    plt.ylabel("Vorhergesagter Risk Score")
    plt.title("Random Forest: Tatsächliche vs. Vorhergesagte Werte")
    plt.legend(); plt.tight_layout(); plt.show()

    # ------------------------------
    # Plot 2: Variablenwichtigkeit (analog R: varImpPlot)
    # ------------------------------
    importances = rf.feature_importances_
    imp_df = (pd.DataFrame({"Variable": feature_cols, "Wichtigkeit": importances})
                .sort_values("Wichtigkeit", ascending=True)
                .reset_index(drop=True))
    imp_df["Wichtigkeit (%)"] = 100 * imp_df["Wichtigkeit"] / imp_df["Wichtigkeit"].sum()

    plt.figure(figsize=(8, 4))
    # Farbabstufung optional – bei 1 Feature ist das naturgemäß 100 %
    norm = imp_df["Wichtigkeit"] / (imp_df["Wichtigkeit"].max() if imp_df["Wichtigkeit"].max() > 0 else 1.0)
    bars = plt.barh(imp_df["Variable"], imp_df["Wichtigkeit"], color=plt.cm.Blues(norm))
    plt.xlabel("Wichtigkeit")
    plt.title("Variablenwichtigkeit im Random-Forest-Modell")
    for bar, val in zip(bars, imp_df["Wichtigkeit (%)"]):
        plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                 f"{val:.1f}%", va="center", fontsize=9)
    plt.tight_layout(); plt.show()

    print("\n=== Variablenwichtigkeit (Feature Importances) ===")
    print(imp_df.to_string(index=False, float_format="%.4f"))

else:
    # --- Fallback: OLS mit NumPy (ohne RF-Importances) ---
    print("\n⚠️  scikit-learn konnte nicht installiert/importiert werden.")
    print("→ Fallback auf OLS (NumPy). Keine Random-Forest-Variablenwichtigkeit verfügbar.")

    # Intercept anhängen und Kleinste Quadrate
    X_design = np.column_stack([np.ones(len(X)), X.values])
    coeffs, *_ = np.linalg.lstsq(X_design, y.values, rcond=None)
    intercept, betas = coeffs[0], coeffs[1:]

    y_pred = X_design @ coeffs
    ss_res = np.sum((y.values - y_pred) ** 2)
    ss_tot = np.sum((y.values - y.values.mean()) ** 2)
    r2 = 1 - ss_res/ss_tot if ss_tot > 0 else np.nan
    rmse = float(np.sqrt(np.mean((y.values - y_pred) ** 2)))

    print("\n=== OLS Fallback (NumPy) ===")
    print(f"Intercept: {intercept:.6f}")
    for name, b in zip(feature_cols, betas):
        print(f"Beta({name}): {b:.6f}")
    print(f"RMSE  : {rmse:.6f}")
    print(f"R²    : {r2:.6f}")

    # Plot: Daten (bei 1 Feature) + Regressionslinie; bei >1 Feature: Actual vs Predicted
    if len(feature_cols) == 1:
        x = X.values.ravel()
        xs = np.linspace(x.min(), x.max(), 200)
        ys = intercept + betas[0] * xs

        plt.figure(figsize=(6, 5))
        plt.scatter(x, y.values, alpha=0.7, label="Daten")
        plt.plot(xs, ys, linewidth=2, label="OLS-Linie", color="red")
        plt.xlabel(feature_cols[0]); plt.ylabel(TARGET_COL)
        plt.title(f"OLS-Fallback: {TARGET_COL} ~ {feature_cols[0]}")
        plt.legend(); plt.tight_layout(); plt.show()
    else:
        plt.figure(figsize=(6, 5))
        plt.scatter(y.values, y_pred, alpha=0.7, label="Daten")
        lims = [min(y.values.min(), y_pred.min()), max(y.values.max(), y_pred.max())]
        plt.plot(lims, lims, linewidth=2, label="y = x", color="red")
        plt.xlim(lims); plt.ylim(lims)
        plt.xlabel("Tatsächlicher Risk Score")
        plt.ylabel("Vorhergesagter Risk Score")
        plt.title("OLS-Fallback: Tatsächliche vs. Vorhergesagte Werte")
        plt.legend(); plt.tight_layout(); plt.show()
