# -*- coding: utf-8 -*-
"""
Created on Sat Oct  4 10:09:12 2025

@author: Moritz Romeike
"""

# -------------------------------------------------------------------------
# Programmcode 37 (Python): Bayes-Netz für Lieferantenausfallrisiken
# - Replikation des R-Codes ohne Zusatzbibliotheken (reine numpy-Inferenz)
# - Variablen: U1, U2, U3, R1, R2 (binär: "Ja","Nein"), W (3-stufig: "h","m","n")
# - Struktur: U1->R1, U2->R1, U3->R2, R1->W, R2->W
# -------------------------------------------------------------------------
import numpy as np
import itertools
import matplotlib.pyplot as plt

# ------------------------ States / Indizes / Labels ----------------------
# Wir verwenden durchgängig dieselbe Reihenfolge wie im R-Code:
# Binäre Variablen: Index 0 = "Ja", Index 1 = "Nein"
bin_levels = ["Ja", "Nein"]
W_levels   = ["h", "m", "n"]  # Index 0=h, 1=m, 2=n

# Reihenfolge der Variablen im gemeinsamen Tensor:
#   [U1, U2, U3, R1, R2, W]  -> Dimensionen [2, 2, 2, 2, 2, 3]
dims = [2, 2, 2, 2, 2, 3]
IDX = {"U1":0, "U2":1, "U3":2, "R1":3, "R2":4, "W":5}

# ----------------------------- CPTs (wie R) ------------------------------
# Priors: P(U1), P(U2), P(U3)  (Index 0="Ja", 1="Nein")
P_U1 = np.array([0.3, 0.7])
P_U2 = np.array([0.1, 0.9])
P_U3 = np.array([0.2, 0.8])

# P(R2 | U3) : shape (2 R2, 2 U3)
# R array c(0.15,0.85, 0.001,0.999) mit dimnames(R2, U3) => Spalten: U3=Ja, U3=Nein
# Spalte U3=Ja:   [P(R2=Ja), P(R2=Nein)] = [0.15, 0.85]
# Spalte U3=Nein: [0.001, 0.999]
P_R2_given_U3 = np.array([
    [0.15, 0.001],   # R2=Ja  | U3=Ja,Nein
    [0.85, 0.999],   # R2=Nein| U3=Ja,Nein
])

# P(R1 | U1, U2): shape (2 R1, 2 U1, 2 U2)
# R array c(0.25,0.75, 0.2,0.8, 0.1,0.9, 0.001,0.999) mit dim(R1, U1, U2)
# Reihenfolge der (U1,U2)-Paare in R (Spalten-füllend): (Ja,Ja), (Nein,Ja), (Ja,Nein), (Nein,Nein)
P_R1_given_U1U2 = np.zeros((2,2,2))
# (U1=Ja, U2=Ja)
P_R1_given_U1U2[:, 0, 0] = [0.25, 0.75]
# (U1=Nein, U2=Ja)
P_R1_given_U1U2[:, 1, 0] = [0.20, 0.80]
# (U1=Ja, U2=Nein)
P_R1_given_U1U2[:, 0, 1] = [0.10, 0.90]
# (U1=Nein, U2=Nein)
P_R1_given_U1U2[:, 1, 1] = [0.001, 0.999]

# P(W | R1, R2): shape (3 W, 2 R1, 2 R2)
# R array c(0.7,0.29,0.01, 0.6,0.35,0.05, 0.2,0.65,0.15, 0,0,1)
# Dimnames(W, R1, R2). Spalten (R1,R2): (Ja,Ja), (Nein,Ja), (Ja,Nein), (Nein,Nein)
P_W_given_R1R2 = np.zeros((3,2,2))
# (R1=Ja,   R2=Ja)   -> [0.7, 0.29, 0.01]
P_W_given_R1R2[:, 0, 0] = [0.7, 0.29, 0.01]
# (R1=Nein, R2=Ja)   -> [0.6, 0.35, 0.05]
P_W_given_R1R2[:, 1, 0] = [0.6, 0.35, 0.05]
# (R1=Ja,   R2=Nein) -> [0.2, 0.65, 0.15]
P_W_given_R1R2[:, 0, 1] = [0.2, 0.65, 0.15]
# (R1=Nein, R2=Nein) -> [0.0, 0.0, 1.0]
P_W_given_R1R2[:, 1, 1] = [0.0, 0.0, 1.0]

# ------------------------ Joint Distribution P(all) ----------------------
# P(U1) P(U2) P(U3) P(R1|U1,U2) P(R2|U3) P(W|R1,R2)
joint = np.zeros(dims, dtype=float)

for u1, u2, u3, r1, r2, w in itertools.product(range(2), range(2), range(2), range(2), range(2), range(3)):
    p = (
        P_U1[u1] *
        P_U2[u2] *
        P_U3[u3] *
        P_R1_given_U1U2[r1, u1, u2] *
        P_R2_given_U3[r2, u3] *
        P_W_given_R1R2[w, r1, r2]
    )
    joint[u1, u2, u3, r1, r2, w] = p

# Numerische Normalisierung (sollte ~1.0 sein)
joint /= joint.sum()

# ---------------------------- Query-Helpers ------------------------------
def marginal(dist, var):
    """Randverteilung P(var) aus 'dist' (gemeinsame Verteilung) berechnen."""
    axes = tuple(i for name,i in IDX.items() if name != var)
    m = dist.sum(axis=axes)
    return m

def joint_of(dist, vars_):
    """Gemeinsame Verteilung über Teilmenge vars_. Liefert Tensor mit Achsen in vars_-Reihenfolge."""
    # Summe über alle NICHT-variablen Achsen
    keep = [IDX[v] for v in vars_]
    sum_axes = tuple(sorted(set(range(dist.ndim)) - set(keep)))
    J = dist.sum(axis=sum_axes)
    # Bringe Achsen in die vars_-Reihenfolge
    if list(keep) != sorted(keep):
        # permute axes accordingly
        axis_map = {ax:i for i,ax in enumerate(sorted(keep))}
        perm = [axis_map[ax] for ax in keep]
        J = np.transpose(J, axes=perm)
    return J

def set_evidence(dist, evidence):
    """
    Evidenz setzen und renormalisieren.
    evidence: dict, z.B. {"U3":"Ja"} oder {"W":"h"} oder {"U2":"Nein","U3":"Nein"}
    """
    mask = np.ones_like(dist, dtype=bool)
    for var, state in evidence.items():
        ax = IDX[var]
        if var in ["U1","U2","U3","R1","R2"]:
            idx = 0 if state=="Ja" else 1
            sel = np.zeros(dims[ax], dtype=bool)
            sel[idx] = True
        elif var == "W":
            idx = {"h":0,"m":1,"n":2}[state]
            sel = np.zeros(dims[ax], dtype=bool)
            sel[idx] = True
        else:
            raise ValueError(f"Unbekannte Variable: {var}")
        # Broadcast-Auswahl
        slicer = [slice(None)]*dist.ndim
        slicer[ax] = sel
        mask &= sel if dist.ndim==1 else np.broadcast_to(sel.reshape([dims[ax] if i==ax else 1 for i in range(dist.ndim)]), dist.shape)
    out = np.where(mask, dist, 0.0)
    s = out.sum()
    if s <= 0:
        raise ValueError("Evidenz inkompatibel – ergibt Summe 0.")
    out /= s
    return out

def pretty_prob(vec, labels):
    return {lab: float(p) for lab, p in zip(labels, vec)}

# ---------------------- Gemeinsame Verteilung & Rand ----------------------
print("Gemeinsame Verteilung hat Summe:", joint.sum())
P_W = marginal(joint, "W")
print("Randverteilung P(W):", pretty_prob(P_W, W_levels))

# --------------------- Visualisierung (Randbalken) -----------------------
def plot_marginals(dist, title="Bayes-Netz (Randverteilungen)"):
    vars_to_plot = ["U1","U2","U3","R1","R2","W"]
    fig, axes = plt.subplots(2, 3, figsize=(10,6))
    axes = axes.ravel()
    for i, var in enumerate(vars_to_plot):
        ax = axes[i]
        p = marginal(dist, var)
        if var == "W":
            labels = W_levels
        else:
            labels = bin_levels
        ax.bar(range(len(p)), p)
        ax.set_title(var)
        ax.set_xticks(range(len(p)))
        ax.set_xticklabels(labels)
        ax.set_ylim(0, 1)
    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_marginals(joint, "Bayes-Netz (Randverteilungen ohne Evidenz)")

# -------------------------- Inferenzanalyse I ----------------------------
# Evidenz: U3 = "Ja"
post_U3_Ja = set_evidence(joint, {"U3":"Ja"})
plot_marginals(post_U3_Ja, 'Evidenz: U3 = "Ja"')

# --------------------- Sensitivitätsanalysen zu W ------------------------
def show_W(dist, label):
    pW = marginal(dist, "W")
    print(f"{label}  P(W):", pretty_prob(pW, W_levels))

show_W(joint, "Baseline")

show_W(set_evidence(joint, {"U1":"Ja"}),   "U1=Ja")
show_W(set_evidence(joint, {"U1":"Nein"}), "U1=Nein")

show_W(set_evidence(joint, {"U2":"Ja"}),   "U2=Ja")
show_W(set_evidence(joint, {"U2":"Nein"}), "U2=Nein")

show_W(set_evidence(joint, {"U3":"Ja"}),   "U3=Ja")
show_W(set_evidence(joint, {"U3":"Nein"}), "U3=Nein")

show_W(set_evidence(joint, {"R1":"Ja"}),   "R1=Ja")
show_W(set_evidence(joint, {"R1":"Nein"}), "R1=Nein")

show_W(set_evidence(joint, {"R2":"Ja"}),   "R2=Ja")
show_W(set_evidence(joint, {"R2":"Nein"}), "R2=Nein")

# -------------------------- Inferenzanalyse II ---------------------------
# Evidenz: W = "h" (hoher Umsatzverlust)
post_W_h = set_evidence(joint, {"W":"h"})

# Posterior von R2 | W=h
P_R2_given_W_h = marginal(post_W_h, "R2")
print("\nPosterior P(R2 | W=h):", pretty_prob(P_R2_given_W_h, bin_levels))

# Joint von (R1,R2) | W=h
P_R1R2_given_W_h = joint_of(post_W_h, ["R1","R2"])
# Ausgabe als kleine Tabelle
print("\nJoint P(R1,R2 | W=h):")
for r1_idx, r1_lab in enumerate(bin_levels):
    for r2_idx, r2_lab in enumerate(bin_levels):
        print(f"  R1={r1_lab}, R2={r2_lab}: {P_R1R2_given_W_h[r1_idx, r2_idx]:.6f}")

# Optional: Visualisierung nach Evidenz W=h
plot_marginals(post_W_h, 'Evidenz: W = "h"')

# -------------------------- Inferenzanalyse III --------------------------
# Evidenz: U2="Nein", U3="Nein", W="h"  -> Posterior über (U1,R1,R2)
post_U2U3W = set_evidence(joint, {"U2":"Nein", "U3":"Nein", "W":"h"})
P_U1R1R2 = joint_of(post_U2U3W, ["U1","R1","R2"])

print("\nJoint P(U1,R1,R2 | U2=Nein, U3=Nein, W=h):")
for u1_idx, u1_lab in enumerate(bin_levels):
    for r1_idx, r1_lab in enumerate(bin_levels):
        for r2_idx, r2_lab in enumerate(bin_levels):
            print(f"  U1={u1_lab}, R1={r1_lab}, R2={r2_lab}: {P_U1R1R2[u1_idx, r1_idx, r2_idx]:.6f}")
# -------------------------------------------------------------------------
