#!/usr/bin/env python3
"""
Reproduce every figure in the NutriShot AI photo-logging accuracy validation.

Usage:
    python3 analyze-validation.py            # expects nutrishot-validation-meals.csv alongside this file

The CSV holds 412 weighed meals. For each meal it records the gram-weighed
reference (computed from USDA FoodData Central), NutriShot's raw photo estimate,
and NutriShot's estimate after the in-app portion-correction step. This script
recomputes the calorie MAPE, the within-tolerance rates, the macro errors, and
the per-category / per-cuisine breakdowns reported on the published page.
"""
import csv, math, os, statistics as st

HERE = os.path.dirname(os.path.abspath(__file__))
CSV = os.path.join(HERE, "nutrishot-validation-meals.csv")


def load():
    with open(CSV) as f:
        return list(csv.DictReader(f))


def ci95(xs):
    """Mean and 95% confidence interval (normal approximation)."""
    m = st.mean(xs)
    h = 1.96 * st.pstdev(xs) / math.sqrt(len(xs))
    return m, m - h, m + h


def main():
    rows = load()
    n = len(rows)
    f = lambda r, k: float(r[k])

    raw_ape, corr_ape, raw_bias, corr_bias = [], [], [], []
    w10r = w20r = w10c = w20c = 0
    for r in rows:
        ref, raw, cor = f(r, "ref_calories_kcal"), f(r, "raw_est_kcal"), f(r, "corrected_est_kcal")
        ra, ca = abs(raw - ref) / ref * 100, abs(cor - ref) / ref * 100
        raw_ape.append(ra); corr_ape.append(ca)
        raw_bias.append(raw - ref); corr_bias.append(cor - ref)
        w10r += ra <= 10; w20r += ra <= 20; w10c += ca <= 10; w20c += ca <= 20

    m_raw, _, _ = ci95(raw_ape)
    m_cor, lo, hi = ci95(corr_ape)
    print(f"N meals: {n}\n")
    print("CALORIE ACCURACY")
    print(f"  MAPE raw        {m_raw:5.2f}%")
    print(f"  MAPE corrected  {m_cor:5.2f}%  (95% CI {lo:.2f}-{hi:.2f})")
    print(f"  within +-10%    raw {w10r/n*100:.1f}%   corrected {w10c/n*100:.1f}%")
    print(f"  within +-20%    raw {w20r/n*100:.1f}%   corrected {w20c/n*100:.1f}%")
    print(f"  mean bias       raw {st.mean(raw_bias):+.1f} kcal   corrected {st.mean(corr_bias):+.1f} kcal")
    print(f"  correction lift {m_raw - m_cor:.2f} pts ({(m_raw - m_cor)/m_raw*100:.0f}% relative)\n")

    print("MACRONUTRIENT ACCURACY (corrected)")
    for macro, rk, ek in [("Protein", "ref_protein_g", "est_protein_g"),
                          ("Carb", "ref_carb_g", "est_carb_g"),
                          ("Fat", "ref_fat_g", "est_fat_g")]:
        apes, w10, cnt = [], 0, 0
        for r in rows:
            ref, est = f(r, rk), f(r, ek)
            if ref <= 0:
                continue
            a = abs(est - ref) / ref * 100
            apes.append(a); w10 += a <= 10; cnt += 1
        print(f"  {macro:8s} MAPE {st.mean(apes):5.2f}%  within10 {w10/cnt*100:4.1f}%  (n={cnt})")
    print()

    for label, key in [("CATEGORY", "category"), ("CUISINE", "cuisine")]:
        groups = {}
        for r in rows:
            ape = abs(f(r, "corrected_est_kcal") - f(r, "ref_calories_kcal")) / f(r, "ref_calories_kcal") * 100
            groups.setdefault(r[key], []).append(ape)
        print(f"BY {label} (corrected calorie MAPE)")
        for g, v in sorted(groups.items(), key=lambda x: st.mean(x[1])):
            print(f"  {g:22s} {st.mean(v):5.2f}%  (n={len(v)})")
        print()


if __name__ == "__main__":
    main()
