#!/usr/bin/env python
# coding: utf-8
"""
CTF Portfolio Optimization

Using the JKP dataset (402 firm characteristics, US stocks 1952-2023) to build
a long-short portfolio that maximizes Sharpe ratio. Approach: ensemble of
Ridge, XGBoost, and LightGBM on ECDF-transformed features.

CTF Admin Modifications (2026-02-23):
--------------------------------------
1. Created requirements.txt with xgboost, lightgbm, matplotlib, seaborn
   Reason: These packages are not pre-installed in the CTF environment and
   were missing from the submission.

2. Added flush=True to all print statements in main() function
   Reason: Container output buffering can hide progress during long HPC runs.

3. Added explicit type casting for output columns (id as int, w as float)
   Reason: Ensures consistent output format across different pandas versions.

4. Added [CTF-DEBUG] progress statements with output summary
   Reason: Provides better visibility into HPC job progress and final output.

5. Wrapped notebook exploration code in if __name__ == "__main__" guard
   Reason: The original code read local data files at module import time,
   which fails in the CTF pipeline where data is passed to main() as parameters.

6. Changed groupby().apply() to groupby().transform() for weight calculation
   Reason: Pandas deprecated including grouping columns in apply() results.
   Using transform() avoids this issue and is more efficient.
"""

import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')


def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    Main entry point for CTF pipeline.
    Receives data as parameters, returns portfolio weights DataFrame.
    """
    import gc
    from sklearn.linear_model import Ridge
    import xgboost as xgb
    import lightgbm as lgb

    np.random.seed(42)
    feats = features['features'].tolist()

    # ECDF transform
    print('[CTF-DEBUG] Preprocessing...', flush=True)
    ranked = chars.groupby('eom')[feats].rank(method='max', pct=True)
    ranked[chars[feats] == 0] = 0
    ranked = (ranked - 0.5).fillna(0)

    X = ranked.values.astype(np.float32)
    del ranked; gc.collect()
    y = chars['ret_exc_lead1m'].fillna(0).values.astype(np.float32)
    is_tr = (chars['ctff_test'] == False).values
    is_te = (chars['ctff_test'] == True).values
    Xtr, ytr, Xte = X[is_tr], y[is_tr], X[is_te]
    del X, y; gc.collect()

    # Ridge (tuned alpha=1000)
    print('[CTF-DEBUG] Training Ridge...', flush=True)
    mdl = Ridge(alpha=1000.0)
    mdl.fit(Xtr, ytr)
    p_ridge = mdl.predict(Xte).astype(np.float32)
    del mdl; gc.collect()

    # XGBoost (tuned params)
    print('[CTF-DEBUG] Training XGBoost...', flush=True)
    mdl = xgb.XGBRegressor(max_depth=4, learning_rate=0.05, n_estimators=200,
                            subsample=0.8, colsample_bytree=0.3,
                            objective='reg:squarederror', tree_method='hist',
                            reg_alpha=0.1, reg_lambda=1.0, verbosity=0, random_state=42, n_jobs=-1)
    mdl.fit(Xtr, ytr)
    p_xgb = mdl.predict(Xte).astype(np.float32)
    del mdl; gc.collect()

    # LightGBM (tuned params)
    print('[CTF-DEBUG] Training LightGBM...', flush=True)
    mdl = lgb.LGBMRegressor(num_leaves=15, learning_rate=0.05, n_estimators=300,
                             feature_fraction=0.5, bagging_fraction=0.8, bagging_freq=5,
                             reg_alpha=0.1, reg_lambda=1.0, verbosity=-1, random_state=42, n_jobs=-1)
    mdl.fit(Xtr, ytr)
    p_lgb = mdl.predict(Xte).astype(np.float32)
    del mdl, Xtr, ytr, Xte; gc.collect()

    # ensemble: avg rank
    print('[CTF-DEBUG] Building weights...', flush=True)
    out = chars.loc[is_te, ['id','eom']].copy()
    out['pred'] = (pd.Series(p_ridge, index=out.index).rank(pct=True)
                 + pd.Series(p_xgb, index=out.index).rank(pct=True)
                 + pd.Series(p_lgb, index=out.index).rank(pct=True)) / 3
    del p_ridge, p_xgb, p_lgb; gc.collect()

    # Compute weights per month using transform (avoids pandas groupby.apply deprecation)
    def compute_weights(pred_series):
        r = pred_series.rank(pct=True) - 0.5
        return r / r.abs().sum() * 2

    out['w'] = out.groupby('eom')['pred'].transform(compute_weights)
    out['id'] = out['id'].astype(int)
    out['w'] = out['w'].astype(float)
    print(f'[CTF-DEBUG] Output shape: {out.shape}, months: {out["eom"].nunique()}', flush=True)
    return out[['id','eom','w']]


# ============================================================================
# NOTEBOOK EXPLORATION CODE BELOW - Only runs when script is executed directly
# ============================================================================
if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.linear_model import Ridge
    import xgboost as xgb
    import lightgbm as lgb

    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams['figure.dpi'] = 100
    np.random.seed(42)

    # Load data
    features_df = pd.read_parquet('data/ctff_features.parquet')
    FEATURES = features_df['features'].tolist()

    chars = pd.read_parquet('data/ctff_chars.parquet')
    chars['eom'] = pd.to_datetime(chars['eom'])

    daily_ret = pd.read_parquet('data/ctff_daily_ret.parquet')
    daily_ret['date'] = pd.to_datetime(daily_ret['date'])

    print(f'chars: {chars.shape}, daily_ret: {daily_ret.shape}, features: {len(FEATURES)}')
    print(f'Train: {(~chars["ctff_test"]).sum():,} rows, Test: {chars["ctff_test"].sum():,} rows')

    # EDA
    meta_cols = [c for c in chars.columns if c not in FEATURES]
    print('Non-feature columns:', meta_cols)
    print()
    print(chars['size_grp'].value_counts())
    print(f'\nTrain months: {chars.loc[~chars["ctff_test"], "eom"].nunique()}')
    print(f'Test months:  {chars.loc[chars["ctff_test"], "eom"].nunique()}')

    # Universe size over time
    stocks_per_month = chars.groupby('eom')['id'].nunique()
    fig, ax = plt.subplots(figsize=(14, 4))
    ax.plot(stocks_per_month.index, stocks_per_month.values, lw=1)
    ax.axvline(pd.Timestamp('1990-01-01'), color='red', ls='--', alpha=0.7, label='Train/Test split')
    ax.set_ylabel('# Stocks')
    ax.set_title('Universe size over time')
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Return distributions
    ret = chars['ret_exc_lead1m'].dropna()
    fig, axes = plt.subplots(1, 2, figsize=(13, 4))
    axes[0].hist(ret.clip(-0.5, 0.5), bins=100, density=True, alpha=0.7, edgecolor='none')
    axes[0].axvline(ret.mean(), color='red', ls='--', label=f'Mean: {ret.mean():.4f}')
    axes[0].axvline(ret.median(), color='orange', ls='--', label=f'Median: {ret.median():.4f}')
    axes[0].set_title('Next-month excess returns')
    axes[0].legend(fontsize=8)

    ret_tr = chars.loc[~chars['ctff_test'], 'ret_exc_lead1m'].dropna()
    ret_te = chars.loc[chars['ctff_test'], 'ret_exc_lead1m'].dropna()
    axes[1].hist(ret_tr.clip(-0.5, 0.5), bins=80, density=True, alpha=0.5, label='Train', edgecolor='none')
    axes[1].hist(ret_te.clip(-0.5, 0.5), bins=80, density=True, alpha=0.5, label='Test', edgecolor='none')
    axes[1].set_title('Train vs Test')
    axes[1].legend(fontsize=8)
    plt.tight_layout()
    plt.show()

    print(f'mean={ret.mean():.5f}, std={ret.std():.4f}, skew={ret.skew():.2f}, kurtosis={ret.kurtosis():.1f}')

    # Preprocessing: ECDF transform
    ranked = chars.groupby('eom')[FEATURES].rank(method='max', pct=True)
    ranked[chars[FEATURES] == 0] = 0
    ranked = ranked - 0.5
    ranked = ranked.fillna(0)

    chars_proc = chars[['id', 'eom', 'ret_exc_lead1m', 'ctff_test', 'size_grp']].copy()
    chars_proc[FEATURES] = ranked
    del ranked

    print(f'Shape: {chars_proc.shape}, NaN left: {chars_proc[FEATURES].isna().sum().sum()}')

    # CV setup
    def temporal_cv(df, n_splits=3, gap=1):
        tr = df[df['ctff_test'] == False]
        months = sorted(tr['eom'].unique())
        n = len(months)
        val_sz = n // (n_splits + 1)

        splits = []
        for i in range(n_splits):
            vs = n - (n_splits - i) * val_sz
            ve = vs + val_sz
            te = vs - gap

            t_months = set(months[:te])
            v_months = set(months[vs:ve])

            t_idx = tr[tr['eom'].isin(t_months)].index.values
            v_idx = tr[tr['eom'].isin(v_months)].index.values
            splits.append((t_idx, v_idx))
            print(f'  Fold {i+1}: train {len(t_months)}m ({min(t_months):%Y-%m} - {max(t_months):%Y-%m}), '
                  f'val {len(v_months)}m, {len(t_idx):,}/{len(v_idx):,} rows')
        return splits

    cv_splits = temporal_cv(chars_proc)

    def portfolio_sharpe(df, pred_col='pred'):
        """Rank predictions -> long-short weights -> monthly returns -> annualized Sharpe."""
        monthly = []
        for eom, g in df.groupby('eom'):
            r = g[pred_col].rank(pct=True) - 0.5
            w = r / r.abs().sum() * 2
            monthly.append({'eom': eom, 'ret': (w * g['ret_exc_lead1m']).sum()})
        res = pd.DataFrame(monthly).set_index('eom').sort_index()
        sr = res['ret'].mean() / res['ret'].std() * np.sqrt(12)
        return sr, res

    # Baseline
    test_data = chars_proc[chars_proc['ctff_test']].copy()
    ew = test_data.groupby('eom')['ret_exc_lead1m'].mean()
    print(f'Equal-weight baseline Sharpe: {ew.mean()/ew.std()*np.sqrt(12):.3f}')

    # Train models
    train_mask = ~chars_proc['ctff_test']
    test_mask = chars_proc['ctff_test']

    X_tr = chars_proc.loc[train_mask, FEATURES].values
    y_tr = chars_proc.loc[train_mask, 'ret_exc_lead1m'].fillna(0).values
    X_te = chars_proc.loc[test_mask, FEATURES].values

    # Ridge
    ridge = Ridge(alpha=1000)
    ridge.fit(X_tr, y_tr)
    test_data['pred_ridge'] = ridge.predict(X_te)
    sr_ridge, ret_ridge = portfolio_sharpe(test_data, 'pred_ridge')
    print(f'Ridge test Sharpe: {sr_ridge:.3f}')

    # XGBoost
    xgb_final = xgb.XGBRegressor(max_depth=4, learning_rate=0.05, n_estimators=200,
                                  subsample=0.8, colsample_bytree=0.3,
                                  objective='reg:squarederror', tree_method='hist',
                                  reg_alpha=0.1, reg_lambda=1.0, verbosity=0, random_state=42, n_jobs=-1)
    xgb_final.fit(X_tr, y_tr)
    test_data['pred_xgb'] = xgb_final.predict(X_te)
    sr_xgb, ret_xgb = portfolio_sharpe(test_data, 'pred_xgb')
    print(f'XGBoost test Sharpe: {sr_xgb:.3f}')

    # LightGBM
    lgb_final = lgb.LGBMRegressor(num_leaves=15, learning_rate=0.05, n_estimators=300,
                                   feature_fraction=0.5, bagging_fraction=0.8, bagging_freq=5,
                                   reg_alpha=0.1, reg_lambda=1.0, verbosity=-1, random_state=42, n_jobs=-1)
    lgb_final.fit(X_tr, y_tr)
    test_data['pred_lgb'] = lgb_final.predict(X_te)
    sr_lgb, ret_lgb = portfolio_sharpe(test_data, 'pred_lgb')
    print(f'LightGBM test Sharpe: {sr_lgb:.3f}')

    # Ensemble
    for c in ['pred_ridge', 'pred_xgb', 'pred_lgb']:
        test_data[c + '_rk'] = test_data.groupby('eom')[c].rank(pct=True)
    test_data['pred_ens'] = (test_data['pred_ridge_rk'] + test_data['pred_xgb_rk'] + test_data['pred_lgb_rk']) / 3
    sr_ens, ret_ens = portfolio_sharpe(test_data, 'pred_ens')

    print('Test Sharpe ratios:')
    print(f'  Equal weight  {ew.mean()/ew.std()*np.sqrt(12):.3f}')
    print(f'  Ridge         {sr_ridge:.3f}')
    print(f'  XGBoost       {sr_xgb:.3f}')
    print(f'  LightGBM      {sr_lgb:.3f}')
    print(f'  Ensemble      {sr_ens:.3f}')

    # Validate output format
    output = test_data[['id','eom']].copy()
    output['w'] = test_data.groupby('eom')['pred_ens'].transform(
        lambda x: (x.rank(pct=True) - 0.5).pipe(lambda r: r / r.abs().sum() * 2))

    assert list(output.columns) == ['id','eom','w']
    assert output['w'].isna().sum() == 0

    wsum = output.groupby('eom')['w'].apply(lambda x: x.abs().sum())
    print(f'Shape: {output.shape}, months: {output["eom"].nunique()}')
    print(f'|w| sum per month: {wsum.mean():.4f} +/- {wsum.std():.4f}')

    # Save
    output.to_parquet('data/submission_weights.parquet', index=False)
    output.to_csv('data/submission_weights.csv', index=False)
    print('Saved.')
