# main.py  –  CTF Competition Submission
# ═══════════════════════════════════════════════════════════════════
# Model: Adaptive Ensemble with Dynamic Risk Management
# Author: Angikar Ghosal, Stanford University
# ═══════════════════════════════════════════════════════════════════
#
# CTF Admin Modifications (2026-02-23):
# --------------------------------------
# 1. Modified get_feature_tiers() to handle validation dataset
#    Reason: The validation dataset has different feature names than
#    production. When no features match IC_DATA, the function returned
#    empty lists causing QuantileTransformer to fail with 0 features.
#    Now falls back to using all available features.
#
# 2. Added [CTF-DEBUG] progress statements throughout main()
#    Reason: HPC jobs can run for hours; progress output is essential
#    for debugging failures and monitoring execution.

import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from sklearn.preprocessing import QuantileTransformer
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings("ignore")

SEED = 42
np.random.seed(SEED)

# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
RIDGE_ALPHA         = 10.0
MIN_TRAIN_MONTHS    = 36
RETRAIN_FREQ        = 12
TOP_PCTILE          = 0.10
BOT_PCTILE          = 0.10
MAX_WEIGHT          = 0.05
WINSOR_PCTILE       = 0.01
IC_STRONG           = 0.03
IC_MEDIUM           = 0.01
W_RIDGE             = 0.40
W_XGB               = 0.40
W_IC                = 0.20
SHORT_SCALE_MIN     = 0.20
SHORT_SCALE_MAX     = 1.20
SHORT_SCALE_DEFAULT = 1.00
SHORT_SCALE_WINDOW  = 36
MAX_NET_EXPOSURE    = 0.30
TARGET_VOL          = 0.10
VOL_LOOKBACK        = 36
VOL_SCALE_MIN       = 0.25
VOL_SCALE_MAX       = 1.50
DD_THRESHOLD_1      = 0.10
DD_THRESHOLD_2      = 0.20
DD_THRESHOLD_3      = 0.30
DD_SCALE_1          = 0.75
DD_SCALE_2          = 0.50
DD_SCALE_3          = 0.25

try:
    from xgboost import XGBRegressor
    HAS_XGB = True
except ImportError:
    HAS_XGB = False
    W_RIDGE += W_XGB / 2
    W_IC    += W_XGB / 2
    W_XGB    = 0.0

IC_DATA = {
    "rmax5_21d": -0.0587, "ret_12_1": 0.0546, "ret_1_0": -0.0519,
    "rmax1_21d": -0.0503, "ivol_hxz4_21d": -0.0477, "ivol_capm_21d": -0.0462,
    "ivol_ff3_21d": -0.0460, "rvol_21d": -0.0441, "ret_9_1": 0.0431,
    "rvol_252d": -0.0427, "ivol_capm_252d": -0.0420, "resff3_12_1": 0.0413,
    "eqnpo_me": 0.0412, "ret_6_1": 0.0405, "rmax5_rvol_21d": -0.0403,
    "rvolhl_21d": -0.0399, "mispricing_perf": 0.0393, "ret_12_7": 0.0390,
    "eqpo_me": 0.0389, "ret_12_0": 0.0375, "ret_18_1": 0.0374,
    "cop_mev": 0.0368, "cop_me": 0.0361, "div3m_me": 0.0343,
    "ret_2_0": -0.0341, "div6m_me": 0.0334, "seas_1_1an": 0.0333,
    "cop_bev": 0.0328, "cop_at": 0.0326, "eqnpo_at": 0.0325,
    "div12m_me": 0.0320, "ivol_capm_60m": -0.0320, "cop_atl1": 0.0318,
    "ebit_mev": 0.0315, "ebitda_mev": 0.0314, "ni_me": 0.0312,
    "ocf_mev": 0.0308, "betabab_1260d": -0.0308, "nix_me": 0.0304,
    "ret_24_1": 0.0302,
}

WEAK_FEATURES = {
    "debt_mev", "xido_at", "ol_gr1a", "int_debtlt",
    "txditc_gr3a", "debt_me", "rd_sale", "txditc_gr1a",
    "at_be", "nwc_gr1a",
}

# ═══════════════════════════════════════════════════════════════════
# Risk Management: Volatility Targeting
# ═══════════════════════════════════════════════════════════════════
def compute_vol_scale(portfolio_returns, target_vol=TARGET_VOL, 
                      lookback=VOL_LOOKBACK, min_scale=VOL_SCALE_MIN, 
                      max_scale=VOL_SCALE_MAX):
    if len(portfolio_returns) < 12:
        return 1.0
    recent = portfolio_returns[-lookback:] if len(portfolio_returns) > lookback else portfolio_returns
    realized_vol = np.std(recent) * np.sqrt(12)
    if realized_vol < 0.01:
        return 1.0
    raw_scale = target_vol / realized_vol
    return float(np.clip(raw_scale, min_scale, max_scale))

# ═══════════════════════════════════════════════════════════════════
# Risk Management: Drawdown Scaling
# ═══════════════════════════════════════════════════════════════════
def compute_drawdown_scale(portfolio_returns):
    if len(portfolio_returns) < 2:
        return 1.0
    wealth = np.cumprod(1.0 + np.array(portfolio_returns))
    running_max = np.maximum.accumulate(wealth)
    current_dd = (wealth[-1] - running_max[-1]) / running_max[-1]
    dd_pct = abs(current_dd)
    if dd_pct < DD_THRESHOLD_1:
        return 1.00
    elif dd_pct < DD_THRESHOLD_2:
        return DD_SCALE_1
    elif dd_pct < DD_THRESHOLD_3:
        return DD_SCALE_2
    else:
        return DD_SCALE_3

# ═══════════════════════════════════════════════════════════════════
# Risk Management: Dynamic Short Scale
# ═══════════════════════════════════════════════════════════════════
def compute_dynamic_short_scale(leg_history, as_of_date, window=SHORT_SCALE_WINDOW,
                                 min_scale=SHORT_SCALE_MIN, max_scale=SHORT_SCALE_MAX,
                                 default=SHORT_SCALE_DEFAULT):
    if len(leg_history) < 12:
        return default
    df = pd.DataFrame(leg_history)
    df["eom"] = pd.to_datetime(df["eom"])
    recent = df[(df["eom"] < as_of_date) & 
                (df["eom"] >= as_of_date - pd.DateOffset(months=window))]
    if len(recent) < 12:
        return default
    
    def ann_sharpe(r):
        r = r.dropna()
        if len(r) < 6 or r.std() < 1e-10:
            return 0.0
        return (r.mean() * 12) / (r.std() * np.sqrt(12))
    
    long_sr = ann_sharpe(recent["ret_long"])
    short_sr = ann_sharpe(recent["ret_short"])
    if abs(long_sr) < 0.1:
        return default
    raw_scale = short_sr / long_sr
    return float(np.clip(raw_scale, min_scale, max_scale))

# ═══════════════════════════════════════════════════════════════════
# Feature Selection
# ═══════════════════════════════════════════════════════════════════
def get_feature_tiers(feat_cols):
    ic_series = pd.Series(IC_DATA).reindex(feat_cols).fillna(0.0)
    strong_feats = [f for f in ic_series[ic_series.abs() >= IC_STRONG].index
                    if f not in WEAK_FEATURES]
    medium_feats = [f for f in ic_series[ic_series.abs() >= IC_MEDIUM].index
                    if f not in WEAK_FEATURES]
    # Fallback: if no features match IC_DATA (e.g., validation dataset),
    # use all available features to avoid empty feature set
    if not medium_feats:
        medium_feats = list(feat_cols)
        ic_series = pd.Series(1.0 / len(feat_cols), index=feat_cols)
    if not strong_feats:
        strong_feats = medium_feats[:min(10, len(medium_feats))]
    return strong_feats, medium_feats, ic_series

def winsorise(y, pct=WINSOR_PCTILE):
    lo, hi = np.quantile(y, [pct, 1 - pct])
    return np.clip(y, lo, hi)

def cs_impute(X_raw, med):
    X = X_raw.copy()
    m = med.ravel() if med.ndim > 1 else med
    for j in range(X.shape[1]):
        mask = np.isnan(X[:, j])
        X[mask, j] = m[j]
    return np.nan_to_num(X, nan=0.0)

# ═══════════════════════════════════════════════════════════════════
# Models
# ═══════════════════════════════════════════════════════════════════
def make_ridge(alpha):
    return Pipeline([
        ("qt", QuantileTransformer(n_quantiles=500, output_distribution="normal",
                                   random_state=SEED)),
        ("ridge", Ridge(alpha=alpha, fit_intercept=True)),
    ])

def make_xgb():
    return XGBRegressor(
        n_estimators=500, max_depth=4, learning_rate=0.02,
        subsample=0.8, colsample_bytree=0.5, min_child_weight=30,
        gamma=1.0, reg_lambda=5.0, random_state=SEED, n_jobs=-1, verbosity=0,
    )

def ic_composite(X, ic_vals):
    n, f = X.shape
    ranked = np.empty_like(X, dtype=np.float32)
    for j in range(f):
        order = np.argsort(X[:, j])
        r = np.empty(n, dtype=np.float32)
        r[order] = np.arange(n, dtype=np.float32)
        ranked[:, j] = r / max(n - 1, 1) - 0.5
    w = np.abs(ic_vals) / (np.abs(ic_vals).sum() + 1e-12)
    return (ranked * w[None, :]).sum(axis=1)

def zscore(x):
    s = x.std()
    return (x - x.mean()) / s if s > 1e-10 else np.zeros_like(x)

def blend(p_ridge, p_xgb, p_ic):
    return W_RIDGE * zscore(p_ridge) + W_XGB * zscore(p_xgb) + W_IC * zscore(p_ic)

def compute_inv_vol(daily_ret, as_of_date, lookback=252):
    cutoff = as_of_date - pd.Timedelta(days=int(lookback * 1.5))
    sub = daily_ret[(daily_ret["date"] < as_of_date) & 
                    (daily_ret["date"] >= cutoff)]
    vol = sub.groupby("id")["ret_exc"].std().mul(np.sqrt(252)).replace(0, np.nan)
    inv_vol = (1.0 / vol).fillna(1.0)
    return inv_vol / (inv_vol.mean() + 1e-12)

# ═══════════════════════════════════════════════════════════════════
# Portfolio Construction
# ═══════════════════════════════════════════════════════════════════
def build_portfolio(signal, ids, inv_vol, short_scale, global_scale,
                    max_net=MAX_NET_EXPOSURE, top_pct=TOP_PCTILE,
                    bot_pct=BOT_PCTILE, max_w=MAX_WEIGHT):
    n = len(signal)
    if n < 10:
        return np.zeros(n)
    
    order = np.argsort(signal)
    pct_rank = np.empty(n, dtype=np.float64)
    pct_rank[order] = np.arange(n) / max(n - 1, 1)
    
    long_mask = pct_rank >= (1.0 - top_pct)
    short_mask = pct_rank <= bot_pct
    
    if not long_mask.any() or not short_mask.any():
        return np.zeros(n)
    
    long_raw = pct_rank[long_mask] - (1.0 - top_pct)
    short_raw = bot_pct - pct_rank[short_mask]
    
    w = np.zeros(n, dtype=np.float64)
    w[long_mask] = long_raw / long_raw.sum()
    w[short_mask] = -short_raw / short_raw.sum() * short_scale
    
    if inv_vol is not None and len(inv_vol) > 0:
        iv = np.array([inv_vol.get(int(i), 1.0) for i in ids], dtype=np.float64)
        iv /= iv.mean() + 1e-12
        li = np.where(w > 0)[0]
        si = np.where(w < 0)[0]
        if len(li) > 0 and w[li].sum() > 1e-10:
            w[li] *= iv[li]
            w[li] /= w[li].sum()
        if len(si) > 0 and (-w[si]).sum() > 1e-10:
            w[si] *= iv[si]
            w[si] /= (-w[si]).sum() / short_scale
    
    current_net = w.sum()
    if current_net > max_net:
        target_short_sum = max_net - 1.0
        current_short_sum = w[w < 0].sum()
        if current_short_sum < -1e-10:
            w[w < 0] *= target_short_sum / current_short_sum
    elif current_net < -max_net:
        target_long_sum = 1.0 + max_net
        current_long_sum = w[w > 0].sum()
        if current_long_sum > 1e-10:
            w[w > 0] *= target_long_sum / current_long_sum
    
    for _ in range(10):
        prev = w.copy()
        w = np.clip(w, -max_w, max_w)
        long_target = prev[prev > 0].sum()
        short_target = prev[prev < 0].sum()
        ls = w[w > 0].sum()
        ss = w[w < 0].sum()
        if ls > 1e-10 and long_target > 1e-10:
            w[w > 0] *= long_target / ls
        if ss < -1e-10 and short_target < -1e-10:
            w[w < 0] *= short_target / ss
        if np.allclose(w, prev, atol=1e-10):
            break
    
    w = np.clip(w, -max_w, max_w)
    w *= global_scale
    w = np.clip(w, -max_w, max_w)
    return w

# ═══════════════════════════════════════════════════════════════════
# MAIN FUNCTION (CTF Entry Point)
# ═══════════════════════════════════════════════════════════════════
def main(chars: pd.DataFrame, features: pd.DataFrame,
         daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    CTF Competition Entry Point

    Args:
        chars: Stock characteristics (ctff_chars.parquet)
        features: Feature names (ctff_features.parquet)
        daily_ret: Daily returns (ctff_daily_ret.parquet)

    Returns:
        DataFrame with columns [id, eom, w]
    """
    import time
    _start_time = time.time()
    print(f"[CTF-DEBUG] Starting main() at {time.strftime('%Y-%m-%d %H:%M:%S')}", flush=True)

    np.random.seed(SEED)

    feat_cols = features["features"].tolist()
    print(f"[CTF-DEBUG] Loaded {len(feat_cols)} features", flush=True)
    chars = chars.copy()
    chars["eom"] = pd.to_datetime(chars["eom"])
    daily_ret = daily_ret.copy()
    daily_ret["date"] = pd.to_datetime(daily_ret["date"])
    chars.sort_values(["eom", "id"], inplace=True)
    chars.reset_index(drop=True, inplace=True)
    
    dates = np.sort(chars["eom"].unique())
    TARGET = "ret_exc_lead1m"
    
    strong_feats, medium_feats, ic_series = get_feature_tiers(feat_cols)
    print(f"[CTF-DEBUG] Using {len(strong_feats)} strong features, {len(medium_feats)} medium features", flush=True)
    ic_weights = ic_series[strong_feats].values.astype(np.float32)
    medians_all = chars.groupby("eom")[medium_feats].median()
    
    r_model = x_model = last_fit = None
    _strong_idx = []
    leg_history = []
    portfolio_returns = []
    results = []
    n_dates = len(dates)
    print(f"[CTF-DEBUG] Processing {n_dates} dates", flush=True)

    for i, date_np in enumerate(dates):
        if i % 100 == 0 or i == n_dates - 1:
            print(f"[CTF-DEBUG] Processing date {i+1}/{n_dates} ({100*(i+1)/n_dates:.1f}%)", flush=True)
        date = pd.Timestamp(date_np)
        pred_mask = chars["eom"] == date
        pred_df = chars[pred_mask]
        if pred_df.empty:
            continue
        
        train_mask = ((chars["eom"] < date) & (~chars["ctff_test"]) & 
                     (chars[TARGET].notna()))
        n_train_months = chars.loc[train_mask, "eom"].nunique()
        
        should_fit = (n_train_months >= MIN_TRAIN_MONTHS and
                     (last_fit is None or i % RETRAIN_FREQ == 0))
        
        if should_fit:
            print(f"[CTF-DEBUG] Retraining models at {date.strftime('%Y-%m')} ({n_train_months} train months)", flush=True)
            train_df = chars[train_mask]
            X_med = medians_all.reindex(train_df["eom"].values).values.astype(np.float32)
            X_tr = cs_impute(train_df[medium_feats].values.astype(np.float32), X_med)
            y_tr = winsorise(train_df[TARGET].values.astype(np.float32))

            r_model = make_ridge(RIDGE_ALPHA)
            r_model.fit(X_tr, y_tr)

            _strong_idx = [medium_feats.index(f) for f in strong_feats
                          if f in medium_feats]
            if HAS_XGB:
                x_model = make_xgb()
                x_model.fit(X_tr[:, _strong_idx], y_tr)
            last_fit = date
        
        if date in medians_all.index:
            med_now = medians_all.loc[date].values.astype(np.float32)
        else:
            med_now = np.nanmedian(pred_df[medium_feats].values.astype(np.float32), 
                                  axis=0)
        
        X_pred = cs_impute(pred_df[medium_feats].values.astype(np.float32), med_now)
        
        if r_model is None:
            n = len(pred_df)
            results.append(pd.DataFrame({
                "id": pred_df["id"].values, "eom": date, 
                "w": np.full(n, 1.0 / n)}))
            equal_w = 1.0 / n
            ret_this_month = float((equal_w * pred_df[TARGET].values).sum())
            portfolio_returns.append(ret_this_month)
            continue
        
        short_scale = compute_dynamic_short_scale(leg_history, date)
        vol_scale = compute_vol_scale(portfolio_returns)
        dd_scale = compute_drawdown_scale(portfolio_returns)
        global_scale = vol_scale * dd_scale
        
        p_ridge = r_model.predict(X_pred)
        p_xgb = (x_model.predict(X_pred[:, _strong_idx])
                if HAS_XGB and x_model is not None and _strong_idx
                else p_ridge.copy())
        p_ic = ic_composite(X_pred[:, _strong_idx] if _strong_idx else X_pred,
                           ic_weights)
        
        composite = blend(p_ridge, p_xgb, p_ic)
        inv_vol = compute_inv_vol(daily_ret, date)
        w = build_portfolio(composite, pred_df["id"].values, inv_vol, 
                           short_scale, global_scale)
        
        actual_ret = pred_df[TARGET].values
        long_mask_w = w > 0
        short_mask_w = w < 0
        
        ret_long = float((w[long_mask_w] * actual_ret[long_mask_w]).sum()) \
                   if long_mask_w.any() else 0.0
        ret_short = float((w[short_mask_w] * actual_ret[short_mask_w]).sum()) \
                    if short_mask_w.any() else 0.0
        ret_total = ret_long + ret_short
        
        leg_history.append({
            "eom": date, "ret_long": ret_long, "ret_short": ret_short,
            "short_scale": short_scale, "vol_scale": vol_scale,
            "dd_scale": dd_scale, "global_scale": global_scale,
            "net": float(w.sum()),
        })
        portfolio_returns.append(ret_total)
        
        results.append(pd.DataFrame({"id": pred_df["id"].values, 
                                    "eom": date, "w": w}))
    
    if not results:
        print("[CTF-DEBUG] WARNING: No results produced!", flush=True)
        return pd.DataFrame(columns=["id", "eom", "w"])

    out = pd.concat(results, ignore_index=True)
    out["id"] = out["id"].astype(int)
    out["eom"] = pd.to_datetime(out["eom"])
    out["w"] = out["w"].astype(float).fillna(0.0)
    out = out[out["w"] != 0.0].copy()

    _elapsed = time.time() - _start_time
    print(f"[CTF-DEBUG] Completed in {_elapsed:.1f}s", flush=True)
    print(f"[CTF-DEBUG] Output: {len(out)} rows, {out['eom'].nunique()} unique dates", flush=True)
    print(f"[CTF-DEBUG] Date range: {out['eom'].min()} to {out['eom'].max()}", flush=True)

    return out[["id", "eom", "w"]]