"""Implementation of Nonlinear Portfolio Transformer

Implementation of the portfolio transformer from "Artificial Intelligence
Asset Pricing Models" (NBER WP 33351, 2025)

Implementation by Gareth Campbell using AI coding support

Differences from paper:
    - K=2 blocks (paper: up to K=10) — compute constraint
    - 3 seeds (paper: 10) — compute constraint
    - 1/sqrt(D) attention scaling (paper: none) — stability
    - Init variance 1/D^2 (paper: 1/D) — conservative
    - Full-sample nano-stock pctile (paper: NYSE only)
    - Gradient clipping at norm 1.0 (paper: not specified)
    - Pre-1990 feature coverage at 90% (no lookahead, fixed D)
    - No nano-stock filter (NANO_PCTILE=0)

Run on an L4 GPU.

See PDF for full documentation.

CTF Admin Modifications (2026-04-03):
--------------------------------------
1. Added security suppression comment to PyTorch model inference mode call
   Reason: The pipeline security checker pattern matches method calls that
   contain certain function names. Added ctf-sec-ignore comment to prevent
   false positive on standard PyTorch API usage.

2. Added security suppression comment to broad exception handler (line 237)
   Reason: The warm-start logic intentionally catches all exceptions when
   loading model state and falls back to cold-start training. This is safe
   because failure to load state simply means training from scratch.

3. Updated torch version in requirements.txt: 2.5.1 → >=2.6.0
   Reason: CVE-2025-32434 is a critical RCE vulnerability in torch ≤2.5.1
   affecting torch.load() with weights_only=True. Fixed in torch 2.6.0.
"""


import os, sys, time, gc
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

# ╔═══════════════════════════════════════════════════════════╗
# ║  CONFIGURATION — adjust these to experiment              ║
# ╚═══════════════════════════════════════════════════════════╝

# Architecture (paper Section 3)
N_BLOCKS      = 2       # K=2 (ablation: K5 fixes applied to K2 architecture)
N_HEADS       = 1       # Attention heads (paper: H=1 is optimal)
D_FF          = 256     # FFN hidden dimension

# Training (paper Section 4.3)
WINDOW        = 60      # Rolling window in months
N_EPOCHS_COLD = 50      # Epochs for cold-start (first window)
N_EPOCHS_WARM = 10      # Epochs for warm-start (subsequent windows)
N_SEEDS       = 3       # Seeds (same as original K2_S3 for comparison)
LR            = 1e-4    # Adam learning rate
GRAD_CLIP     = 1.0     # Gradient clipping norm
MIN_OBS       = 60      # Paper: first OOS after full 60-month window

# Data preprocessing
MIN_COVERAGE  = 0.90    # Pre-1990 coverage threshold (no lookahead, fixed D)
PRE_TEST_DATE = '1990-01-01'  # Features selected using data before this date
NANO_PCTILE   = 0.000   # Keep all stocks (no nano-stock filter)
MIN_STOCKS    = 30      # Minimum cross-section size
MAX_MISS_FRAC = 1/3     # Drop stocks missing > 1/3 of chars (paper Sec 4.1)

# ═══ Portfolio Transformer Model (Kelly et al. 2025, Section 3) ═══

class AttentionHead(nn.Module):
    """Single attention head: sigma(Y W_h Y') Y V_h  — equation (9)
    Paper uses NO explicit scaling — learned W_h absorbs any needed scale."""

    def __init__(self, D, init_scale):
        super().__init__()
        self.W = nn.Parameter(torch.randn(D, D) * init_scale)
        self.V = nn.Parameter(torch.randn(D, D) * init_scale)
        self.scale = 1.0 / np.sqrt(D)  # 1/√D scaling (kept from K2_S3 — acts as regularization)

    def forward(self, Y):
        # Y: (N, D)
        scores = (Y @ self.W @ Y.t()) * self.scale     # (N, N) — with 1/√D scaling
        attn = F.softmax(scores, dim=-1)                # (N, N) row-wise
        return attn @ (Y @ self.V)                      # (N, D)


class TransformerBlock(nn.Module):
    """
    T(Y) = F^R(A^R(Y))  — equation (16)
    Attention + residual, then FFN + residual.
    """

    def __init__(self, D, n_heads, d_ff, init_scale):
        super().__init__()
        self.heads = nn.ModuleList([
            AttentionHead(D, init_scale) for _ in range(n_heads)
        ])
        # Feed-forward: max(0, Y W1 + b1) W2 + b2
        self.W1 = nn.Parameter(torch.randn(D, d_ff) * (1.0 / d_ff))
        self.b1 = nn.Parameter(torch.zeros(d_ff))
        self.W2 = nn.Parameter(torch.randn(d_ff, D) * init_scale)
        self.b2 = nn.Parameter(torch.zeros(D))

    def forward(self, Y):
        # Multi-head attention + residual
        attn_out = sum(h(Y) for h in self.heads)
        Y = attn_out + Y
        # FFN + residual
        ffn_out = F.relu(Y @ self.W1 + self.b1) @ self.W2 + self.b2
        Y = ffn_out + Y
        return Y


class PortfolioTransformer(nn.Module):
    """
    w_t = T^{(K)}(X_t) @ lambda  — equation (18)
    K stacked transformer blocks, then final linear layer.

    Init (paper Section 4.3):
        W, V ~ N(0, 1/D)
        W1 ~ N(0, 1/d_ff), W2 ~ N(0, 1/D)
        biases = 0, lambda ~ N(0, 1/D)
    """

    def __init__(self, D, n_blocks, n_heads, d_ff):
        super().__init__()
        init_scale = 1.0 / D
        self.blocks = nn.ModuleList([
            TransformerBlock(D, n_heads, d_ff, init_scale)
            for _ in range(n_blocks)
        ])
        self.lam = nn.Parameter(torch.randn(D) * init_scale)
        self.D = D

    def forward(self, X):
        """X: (N, D) characteristics -> w: (N,) portfolio weights"""
        Y = X
        for block in self.blocks:
            Y = block(Y)
        return Y @ self.lam

    def msrr_loss(self, X, R):
        """MSRR objective: (1 - w'R)^2"""
        w = self.forward(X)
        return (1.0 - w @ R) ** 2


# ═══ Training function — per-month shuffled SGD (paper Section 4.3) ═══

def train_model(model, X_list, R_list, n_epochs, lr, grad_clip, device):
    """
    Train a PortfolioTransformer on a window of monthly data.
    Per-month SGD with shuffled order each epoch — matches paper Section 4.3.
    Each month gets its own forward/backward/step: T updates per epoch.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    T = len(X_list)

    # Pre-convert all months to GPU tensors (~30 MB for 60 months)
    X_gpu = [torch.as_tensor(x, dtype=torch.float32, device=device) for x in X_list]
    R_gpu = [torch.as_tensor(r, dtype=torch.float32, device=device) for r in R_list]

    for epoch in range(n_epochs):
        # Shuffle month order each epoch (paper protocol)
        month_order = np.random.permutation(T)
        for t in month_order:
            optimizer.zero_grad()
            w_t = model(X_gpu[t])
            loss = (1.0 - w_t @ R_gpu[t]) ** 2
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

    # Free GPU tensors
    del X_gpu, R_gpu

    return model


# ═══ Rolling-window backtest with fixed feature set ═══

def rolling_backtest(train_months, month_X, month_ids, month_mask, month_R,
                     D, device, cfg):
    """
    Rolling-window backtest with warm-starting and seed averaging.
    Features are fixed (selected from pre-test-period data), so D never changes
    and warm-start states are always valid.
    Per-seed L1 normalization BEFORE averaging (matches paper/training.py).
    """
    all_weight_months = sorted(month_X.keys())

    seed_states = [None] * cfg['n_seeds']
    results = []
    n_skip = 0
    n_done = 0

    t0 = time.time()

    for m_idx in range(len(all_weight_months)):
        eom = all_weight_months[m_idx]

        # Training months strictly before this month
        cutoff = eom - pd.DateOffset(months=1)
        avail = [m for m in train_months if m <= cutoff]

        if len(avail) < cfg['min_obs']:
            n_skip += 1
            continue

        # Rolling window
        win = avail[-cfg['window']:] if len(avail) > cfg['window'] else avail
        X_list = [month_X[m][month_mask[m]] for m in win]
        R_list = [month_R[m] for m in win]

        X_oos = month_X[eom]
        ids_oos = month_ids[eom]

        if len(ids_oos) < MIN_STOCKS:
            n_skip += 1
            continue

        # Per-seed: L1-normalize weights THEN average (paper protocol)
        w_sum = np.zeros(len(ids_oos), dtype=np.float64)
        n_valid_seeds = 0

        for s in range(cfg['n_seeds']):
            torch.manual_seed(s * 10000 + 42)
            np.random.seed(s * 10000 + 42)

            model = PortfolioTransformer(
                D, cfg['n_blocks'], cfg['n_heads'], cfg['d_ff']
            ).to(device)

            # Warm-start: load previous window's state for this seed
            if seed_states[s] is not None:
                try:
                    model.load_state_dict(seed_states[s])
                    n_ep = cfg['n_epochs_warm']
                except Exception: - Intentional fallback to cold-start
                    n_ep = cfg['n_epochs_cold']
            else:
                n_ep = cfg['n_epochs_cold']

            model = train_model(
                model, X_list, R_list, n_ep,
                cfg['lr'], cfg['grad_clip'], device
            )

            # Save state to CPU for warm-starting next window
            seed_states[s] = {k: v.cpu() for k, v in model.state_dict().items()}

            # Forward pass (no gradient needed)
            model.eval() - PyTorch method, not builtin eval()
            with torch.no_grad():
                X_t = torch.as_tensor(X_oos, dtype=torch.float32, device=device)
                w = model(X_t).cpu().numpy().astype(np.float64)

            # L1-normalize THIS seed's weights before accumulating
            abs_sum_s = np.abs(w).sum()
            if abs_sum_s > 1e-10:
                w /= abs_sum_s
                w_sum += w
                n_valid_seeds += 1

            # Free GPU memory
            del model, X_t
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        if n_valid_seeds == 0:
            n_skip += 1
            continue

        # Average the already-normalized per-seed weights
        w_avg = (w_sum / n_valid_seeds).astype(np.float32)

        # Store non-zero weights (already L1 ≈ 1 by construction)
        mask = np.abs(w_avg) > 1e-15
        if mask.any():
            results.append(pd.DataFrame({
                'id': ids_oos[mask],
                'eom': eom,
                'w': w_avg[mask]
            }))
        n_done += 1

        # Progress report
        if n_done % 50 == 0 or (m_idx + 1) == len(all_weight_months):
            el = time.time() - t0
            print(f'  {m_idx+1}/{len(all_weight_months)} | '
                  f'done={n_done} skip={n_skip} D={D} | '
                  f'{el/60:.1f}m elapsed')

    elapsed = time.time() - t0
    print(f'\nBacktest complete: {len(results)} OOS months, '
          f'{n_skip} skipped, {elapsed/60:.1f} min total')

    if not results:
        return pd.DataFrame(columns=['id', 'eom', 'w'])
    return pd.concat(results, ignore_index=True)


# ═══ Main entry point (CTF submission) ═══

def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    Nonlinear Portfolio Transformer — Kelly et al. (2025)

    Args:
        chars: Stock characteristics (ctff_chars.parquet)
        features: Computed features (ctff_features.parquet)
        daily_ret: Historical daily returns (ctff_daily_ret.parquet)

    Returns:
        DataFrame with columns: id, eom, w
    """
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'PyTorch {torch.__version__}, Device: {DEVICE}')
    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name(0)}')

    # ═══ Feature selection: pre-1990 coverage (no lookahead, fixed D) ═══
    feature_names = features['features'].tolist()
    feat_cols = sorted([f for f in feature_names if f in chars.columns])
    print(f'{len(feat_cols)} candidate features found in data')

    chars = chars.copy()
    chars['eom'] = pd.to_datetime(chars['eom'])

    # Compute coverage on pre-test-period data only
    pre_test = chars[chars['eom'] < PRE_TEST_DATE]
    coverage = pre_test[feat_cols].notna().mean()
    valid_features = sorted([f for f in feat_cols if coverage[f] >= MIN_COVERAGE])
    D = len(valid_features)
    print(f'Features >= {MIN_COVERAGE:.0%} coverage (pre-{PRE_TEST_DATE[:4]}): D={D}')

    # ═══ Process months into dicts ═══
    months = sorted(chars['eom'].unique())

    month_X    = {}   # eom -> (N, D) float32
    month_ids  = {}   # eom -> (N,) stock ids
    month_mask = {}   # eom -> (N,) bool has-return mask
    month_R    = {}   # eom -> (N_ret,) float32 returns

    print(f'Processing {len(months)} months...')
    t0 = time.time()

    for i, eom in enumerate(months):
        md = chars.loc[chars['eom'] == eom]

        # Nano-stock filter
        me = md['market_equity']
        if NANO_PCTILE > 0 and me.notna().sum() >= 10:
            thr = me.quantile(NANO_PCTILE)
            md = md.loc[me.isna() | (me >= thr)]

        # Per-stock missing filter (paper Sec 4.1): drop stocks missing > 1/3 chars
        n_missing = md[valid_features].isna().sum(axis=1)
        keep_stock = (n_missing <= len(valid_features) * MAX_MISS_FRAC).values
        md = md.loc[keep_stock]

        if len(md) < MIN_STOCKS:
            continue

        # Rank-standardize to [-0.5, 0.5], NaN -> 0 (float32)
        X = md[valid_features].rank(pct=True).values.astype(np.float32)
        X = np.nan_to_num(X, nan=0.5) - 0.5
        ids = md['id'].values
        rets = md['ret_exc_lead1m'].values
        has_ret = np.isfinite(rets)

        month_X[eom] = X
        month_ids[eom] = ids

        if has_ret.sum() >= MIN_STOCKS:
            month_mask[eom] = has_ret
            month_R[eom] = rets[has_ret].astype(np.float32)

        if (i + 1) % 200 == 0:
            print(f'  {i+1}/{len(months)} ({time.time()-t0:.0f}s)')

    del chars
    gc.collect()

    train_months = sorted(month_R.keys())
    print(f'Done: {len(month_R)} train months, {len(month_X)} total, D={D} '
          f'({time.time()-t0:.0f}s)')

    # ═══ Run backtest ═══
    cfg = {
        'n_blocks': N_BLOCKS, 'n_heads': N_HEADS, 'd_ff': D_FF,
        'window': WINDOW, 'n_epochs_cold': N_EPOCHS_COLD,
        'n_epochs_warm': N_EPOCHS_WARM, 'n_seeds': N_SEEDS,
        'lr': LR, 'grad_clip': GRAD_CLIP, 'min_obs': MIN_OBS,
    }

    print(f'Config: K={N_BLOCKS}, H={N_HEADS}, d_ff={D_FF}, seeds={N_SEEDS}, D={D}')

    output = rolling_backtest(
        train_months, month_X, month_ids, month_mask, month_R,
        D, DEVICE, cfg
    )

    print(f'Output: {len(output):,} rows, {output["eom"].nunique()} months')
    return output