"""
Elastic Net with Inverse Variance Weights

CTF Admin Modifications (2026-02-19):
--------------------------------------
1. Added [CTF-DEBUG] progress statements for HPC monitoring
   Reason: Long-running jobs need progress output for debugging failures
"""
import numpy as np
import pandas as pd
from sklearn.linear_model import ElasticNet
import time

ID_COL = "id"
EOM_COL = "eom"
TEST_COL = "ctff_test"
TARGET_COL = "ret_exc_lead1m"
FEATURES_LIST_COL = "features"

DAILY_DATE_COL = "date"
DAILY_RET_COL = "ret_exc"

ROLLING_MONTHS = 60
TUNE_TRAIN_MONTHS = 48
TUNE_VAL_MONTHS = 12
RETUNE_EVERY_MONTHS = 120  # 10 years

DAILY_VAR_YEARS = 5
MIN_DAILY_OBS_PER_ASSET = 200

ALPHA_GRID = np.logspace(-5, 1, 10)
L1_GRID = [1 / 5, 2 / 5, 3 / 5, 4 / 5]
TOP_PCT = 0.10


def _impute_cs_median(df: pd.DataFrame, cols, group_col: str) -> pd.DataFrame:
    med = df.groupby(group_col)[cols].transform("median")
    df[cols] = df[cols].fillna(med).fillna(0.0)
    return df


def _rank_scale_minus1_1(df: pd.DataFrame, cols, group_col: str) -> pd.DataFrame:
    pct = df.groupby(group_col)[cols].transform(lambda s: s.rank(method="average", pct=True))
    df[cols] = (2.0 * pct - 1.0).fillna(0.0)
    return df


def _demean_cs(df: pd.DataFrame, y_col: str, group_col: str, out_col: str) -> pd.DataFrame:
    df[out_col] = (df[y_col] - df.groupby(group_col)[y_col].transform("mean")).fillna(0.0)
    return df


def _mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    d = y_true - y_pred
    return float(np.mean(d * d))


def _select_hparams(train_df: pd.DataFrame, val_df: pd.DataFrame, feature_cols):
    X_tr = train_df[feature_cols].to_numpy()
    y_tr = train_df["ret_cs_demeaned"].to_numpy()
    X_va = val_df[feature_cols].to_numpy()
    y_va = val_df["ret_cs_demeaned"].to_numpy()

    best = (None, None, np.inf)
    for a in ALPHA_GRID:
        for l1 in L1_GRID:
            m = ElasticNet(alpha=a, l1_ratio=l1, fit_intercept=True, max_iter=3000)
            m.fit(X_tr, y_tr)
            mse = _mse(y_va, m.predict(X_va))
            if mse < best[2]:
                best = (a, l1, mse)
    return best


def _fit_model(fit_df: pd.DataFrame, feature_cols, alpha: float, l1_ratio: float) -> ElasticNet:
    m = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, fit_intercept=True, max_iter=3000)
    m.fit(fit_df[feature_cols].to_numpy(), fit_df["ret_cs_demeaned"].to_numpy())
    return m


def _inv_var_weights(daily: pd.DataFrame, ids: np.ndarray, end_date: pd.Timestamp) -> pd.Series:
    ids = np.asarray(ids, dtype=int)
    if ids.size == 0:
        return pd.Series(dtype=float)

    start_date = end_date - pd.DateOffset(years=DAILY_VAR_YEARS)
    dsub = daily.loc[
        (daily[DAILY_DATE_COL] > start_date)
        & (daily[DAILY_DATE_COL] <= end_date)
        & (daily[ID_COL].isin(ids)),
        [ID_COL, DAILY_RET_COL],
    ]
    if dsub.empty:
        return pd.Series(1.0 / ids.size, index=ids)

    stats = dsub.groupby(ID_COL)[DAILY_RET_COL].agg(["var", "count"])
    stats = stats[(stats["count"] >= MIN_DAILY_OBS_PER_ASSET) & (stats["var"] > 0)]
    inv = (1.0 / stats["var"]).replace([np.inf, -np.inf], np.nan).dropna()
    if inv.empty:
        return pd.Series(1.0 / ids.size, index=ids)

    w = inv / inv.sum()
    return w


def _equal_weight_month(test_month_df: pd.DataFrame) -> pd.DataFrame:
    out = test_month_df[[ID_COL, EOM_COL]].copy()
    n = len(out)
    if n == 0:
        return pd.DataFrame(columns=[ID_COL, EOM_COL, "w"])
    out["w"] = 1.0 / n
    return out


def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    _start_time = time.time()
    print(f"[CTF-DEBUG] main() started at {time.strftime('%Y-%m-%d %H:%M:%S')}", flush=True)
    df = chars.copy()
    df[EOM_COL] = pd.to_datetime(df[EOM_COL])
    df[TEST_COL] = df[TEST_COL].astype(bool)

    daily = daily_ret[[ID_COL, DAILY_DATE_COL, DAILY_RET_COL]].copy()
    daily[DAILY_DATE_COL] = pd.to_datetime(daily[DAILY_DATE_COL])

    feat_list = features[FEATURES_LIST_COL].tolist()
    feature_cols = [c for c in feat_list if c in df.columns]

    months_test = sorted(df.loc[df[TEST_COL], EOM_COL].dropna().unique())
    print(f"[CTF-DEBUG] Test months: {len(months_test)}, features: {len(feature_cols)}", flush=True)
    if len(months_test) == 0:
        return pd.DataFrame(columns=[ID_COL, EOM_COL, "w"])

    if len(feature_cols) == 0:
        outs = []
        for m in months_test:
            outs.append(_equal_weight_month(df[(df[EOM_COL] == m) & (df[TEST_COL])]))
        return pd.concat(outs, ignore_index=True)[[ID_COL, EOM_COL, "w"]]

    df = _impute_cs_median(df, feature_cols, EOM_COL)
    df = _rank_scale_minus1_1(df, feature_cols, EOM_COL)
    df = _demean_cs(df, TARGET_COL, EOM_COL, "ret_cs_demeaned")

    months_all = sorted(df[EOM_COL].dropna().unique())
    month_to_idx = {m: i for i, m in enumerate(months_all)}

    best_alpha = None
    best_l1 = None
    last_retune_idx = None

    outs = []
    _n_months = len(months_test)

    for _i, pred_month in enumerate(months_test):
        if _i % 50 == 0 or _i == _n_months - 1:
            print(f"[CTF-DEBUG] Processing month {_i+1}/{_n_months} ({pred_month})", flush=True)
        idx = month_to_idx.get(pred_month, None)
        test_df = df[(df[EOM_COL] == pred_month) & (df[TEST_COL])].dropna(subset=feature_cols)
        if idx is None or test_df.empty:
            continue

        if idx < ROLLING_MONTHS:
            outs.append(_equal_weight_month(test_df))
            continue

        need_retune = (
            best_alpha is None
            or last_retune_idx is None
            or (idx - last_retune_idx) >= RETUNE_EVERY_MONTHS
        )

        if need_retune:
            hist = months_all[idx - ROLLING_MONTHS : idx]
            train_months = hist[:TUNE_TRAIN_MONTHS]
            val_months = hist[TUNE_TRAIN_MONTHS:]

            train_df = df[df[EOM_COL].isin(train_months)].dropna(subset=feature_cols + ["ret_cs_demeaned"])
            val_df = df[df[EOM_COL].isin(val_months)].dropna(subset=feature_cols + ["ret_cs_demeaned"])

            if (not train_df.empty) and (not val_df.empty):
                best_alpha, best_l1, _ = _select_hparams(train_df, val_df, feature_cols)
                last_retune_idx = idx

        if best_alpha is None:
            outs.append(_equal_weight_month(test_df))
            continue

        fit_months = months_all[idx - ROLLING_MONTHS : idx]
        fit_df = df[df[EOM_COL].isin(fit_months)].dropna(subset=feature_cols + ["ret_cs_demeaned"])
        if fit_df.empty:
            outs.append(_equal_weight_month(test_df))
            continue

        model = _fit_model(fit_df, feature_cols, best_alpha, best_l1)
        mu = model.predict(test_df[feature_cols].to_numpy())

        preds = test_df[[ID_COL, EOM_COL]].copy()
        preds["mu_hat"] = mu
        preds = preds.dropna(subset=["mu_hat"]).sort_values("mu_hat", ascending=False).reset_index(drop=True)

        n = len(preds)
        if n == 0:
            outs.append(_equal_weight_month(test_df))
            continue

        k = max(1, int(np.floor(TOP_PCT * n)))
        top_ids = preds.iloc[:k][ID_COL].astype(int).to_numpy()

        w = _inv_var_weights(daily, top_ids, end_date=pd.Timestamp(pred_month))
        out = pd.DataFrame({ID_COL: w.index.astype(int), EOM_COL: pd.Timestamp(pred_month), "w": w.to_numpy(dtype=float)})

        s = float(out["w"].sum())
        if s > 0:
            out["w"] = out["w"] / s
        else:
            out["w"] = 1.0 / len(out)

        outs.append(out)

    out_df = pd.concat(outs, ignore_index=True) if outs else pd.DataFrame(columns=[ID_COL, EOM_COL, "w"])
    out_df[ID_COL] = out_df[ID_COL].astype(int)
    out_df[EOM_COL] = pd.to_datetime(out_df[EOM_COL])
    out_df["w"] = out_df["w"].astype(float)

    out_df = out_df.merge(df.loc[df[TEST_COL], [ID_COL, EOM_COL]], on=[ID_COL, EOM_COL], how="inner")
    out_df["w"] = out_df["w"] / out_df.groupby(EOM_COL)["w"].transform("sum")
    _elapsed = time.time() - _start_time
    print(f"[CTF-DEBUG] main() completed in {_elapsed:.1f}s, output rows: {len(out_df)}", flush=True)
    return out_df[[ID_COL, EOM_COL, "w"]]
