"""
Group Lasso Factor Model for CTF Competition

This model applies Group Lasso regression to predict stock returns using
ECDF-transformed features. Features are organized into thematic groups
(Size, Value, Momentum, etc.) for group-level regularization.

Strategy:
- Uses Group Lasso to select informative feature groups
- Hyperparameter tuning via validation set (pre-1990 data)
- Decile-based long-short portfolio construction
- Optimizes long/short multipliers for Sharpe ratio

CTF Admin Modifications (2026-02-19):
--------------------------------------
1. Restructured from Jupyter notebook export to CTF-compliant main() function
   Reason: Original submission used WRDS database access and had no main() entry point

2. Removed keyring, sqlalchemy, psycopg2 dependencies for external database access
   Reason: CTF rules prohibit external data access; data is provided via function arguments

3. Added [CTF-DEBUG] progress statements throughout
   Reason: HPC jobs can run for hours; progress output is essential for debugging

4. Changed output column from 'asset_weight' to 'w'
   Reason: CTF output format requires columns (id, eom, w)

5. Added flush=True to all print statements
   Reason: Container output buffering can hide progress in HPC environments
"""

import numpy as np
import pandas as pd
from collections import Counter
from group_lasso import GroupLasso
from sklearn.metrics import r2_score
import time


def ecdf(data: pd.Series) -> pd.Series:
    """Compute empirical CDF ranks for a series."""
    if data.empty:
        return data
    sorted_data = data.sort_values()
    ranks = sorted_data.rank(method="max", pct=True)
    return pd.Series(ranks, index=data.index)


def prepare_data(chars: pd.DataFrame, features: list[str], eom_col: str) -> pd.DataFrame:
    """
    Transform features using cross-sectional ECDF ranking.

    For each month, ranks features across all stocks, preserving zeros
    and filling NaN with 0.5 (median rank). Centers by subtracting 0.5.
    """
    chars = chars.copy()
    for feature in features:
        is_zero = chars[feature] == 0
        chars[feature] = chars.groupby(eom_col)[feature].transform(
            lambda x: ecdf(x)
        )
        chars.loc[is_zero, feature] = 0
        chars[feature] = chars[feature].fillna(0.5)

    chars[features] = chars[features] - 0.5
    return chars


def assign_feature_groups(feature_list: list[str]) -> tuple[np.ndarray, dict]:
    """
    Assign features to thematic groups for Group Lasso regularization.

    Groups: Size, Value, Momentum, Risk/Volatility, Liquidity,
    Investment/Growth, Issuance, Accruals, Leverage, Profitability, Quality, Other
    """
    feature_to_group = {}
    group_map = {}
    group_id = 0

    def assign_group(condition, name):
        nonlocal group_id
        selected = [
            f for f in feature_list
            if condition(f) and f not in feature_to_group
        ]
        if selected:
            for f in selected:
                feature_to_group[f] = group_id
            group_map[group_id] = name
            group_id += 1

    # Size-related features
    assign_group(
        lambda f: f.endswith("_me") or f.endswith("_mev") or f == "market_equity",
        "Size"
    )

    # Value-related features
    assign_group(
        lambda f: "bev" in f or "intrinsic" in f or "ival" in f,
        "Value"
    )

    # Momentum-related features
    assign_group(
        lambda f: f.startswith("ret_") or "resff3" in f or "seas_" in f or "rmax" in f,
        "Momentum"
    )

    # Risk/Volatility features
    assign_group(
        lambda f: f.startswith("beta_") or
                  "ivol_" in f or
                  "rvol_" in f or
                  "iskew" in f or
                  "coskew" in f or
                  "corr_" in f,
        "Risk_Vol"
    )

    # Liquidity features
    assign_group(
        lambda f: "turnover" in f or
                  "dolvol" in f or
                  "zero_trades" in f or
                  "bidask" in f or
                  "ami_" in f or
                  "adv_" in f,
        "Liquidity"
    )

    # Investment/Growth features
    assign_group(
        lambda f: "_gr" in f or "_chg" in f or "_ch5" in f,
        "Investment_Growth"
    )

    # Issuance features
    assign_group(
        lambda f: "netis" in f or
                  "eqis" in f or
                  "eqpo" in f or
                  "fincf" in f,
        "Issuance"
    )

    # Accruals features
    assign_group(
        lambda f: "accrual" in f or
                  "noa" in f or
                  "cash_conversion" in f or
                  "dsale_" in f,
        "Accruals"
    )

    # Leverage features
    assign_group(
        lambda f: "debt" in f or
                  "kz_index" in f or
                  "int_debt" in f,
        "Leverage"
    )

    # Profitability features
    assign_group(
        lambda f: "_at" in f or "_be" in f or "_sale" in f or
                  "roe" in f or "roa" in f or
                  "profit" in f or
                  "qmj_prof" in f,
        "Profitability"
    )

    # Quality features
    assign_group(
        lambda f: "qmj" in f or "mispricing" in f or
                  "f_score" in f or "o_score" in f or "z_score" in f,
        "Quality"
    )

    # Remaining features
    leftover = [f for f in feature_list if f not in feature_to_group]
    if leftover:
        for f in leftover:
            feature_to_group[f] = group_id
        group_map[group_id] = "Other"

    groups = np.array([feature_to_group[f] for f in feature_list])
    return groups, group_map


def tune_group_lasso(X_train: np.ndarray, y_train: np.ndarray,
                     X_val: np.ndarray, y_val: np.ndarray,
                     groups: np.ndarray) -> float:
    """Tune Group Lasso regularization parameter via validation R2."""
    penalties = np.linspace(0.0005, 0.0003, 3)
    best_r2 = -np.inf
    best_reg = penalties[0]

    for i, reg in enumerate(penalties):
        print(f"[CTF-DEBUG] Tuning penalty {i+1}/{len(penalties)}: reg={reg:.6f}", flush=True)
        gl = GroupLasso(
            groups=groups, group_reg=reg, l1_reg=0.0,
            scale_reg=None, n_iter=1500, tol=1e-4,
            fit_intercept=True, supress_warning=True
        )
        gl.fit(X_train, y_train)
        val_preds = gl.predict(X_val)
        current_r2 = r2_score(y_val, val_preds)

        print(f"[CTF-DEBUG]   Validation R2: {current_r2:.6f}", flush=True)

        if current_r2 > best_r2:
            best_r2 = current_r2
            best_reg = reg

    print(f"[CTF-DEBUG] Best group_reg: {best_reg} (Val R2: {best_r2:.6f})", flush=True)
    return best_reg


def sharpe(x: pd.Series) -> float:
    """Compute annualized Sharpe ratio from monthly returns."""
    x = pd.Series(x).dropna()
    if x.std(ddof=1) == 0 or len(x) < 5:
        return np.nan
    return (x.mean() / x.std(ddof=1)) * np.sqrt(12)


def normalize_gross(w: np.ndarray) -> np.ndarray:
    """Normalize weights to sum of absolute values = 1."""
    w = np.asarray(w, dtype=float)
    s = np.sum(np.abs(w))
    return w / s if s != 0 else w


def ensure_10cols(dec: pd.DataFrame) -> pd.DataFrame:
    """Ensure decile DataFrame has columns 0-9."""
    dec = dec.copy()
    for d in range(10):
        if d not in dec.columns:
            dec[d] = np.nan
    return dec[sorted([c for c in dec.columns if isinstance(c, (int, np.integer))])]


def make_decile_ret(df_with_pred: pd.DataFrame) -> pd.DataFrame:
    """Compute monthly returns by decile."""
    tmp = df_with_pred.copy()
    tmp["decile"] = (
        tmp.groupby("eom")["pred"]
           .transform(lambda x: pd.qcut(x.rank(method="first"), 10, labels=False, duplicates="drop"))
    )
    dec = (
        tmp.groupby(["eom", "decile"])["ret_exc_lead1m"]
           .mean()
           .unstack()
    )
    return ensure_10cols(dec)


def eval_weights(panel: pd.DataFrame, w: np.ndarray) -> tuple[float, np.ndarray]:
    """Evaluate a decile weighting scheme."""
    w = normalize_gross(w)
    r = panel.values @ w
    return sharpe(r), r


def optimize_weights(decile_panel: pd.DataFrame) -> tuple[str, np.ndarray]:
    """
    Find optimal long/short multipliers for decile portfolios.

    Searches over combinations of:
    - Number of top deciles to go long (1-9)
    - Number of bottom deciles to go short (1-9)
    - Long multiplier (1, 2, 3, 4, 5, 7, 10)
    - Short multiplier (1, 2, 3, 4, 5, 7, 10)
    """
    long_mult_candidates = [1, 2, 3, 4, 5, 7, 10]
    short_mult_candidates = [1, 2, 3, 4, 5, 7, 10]

    candidates = []
    for k in range(1, 10):
        for m in range(1, 10):
            for L in long_mult_candidates:
                for S in short_mult_candidates:
                    w = np.zeros(10, dtype=float)
                    w[10-k:10] = +L
                    w[0:m] = -S
                    candidates.append((f"Top{k}x{L}_Bot{m}x{S}", w))

    print(f"[CTF-DEBUG] Evaluating {len(candidates)} weight configurations", flush=True)

    rows = []
    for name, w in candidates:
        sr_is, _ = eval_weights(decile_panel, w)
        if np.isfinite(sr_is):
            rows.append((name, sr_is, normalize_gross(w)))

    res_is = pd.DataFrame(rows, columns=["name", "Sharpe_InSample", "weights"])
    res_is = res_is.sort_values("Sharpe_InSample", ascending=False).reset_index(drop=True)

    best = res_is.iloc[0]
    print(f"[CTF-DEBUG] Best strategy: {best['name']} (Sharpe: {best['Sharpe_InSample']:.3f})", flush=True)

    return best["name"], best["weights"].astype(float)


def compute_portfolio_weights(df_with_pred: pd.DataFrame, best_w: np.ndarray) -> pd.DataFrame:
    """
    Compute per-asset portfolio weights based on decile assignments.

    Distributes decile weights equally among assets in each decile.
    Only assets in top and bottom deciles receive non-zero weights.
    """
    df = df_with_pred.copy()
    df["decile"] = (
        df.groupby("eom")["pred"]
        .transform(lambda x: pd.qcut(
            x.rank(method="first"),
            10,
            labels=False,
            duplicates="drop"
        ))
    )

    # Identify which deciles have non-zero weights
    active_deciles = [d for d in range(10) if best_w[d] != 0]

    df["w"] = 0.0
    for dt, group in df.groupby("eom"):
        group = group.copy()
        for d in active_deciles:
            mask = group["decile"] == d
            n_assets = mask.sum()
            if n_assets > 0:
                group.loc[mask, "w"] = best_w[d] / n_assets
        df.loc[group.index, "w"] = group["w"]

    return df[["id", "eom", "w"]].copy()


def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    CTF submission entry point for Group Lasso factor model.

    Args:
        chars: Stock characteristics from ctff_chars.parquet
        features: Feature metadata from ctff_features.parquet
        daily_ret: Daily returns from ctff_daily_ret.parquet

    Returns:
        DataFrame with columns: id, eom, w (portfolio weights)
    """
    start_time = time.time()
    print("[CTF-DEBUG] Starting Group Lasso factor model", flush=True)
    print(f"[CTF-DEBUG] Input shapes - chars: {chars.shape}, features: {features.shape}, daily_ret: {daily_ret.shape}", flush=True)

    # Extract feature list
    feature_list = features["features"].tolist()
    print(f"[CTF-DEBUG] Number of features: {len(feature_list)}", flush=True)

    # Prepare data
    print("[CTF-DEBUG] Preparing data with ECDF transformation", flush=True)
    df = chars.copy()
    df["eom"] = pd.to_datetime(df["eom"])
    df = prepare_data(df, feature_list, "eom")
    df = df.sort_values("eom")

    # Assign feature groups
    print("[CTF-DEBUG] Assigning feature groups", flush=True)
    groups, group_map = assign_feature_groups(feature_list)
    counts = Counter(groups)
    print(f"[CTF-DEBUG] Number of groups: {len(group_map)}", flush=True)
    for gid in sorted(group_map.keys()):
        print(f"[CTF-DEBUG]   {group_map[gid]:<20}: {counts[gid]} features", flush=True)

    # Define time splits
    train_end = pd.to_datetime("1984-12-31")
    val_end = pd.to_datetime("1989-11-30")

    # Split data
    df_train = df[df["eom"] <= train_end]
    df_val = df[(df["eom"] > train_end) & (df["eom"] <= val_end)]
    df_test = df[df["eom"] > val_end]

    print(f"[CTF-DEBUG] Data splits - Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}", flush=True)

    # Prepare arrays
    X_train = df_train[feature_list].values
    y_train = df_train["ret_exc_lead1m"].values
    X_val = df_val[feature_list].values
    y_val = df_val["ret_exc_lead1m"].values
    X_trainval = df[df["eom"] <= val_end][feature_list].values
    y_trainval = df[df["eom"] <= val_end]["ret_exc_lead1m"].values
    X_test = df_test[feature_list].values

    # Tune hyperparameters
    print("[CTF-DEBUG] Tuning Group Lasso hyperparameters", flush=True)
    best_reg = tune_group_lasso(X_train, y_train, X_val, y_val, groups)

    # Train final model on train+val
    print("[CTF-DEBUG] Training final model on train+val data", flush=True)
    gl_final = GroupLasso(
        groups=groups, group_reg=best_reg, l1_reg=0.0,
        scale_reg=None, n_iter=1500, tol=1e-4,
        fit_intercept=True, supress_warning=True
    )
    gl_final.fit(X_trainval, y_trainval)
    print(f"[CTF-DEBUG] Active coefficients: {np.sum(np.abs(gl_final.coef_) > 1e-8)}", flush=True)

    # Generate predictions for weight optimization
    df_trainval = df[df["eom"] <= val_end].copy()
    df_trainval["pred"] = gl_final.predict(X_trainval)

    # Optimize decile weights
    print("[CTF-DEBUG] Optimizing decile weights on train+val period", flush=True)
    decile_trainval = make_decile_ret(df_trainval)
    best_name, best_w = optimize_weights(decile_trainval)

    # Generate test predictions
    print("[CTF-DEBUG] Generating predictions for test period", flush=True)
    df_test_pred = df_test.copy()
    df_test_pred["pred"] = gl_final.predict(X_test)

    # Compute portfolio weights
    print("[CTF-DEBUG] Computing portfolio weights", flush=True)
    output = compute_portfolio_weights(df_test_pred, best_w)
    output = output.sort_values(["eom", "id"]).reset_index(drop=True)

    # Ensure correct types
    output["id"] = output["id"].astype(int)
    output["w"] = output["w"].astype(float)

    elapsed = time.time() - start_time
    print(f"[CTF-DEBUG] Completed in {elapsed:.1f}s", flush=True)
    print(f"[CTF-DEBUG] Output shape: {output.shape}", flush=True)
    print(f"[CTF-DEBUG] Date range: {output['eom'].min()} to {output['eom'].max()}", flush=True)
    print(f"[CTF-DEBUG] Non-zero weights: {(output['w'] != 0).sum()}", flush=True)

    return output
