#!/usr/bin/env python3
"""
A long/short equity strategy that adjusts its behavior based on whether
the market looks like it's in bull or bear mode.

The rough idea: train two separate Ridge regression models (one for each
regime), use cross-sectional factor signals to rank stocks, then build
a long/short portfolio scaled to the current market environment.

Steps (nothing here peeks at future data):
  1. Per-month: fill missing values with the median, clip outliers, z-score
  2. Optionally combine raw features into 6 economic composites
  3. Estimate a value-weighted market return each month
  4. Decide if we're in a bull or bear market (trailing 12-month return + hysteresis)
  5. Keep expanding-window Ridge models for each regime, refit annually
  6. Turn predictions into long/short portfolio weights
  7. Scale down exposure when markets are volatile
  8. Blend weights with last month's to reduce churn
  9. Return DataFrame with columns (id, eom, w)

Only needs numpy, pandas, and pyarrow — no sklearn.

CTF Admin Modifications (2026-02-20):
--------------------------------------
1. Added main(chars, features, daily_ret) function with required CTF signature
   Reason: Original script used run_strategy() with no arguments and loaded data
   from hardcoded file paths. CTF pipeline requires the standard entry point.

2. Renamed output columns from (date, id, weight) to (id, eom, w)
   Reason: CTF Rule 12 requires specific column names for scoring.

3. Modified run_strategy() to accept DataFrame parameter instead of loading files
   Reason: CTF pipeline provides data via function arguments, not file I/O.

4. Added flush=True to all print() statements
   Reason: Container output buffering can cause logs to be lost on crashes.

5. Removed file-writing code from run_strategy() (now returns DataFrame directly)
   Reason: CTF pipeline handles output file I/O automatically.
"""

import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import time
import sys
import os

# where to find things
DATA_PATH     = "ctff_chars.parquet"
FEATURES_PATH = "ctff_features.csv"
OUTPUT_PATH   = "weights_output.csv"

# use 6 composite factors or go with raw features?
USE_COMPOSITES = False       # False = raw features (or subset below)
USE_FEATURE_SUBSET = True    # trim to the most useful ~40 features when not using composites
MODEL_TYPE     = "ridge"     # only ridge is implemented

# training settings
BURN_IN_MONTHS = 120    # wait 10 years before making any predictions
REFIT_EVERY    = 12     # retrain once a year
CV_FOLDS       = 5      # cross-validation folds for picking regularization strength
ALPHA_GRID     = [0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]

# column names in the parquet
DATE_COL = "eom"
ID_COL   = "id"
RET_COL  = "ret_exc_lead1m"  # this is already shifted: value at t is the return from t to t+1
ME_COL   = "market_equity"

# how much to be long/short depending on regime
GROSS_BULL     = 1.2    # total position size in bull
NET_BULL       = 0.4    # net long bias in bull
GROSS_BEAR     = 0.8    # smaller book in bear
NET_BEAR       = -0.1   # slight short bias in bear
MAX_ABS_WEIGHT = 0.03   # no single stock gets more than 3%
RESCALE_ITERS  = 3      # a few rescaling passes to hit gross/net targets

# how aggressively to trim outliers
WINSOR_P = (0.01, 0.99)

# the ~40 features that tend to work best (used when USE_COMPOSITES=False and USE_FEATURE_SUBSET=True)
FEATURE_SUBSET = [
    # good long signals (high IC)
    "cop_mev", "ocf_at", "ocf_mev", "mispricing_mgmt", "ope_me",
    "mispricing_perf", "cop_at", "resff3_12_1", "cop_me", "ocf_me",
    "ret_12_1", "eqnpo_me", "qmj_prof", "ebitda_mev", "eqnpo_at",
    "ebitda_me", "ocf_be", "fcf_be", "fcf_ppen", "fcf_sale",
    # good short signals (negative IC)
    "inv_gr1", "opex_gr1", "ppeinv_gr1a", "at_gr1", "seas_2_5na",
    "eqnetis_at", "ivol_ff3_21d", "ivol_capm_252d", "ivol_hxz4_21d",
    "oa_gr1a", "ivol_capm_21d", "ret_1_0", "noa_gr1a", "fincf_mev",
    "rvol_21d", "rvol_252d", "netis_at", "rmax1_21d", "fincf_at",
    "rmax5_21d",
]

# regime detection settings
REGIME_WINDOW  = 12     # look back 12 months to determine regime
HYSTERESIS     = True
HYSTERESIS_K   = 2      # need 2 months in a row to officially flip regimes

# smooth out the bull/bear transition instead of hard-switching
USE_REGIME_BLEND       = True   # linearly blend targets when near the boundary
BLEND_BAND             = 0.04   # blend when trailing 12m return is within +/-4%
EARLY_BULL_GROSS_BONUS = 0.2    # add a little extra gross at the start of a bull market

# how to estimate and target volatility
VOL_WINDOW = 12         # months to estimate trailing market vol

# exponential blending of weights across months to reduce turnover
SMOOTHING_ALPHA = 0.5   # 0 = keep old weights entirely, 1 = use new weights entirely

# the 6 composite factors and their constituent signals
# pos signals get added, neg signals get subtracted before averaging
COMPOSITE_DEFS = {
    "VALUE": {
        "pos": ["be_me", "sale_me", "cash_mev", "be_mev", "sale_bev"],
        "neg": [],
    },
    "QUALITY": {
        "pos": ["gp_at", "ni_at", "gp_sale", "ebit_sale", "ebitda_sale"],
        "neg": ["earnings_variability"],
    },
    "MOMENTUM": {
        "pos": ["ret_12_1", "ret_6_1", "ret_36_12"],
        "neg": [],
    },
    "INVESTMENT": {
        "pos": [],
        "neg": ["at_gr1", "capx_at", "sale_gr1"],
    },
    "DISTRESS": {
        "pos": ["o_score"],
        "neg": ["z_score", "f_score", "debt_at", "debt_be"],
    },
    "LOWRISK": {
        "pos": [],
        "neg": ["beta_60m", "ivol_capm_60m", "rvol_252d", "ivol_capm_252d"],
    },
}
COMPOSITE_NAMES = list(COMPOSITE_DEFS.keys())


def load_data(
    data_path: str | None = None,
    features_path: str | None = None,
    features_override: list | None = None,
    chars_df: pd.DataFrame | None = None,
    features_df: pd.DataFrame | None = None,
):
    """
    Load the parquet file and the list of feature names.

    Can either load from file paths (standalone mode) or accept DataFrames
    directly (CTF pipeline mode).

    Returns the full dataframe sorted by (date, stock id) and the list
    of feature column names we'll actually use.
    """
    # Get feature names from either DataFrame or file
    if features_df is not None:
        # CTF pipeline mode: features_df is a DataFrame with feature metadata
        if "features" in features_df.columns:
            features = features_df["features"].tolist()
        elif "feature" in features_df.columns:
            features = features_df["feature"].tolist()
        else:
            # Assume column names are the features (excluding metadata cols)
            features = [c for c in features_df.columns if c not in [DATE_COL, ID_COL]]
    else:
        # Standalone mode: load from file
        features = pd.read_csv(features_path)["features"].tolist()

    if features_override is not None:
        features = [f for f in features_override if f in set(features)]

    # Get main data from either DataFrame or file
    if chars_df is not None:
        # CTF pipeline mode: use provided DataFrame
        df = chars_df.copy()
        print(f"[CTF-DEBUG] Using provided chars DataFrame ({len(df):,} rows)", flush=True)
    else:
        # Standalone mode: load from file
        # market_equity can appear in both the features list and the meta columns,
        # so be careful not to load it twice (pyarrow would give a 2-column result)
        meta_cols = [DATE_COL, ID_COL, ME_COL, RET_COL]
        meta_set  = set(meta_cols)
        load_cols = meta_cols + [f for f in features if f not in meta_set]

        print(f"Loading parquet: {data_path} ({len(load_cols)} columns)...", flush=True)
        table = pq.read_table(data_path, columns=load_cols)
        df = table.to_pandas()

    df[DATE_COL] = pd.to_datetime(df[DATE_COL])

    # make sure all feature columns are float so math works uniformly
    int_cols = [c for c in features if c in df.columns and pd.api.types.is_integer_dtype(df[c])]
    for c in int_cols:
        df[c] = df[c].astype(np.float64)

    df = df.sort_values([DATE_COL, ID_COL]).reset_index(drop=True)

    # Filter features to only those present in the data
    features = [f for f in features if f in df.columns]

    missing_feats = [f for f in features if f not in df.columns]
    assert len(missing_feats) == 0, f"Missing features in parquet: {missing_feats[:5]}"
    assert RET_COL in df.columns, f"Missing target column {RET_COL}"
    assert ME_COL  in df.columns, f"Missing market equity column {ME_COL}"

    n_dates  = df[DATE_COL].nunique()
    n_stocks = df[ID_COL].nunique()
    print(f"  Loaded: {len(df):,} rows | {n_dates} dates | {n_stocks} unique stocks", flush=True)
    print(f"  Date range: {df[DATE_COL].min().date()} to {df[DATE_COL].max().date()}", flush=True)
    return df, features


def preprocess_month(df_t: pd.DataFrame, features: list, winsor_p=(0.01, 0.99)) -> pd.DataFrame:
    """
    Clean up one month of data: fill NaNs with the median, clip extreme
    values at the 1st/99th percentile, then z-score each feature.

    This is done purely within the month — no information from other months
    leaks in here.
    """
    assert df_t[DATE_COL].nunique() == 1, (
        f"preprocess_month got {df_t[DATE_COL].nunique()} dates — needs exactly one"
    )

    df_out = df_t.copy()
    X = df_out[features].values.astype(np.float64)
    lo_pct = winsor_p[0] * 100.0
    hi_pct = winsor_p[1] * 100.0

    for j in range(X.shape[1]):
        col = X[:, j]

        # fill missing with median
        nan_mask = np.isnan(col)
        if nan_mask.all():
            col[:] = 0.0
            X[:, j] = col
            continue
        med = float(np.median(col[~nan_mask]))
        col[nan_mask] = med

        # clip outliers
        lo = np.percentile(col, lo_pct)
        hi = np.percentile(col, hi_pct)
        col = np.clip(col, lo, hi)

        # standardize
        mu    = col.mean()
        sigma = col.std(ddof=0)
        if sigma < 1e-10:
            col = np.zeros_like(col)
        else:
            col = (col - mu) / sigma

        X[:, j] = col

    df_out[features] = X
    return df_out


def build_composites(df_proc: pd.DataFrame, features_in_data: set) -> pd.DataFrame:
    """
    Combine z-scored features into 6 economic composite signals.

    Each composite is a simple average of its members, with some features
    flipped in sign (e.g., more debt = worse quality, so debt gets negated
    when building the QUALITY composite).

    Missing features are silently skipped.
    """
    keep = [ID_COL, DATE_COL, ME_COL, RET_COL]
    result = df_proc[[c for c in keep if c in df_proc.columns]].copy()

    for comp_name, defn in COMPOSITE_DEFS.items():
        pos_cols = [f for f in defn["pos"] if f in features_in_data]
        neg_cols = [f for f in defn["neg"] if f in features_in_data]

        parts = []
        for c in pos_cols:
            parts.append(df_proc[c].values)
        for c in neg_cols:
            parts.append(-df_proc[c].values)

        if not parts:
            result[comp_name] = 0.0
        else:
            stacked = np.stack(parts, axis=1)
            result[comp_name] = np.nanmean(stacked, axis=1)

    return result


def compute_market_returns(df: pd.DataFrame) -> pd.Series:
    """
    Estimate the market's monthly return using a value-weighted average.

    Rm[t] = sum(market_cap * return) / sum(market_cap) for that month.

    Falls back to equal-weight if market cap data is missing.
    Returns a series indexed by end-of-month date.
    """
    dates_arr = df[DATE_COL].values
    ret_arr   = df[RET_COL].values.astype(np.float64)
    me_arr    = df[ME_COL].values.astype(np.float64)

    unique_dates, inverse = np.unique(dates_arr, return_inverse=True)
    n_dates_u = len(unique_dates)
    Rm_vals   = np.zeros(n_dates_u, dtype=np.float64)

    for idx in range(n_dates_u):
        mask  = inverse == idx
        ret_g = ret_arr[mask]
        me_g  = me_arr[mask]
        valid = np.isfinite(ret_g) & np.isfinite(me_g) & (me_g > 0)
        if valid.sum() == 0:
            # no valid market cap data — fall back to equal weight
            fin = np.isfinite(ret_g)
            Rm_vals[idx] = float(ret_g[fin].mean()) if fin.any() else 0.0
        else:
            total_me     = me_g[valid].sum()
            Rm_vals[idx] = float((me_g[valid] * ret_g[valid]).sum() / total_me)

    return pd.Series(Rm_vals, index=pd.to_datetime(unique_dates), name="Rm").sort_index()


def compute_regime(
    Rm_series: pd.Series,
    window: int = 12,
    hysteresis: bool = True,
    hysteresis_k: int = 2,
) -> pd.Series:
    """
    Label each month as 'bull' or 'bear' based on trailing market returns.

    The rule: if the cumulative return over the past `window` months is
    positive, call it bull; otherwise bear. With hysteresis on, we don't
    switch labels until we see `hysteresis_k` months in a row of the
    opposite signal — this prevents thrashing around on noisy signals.

    Only past returns are used, so no look-ahead here.

    Returns a series of 'bull', 'bear', or None (if not enough history yet).
    """
    dates    = Rm_series.index.tolist()
    Rm_vals  = Rm_series.values.astype(np.float64)
    n        = len(dates)
    raw      = [None] * n

    for idx in range(window, n):
        seg   = Rm_vals[idx - window : idx]
        valid = seg[np.isfinite(seg)]
        if len(valid) == 0:
            raw[idx] = None
        else:
            raw[idx] = "bull" if valid.sum() > 0.0 else "bear"

    if not hysteresis:
        return pd.Series(raw, index=dates, name="regime", dtype=object)

    # apply the hysteresis filter so we don't flip-flop
    regimes         = [None] * n
    current_regime  = None
    pending         = None
    pending_count   = 0

    for idx in range(n):
        sig = raw[idx]
        if sig is None:
            regimes[idx] = None
            continue

        if current_regime is None:
            current_regime = sig
            pending        = None
            pending_count  = 0
            regimes[idx]   = current_regime
            continue

        if sig == current_regime:
            # still in the same regime, reset any pending switch
            pending       = None
            pending_count = 0
        else:
            # possible regime change — count consecutive signals
            if sig == pending:
                pending_count += 1
            else:
                pending       = sig
                pending_count = 1

            if pending_count >= hysteresis_k:
                current_regime = sig
                pending        = None
                pending_count  = 0

        regimes[idx] = current_regime

    return pd.Series(regimes, index=dates, name="regime", dtype=object)


def _spearman_ic(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Rank correlation between true and predicted returns (pure numpy).
    Returns 0.0 if either input is constant.
    """
    n = len(y_true)
    if n < 2:
        return 0.0

    def _rank(arr):
        order = arr.argsort()
        r = np.empty(n, dtype=np.float64)
        r[order] = np.arange(n, dtype=np.float64)
        return r

    r1 = _rank(y_true.astype(np.float64))
    r2 = _rank(y_pred.astype(np.float64))

    r1 -= r1.mean()
    r2 -= r2.mean()

    s1 = np.sqrt((r1 ** 2).sum())
    s2 = np.sqrt((r2 ** 2).sum())

    if s1 < 1e-12 or s2 < 1e-12:
        return 0.0

    return float((r1 * r2).sum() / (s1 * s2))


def _ridge_solve(XTX: np.ndarray, Xty: np.ndarray, alpha: float) -> np.ndarray:
    """
    Solve the Ridge regression normal equations: coef = (X'X + alpha*I)^{-1} X'y.

    We use np.linalg.solve instead of inverting the matrix directly,
    which is more numerically stable. With alpha > 0, the system is
    always positive definite so this should never fail.
    """
    P = XTX.shape[0]
    A = XTX + alpha * np.eye(P, dtype=XTX.dtype)
    try:
        return np.linalg.solve(A, Xty)
    except np.linalg.LinAlgError:
        return np.zeros(P, dtype=np.float64)


def time_blocked_cv_ridge(
    X: np.ndarray,
    y: np.ndarray,
    months: np.ndarray,
    alpha_grid: list,
    n_folds: int = 5,
) -> float:
    """
    Pick the best Ridge regularization strength using time-series cross-validation.

    We split months into blocks. Each block serves as a validation set, and
    we train only on earlier months — never future ones. This mimics how
    the model will actually be used.

    To avoid recomputing X'X from scratch for each fold, we precompute it
    per month and add/subtract as needed. Much faster on big datasets.

    Returns the alpha value with the best average rank correlation.
    """
    unique_months = np.unique(months)
    n_unique      = len(unique_months)

    if n_unique < n_folds + 1:
        # not enough data for proper CV, just pick the middle of the grid
        return alpha_grid[len(alpha_grid) // 2]

    P = X.shape[1]

    # precompute X'X and X'y for each month separately
    month_XTX  = {}
    month_Xty  = {}
    month_mask = {}

    total_XTX = np.zeros((P, P), dtype=np.float64)
    total_Xty = np.zeros(P,      dtype=np.float64)

    for m in unique_months:
        mask  = (months == m)
        X_m   = X[mask].astype(np.float64)
        y_m   = y[mask].astype(np.float64)
        XTX_m = X_m.T @ X_m
        Xty_m = X_m.T @ y_m
        month_XTX[m]  = XTX_m
        month_Xty[m]  = Xty_m
        month_mask[m] = mask
        total_XTX    += XTX_m
        total_Xty    += Xty_m

    fold_size  = n_unique // n_folds
    best_alpha = alpha_grid[0]
    best_score = -np.inf

    for alpha in alpha_grid:
        fold_ics = []

        for fold_i in range(n_folds):
            val_start     = fold_i * fold_size
            val_end       = (fold_i + 1) * fold_size if fold_i < n_folds - 1 else n_unique
            val_month_ids = unique_months[val_start:val_end]
            train_month_ids = unique_months[:val_start]  # only months before the val block

            if len(train_month_ids) < 2:
                continue

            # subtract val and future months from the running total to get training set
            train_XTX = total_XTX.copy()
            train_Xty = total_Xty.copy()
            for m in val_month_ids:
                train_XTX -= month_XTX[m]
                train_Xty -= month_Xty[m]

            future_month_ids = unique_months[val_end:]
            for m in future_month_ids:
                train_XTX -= month_XTX[m]
                train_Xty -= month_Xty[m]

            coef = _ridge_solve(train_XTX, train_Xty, alpha)

            val_mask = np.zeros(len(y), dtype=bool)
            for m in val_month_ids:
                val_mask |= month_mask[m]

            X_val = X[val_mask].astype(np.float64)
            y_val = y[val_mask].astype(np.float64)

            if len(y_val) < 2:
                continue

            y_pred = X_val @ coef
            ic     = _spearman_ic(y_val, y_pred)
            fold_ics.append(ic)

        if fold_ics:
            mean_ic = float(np.mean(fold_ics))
            if mean_ic > best_score:
                best_score = mean_ic
                best_alpha = alpha

    return best_alpha


def _rescale_to_targets(
    w: np.ndarray,
    gross_target: float,
    net_target: float,
    max_abs_weight: float = MAX_ABS_WEIGHT,
    n_iter: int = RESCALE_ITERS,
) -> np.ndarray:
    """
    Push weights toward gross and net targets while capping individual positions.

    Runs a few iterations of: scale gross → shift net → clip per-name.
    Three passes get very close to both targets without being exact
    (clipping breaks the exact math, so we iterate to compensate).
    """
    w = w.copy().astype(np.float64)
    n = len(w)
    if n == 0:
        return w

    for _ in range(n_iter):
        gross = np.abs(w).sum()
        if gross > 1e-12:
            w *= gross_target / gross

        shift = (w.sum() - net_target) / n
        w    -= shift

        w = np.clip(w, -max_abs_weight, max_abs_weight)

    return w


def _blend_targets(
    cum12: float,
    gross_bull: float,
    net_bull: float,
    gross_bear: float,
    net_bear: float,
    blend_band: float = BLEND_BAND,
    early_bull_bonus: float = EARLY_BULL_GROSS_BONUS,
) -> tuple[float, float]:
    """
    Smoothly interpolate between bull and bear portfolio targets.

    Instead of a hard switch when the regime changes, we blend based on
    how strongly bull or bear the market looks (measured by the trailing
    12-month cumulative return). Near the zero line, we're in the middle.

    We also add a small gross exposure bonus at the start of a bull
    market to capture momentum when markets first turn up.
    """
    if not np.isfinite(cum12) or blend_band <= 0:
        return gross_bull, net_bull

    # map cum12 to [-1, 1]: -1 is deep bear, +1 is deep bull
    strength = max(-1.0, min(1.0, cum12 / blend_band))
    alpha = (strength + 1.0) / 2.0  # 0 = full bear, 1 = full bull

    gross_t = gross_bear + alpha * (gross_bull - gross_bear)
    net_t   = net_bear   + alpha * (net_bull   - net_bear)

    # taper the early bull bonus as we move deeper into bull territory
    if strength > 0.0:
        gross_t += early_bull_bonus * (1.0 - strength)

    return gross_t, net_t


def make_weights(
    scores: np.ndarray,
    regime: str,
    gross_bull: float = GROSS_BULL,
    net_bull:   float = NET_BULL,
    gross_bear: float = GROSS_BEAR,
    net_bear:   float = NET_BEAR,
    gross_target: float | None = None,
    net_target:   float | None = None,
    max_abs_weight: float = MAX_ABS_WEIGHT,
    n_iter: int = RESCALE_ITERS,
) -> np.ndarray:
    """
    Turn model prediction scores into portfolio weights.

    We use rank-based weighting rather than raw scores, which makes the
    portfolio more robust to outliers. Stocks with higher predicted
    returns get positive weights (long), lower get negative (short).

    The weights are then scaled to hit the regime's gross/net targets.
    """
    n = len(scores)
    if n < 2:
        return np.zeros(n, dtype=np.float64)

    # rank stocks by score
    order = scores.argsort()
    ranks = np.empty(n, dtype=np.float64)
    ranks[order] = np.arange(n, dtype=np.float64)

    # map to [-0.5, 0.5] then z-score for a clean unit-variance signal
    mapped = ranks / (n - 1) - 0.5
    std = mapped.std(ddof=0)
    if std < 1e-10:
        return np.zeros(n, dtype=np.float64)
    signal = (mapped - mapped.mean()) / std

    if gross_target is not None and net_target is not None:
        gross_t, net_t = gross_target, net_target
    elif regime == "bull":
        gross_t, net_t = gross_bull, net_bull
    else:
        gross_t, net_t = gross_bear, net_bear

    return _rescale_to_targets(signal, gross_t, net_t, max_abs_weight, n_iter)


def run_strategy(
    chars_df: pd.DataFrame | None = None,
    features_df: pd.DataFrame | None = None,
) -> pd.DataFrame:
    """
    Run the whole strategy from raw data to output weights.

    Args:
        chars_df: Optional DataFrame with stock characteristics (CTF pipeline mode).
                  If None, loads from DATA_PATH (standalone mode).
        features_df: Optional DataFrame with feature names (CTF pipeline mode).
                     If None, loads from FEATURES_PATH (standalone mode).

    Returns a dataframe with columns [id, eom, w] where eom is the
    end-of-month at which the weights are set (to be held through the next month).
    """
    t_start = time.time()
    print(f"[CTF-DEBUG] Starting strategy execution at {time.strftime('%Y-%m-%d %H:%M:%S')}", flush=True)

    # --- load data ---

    features_override = None
    if not USE_COMPOSITES and USE_FEATURE_SUBSET:
        features_override = FEATURE_SUBSET

    if chars_df is not None:
        # CTF pipeline mode
        df, features = load_data(
            chars_df=chars_df,
            features_df=features_df,
            features_override=features_override,
        )
    else:
        # Standalone mode
        df, features = load_data(DATA_PATH, FEATURES_PATH, features_override=features_override)
    features_set = set(features)
    if features_override is not None:
        missing = [f for f in features_override if f not in features_set]
        if missing:
            print(f"Warning: {len(missing)} requested features not found in parquet: {missing[:5]}", flush=True)

    dates        = sorted(df[DATE_COL].unique())
    date_to_idx  = {d: i for i, d in enumerate(dates)}
    n_dates      = len(dates)
    print(f"  {n_dates} dates total; burn-in={BURN_IN_MONTHS}, refit every={REFIT_EVERY}", flush=True)

    print("Computing market returns...", flush=True)
    Rm = compute_market_returns(df)

    print("Labeling bull/bear regimes...", flush=True)
    regime_series = compute_regime(
        Rm, window=REGIME_WINDOW, hysteresis=HYSTERESIS, hysteresis_k=HYSTERESIS_K
    )
    regime_dict = regime_series.to_dict()  # dict lookup is faster than series indexing

    # trailing 12-month cumulative return (used for blending bull/bear targets)
    Rm_vals   = Rm.values
    cum12_map = {}
    for idx in range(n_dates):
        t = dates[idx]
        if idx >= REGIME_WINDOW:
            seg = Rm_vals[idx - REGIME_WINDOW : idx]
            seg = seg[np.isfinite(seg)]
            cum12_map[t] = float(seg.sum()) if len(seg) > 0 else np.nan
        else:
            cum12_map[t] = np.nan

    # trailing 12-month market volatility (used for vol scaling)
    sigma_m = {}
    for idx in range(n_dates):
        t = dates[idx]
        if idx >= VOL_WINDOW:
            seg = Rm_vals[idx - VOL_WINDOW : idx]
            seg = seg[np.isfinite(seg)]
            sigma_m[t] = float(np.std(seg, ddof=1)) if len(seg) >= 2 else np.nan
        else:
            sigma_m[t] = np.nan

    # target vol = median vol over the prediction period
    pred_vols = [
        sigma_m[dates[idx]]
        for idx in range(BURN_IN_MONTHS, n_dates)
        if np.isfinite(sigma_m.get(dates[idx], np.nan))
    ]
    sigma_star = float(np.median(pred_vols)) if pred_vols else 0.04
    print(f"  target vol (sigma_star) = {sigma_star:.6f}", flush=True)

    # --- preprocess every month upfront ---

    P = len(COMPOSITE_NAMES) if USE_COMPOSITES else len(features)
    print(f"Preprocessing {n_dates} months (composites={USE_COMPOSITES}, features={P})...", flush=True)

    processed = {}
    for i, (d, df_t) in enumerate(df.groupby(DATE_COL, sort=True)):
        d = pd.Timestamp(d)
        if i % 100 == 0:
            elapsed = time.time() - t_start
            print(f"  ... {i}/{n_dates} months done  ({elapsed:.0f}s)", end="\r", flush=True)

        df_proc = preprocess_month(df_t, features, WINSOR_P)

        if USE_COMPOSITES:
            df_feat = build_composites(df_proc, features_set)
            X_t     = df_feat[COMPOSITE_NAMES].values.astype(np.float64)
            y_t     = df_feat[RET_COL].values.astype(np.float64)
            ids_t   = df_feat[ID_COL].values
        else:
            X_t   = df_proc[features].values.astype(np.float32)  # float32 saves RAM
            y_t   = df_proc[RET_COL].values.astype(np.float64)
            ids_t = df_proc[ID_COL].values

        processed[d] = {"X": X_t, "y": y_t, "ids": ids_t}

    print(f"\n  Done preprocessing in {time.time()-t_start:.1f}s", flush=True)

    # --- set up running accumulators for each regime ---
    # we maintain X'X and X'y as running sums so we don't have to restack
    # the full training dataset every time we refit

    cumulative_XTX = {"bull": np.zeros((P, P)), "bear": np.zeros((P, P))}
    cumulative_Xty = {"bull": np.zeros(P),      "bear": np.zeros(P)}
    month_list     = {"bull": [],               "bear": []}

    coef       = {"bull": None, "bear": None}
    best_alpha = {"bull": ALPHA_GRID[len(ALPHA_GRID) // 2],
                  "bear": ALPHA_GRID[len(ALPHA_GRID) // 2]}

    # for turnover smoothing
    w_smoothed_prev = None
    ids_prev        = None

    # --- main prediction loop ---

    output_records = []
    print(f"Generating weights for dates {BURN_IN_MONTHS} through {n_dates-1}...", flush=True)

    for t_idx in range(BURN_IN_MONTHS, n_dates):
        t        = dates[t_idx]
        regime_t = regime_dict.get(t)

        if regime_t is None:
            print(f"  WARNING: no regime label at {t.date()}, skipping", flush=True)
            continue

        # add training months to our running accumulators
        # on the first prediction date, add all burn-in months at once;
        # after that, just add the most recent completed month
        if t_idx == BURN_IN_MONTHS:
            accum_dates = dates[:BURN_IN_MONTHS]
        else:
            accum_dates = [dates[t_idx - 1]]

        for tau in accum_dates:
            assert tau < t, f"Training date {tau.date()} >= prediction date {t.date()} — look-ahead!"
            regime_tau = regime_dict.get(tau)
            if regime_tau is None:
                continue

            X_tau = processed[tau]["X"].astype(np.float64)
            y_tau = processed[tau]["y"].astype(np.float64)

            # zero out any remaining NaNs in features
            nan_rows = ~np.isfinite(X_tau).all(axis=1)
            if nan_rows.any():
                X_tau[nan_rows] = np.nan_to_num(X_tau[nan_rows], nan=0.0, posinf=0.0, neginf=0.0)

            finite_y = np.isfinite(y_tau)
            X_tau    = X_tau[finite_y]
            y_tau    = y_tau[finite_y]

            if len(y_tau) == 0:
                continue

            cumulative_XTX[regime_tau] += X_tau.T @ X_tau
            cumulative_Xty[regime_tau] += X_tau.T @ y_tau
            month_list[regime_tau].append(date_to_idx[tau])

        # refit the model once per year (or whenever enough data is available)
        should_refit = ((t_idx - BURN_IN_MONTHS) % REFIT_EVERY == 0)

        if should_refit:
            for reg in ["bull", "bear"]:
                m_list       = month_list[reg]
                n_reg_months = len(m_list)

                if n_reg_months == 0:
                    print(f"  [{t.date()}] {reg}: no training data yet", flush=True)
                    continue

                if n_reg_months >= CV_FOLDS * 2:
                    # assemble raw arrays for cross-validation scoring
                    X_parts, y_parts, m_parts = [], [], []
                    for m_int in m_list:
                        tau   = dates[m_int]
                        X_tau = processed[tau]["X"].astype(np.float64)
                        y_tau = processed[tau]["y"].astype(np.float64)
                        nan_rows = ~np.isfinite(X_tau).all(axis=1)
                        if nan_rows.any():
                            X_tau[nan_rows] = np.nan_to_num(
                                X_tau[nan_rows], nan=0.0, posinf=0.0, neginf=0.0
                            )
                        fin_y = np.isfinite(y_tau)
                        X_parts.append(X_tau[fin_y])
                        y_parts.append(y_tau[fin_y])
                        m_parts.append(np.full(fin_y.sum(), m_int, dtype=np.int32))

                    X_cv = np.vstack(X_parts)
                    y_cv = np.concatenate(y_parts)
                    m_cv = np.concatenate(m_parts)

                    best_alpha[reg] = time_blocked_cv_ridge(
                        X_cv, y_cv, m_cv, ALPHA_GRID, CV_FOLDS
                    )
                    del X_cv, y_cv, m_cv, X_parts, y_parts, m_parts

                coef[reg] = _ridge_solve(
                    cumulative_XTX[reg], cumulative_Xty[reg], best_alpha[reg]
                )

                print(
                    f"  [{t.date()}] {reg}: {n_reg_months} months, "
                    f"alpha={best_alpha[reg]:.3g}, current regime={regime_t}",
                    flush=True
                )

        # grab the right model for today's regime
        active_coef = coef[regime_t]
        if active_coef is None:
            # fall back to the other regime's model if we don't have one yet
            other = "bear" if regime_t == "bull" else "bull"
            active_coef = coef[other]
        if active_coef is None:
            continue  # no model at all yet, skip this month

        X_t   = processed[t]["X"].astype(np.float64)
        ids_t = processed[t]["ids"]

        nan_rows = ~np.isfinite(X_t).all(axis=1)
        if nan_rows.any():
            X_t[nan_rows] = np.nan_to_num(X_t[nan_rows], nan=0.0, posinf=0.0, neginf=0.0)

        mu_hat = X_t @ active_coef

        # turn scores into weights, with regime-blended targets if enabled
        if USE_REGIME_BLEND:
            cum12_t = cum12_map.get(d, np.nan)
            gross_t, net_t = _blend_targets(
                cum12_t, GROSS_BULL, NET_BULL, GROSS_BEAR, NET_BEAR
            )
            w_raw = make_weights(mu_hat, regime_t, gross_target=gross_t, net_target=net_t)
        else:
            w_raw = make_weights(mu_hat, regime_t)

        # scale down exposure when the market is unusually volatile
        sig_t = sigma_m.get(t, np.nan)
        if np.isfinite(sig_t) and sig_t > 1e-12:
            scale_t = min(1.0, sigma_star / sig_t)
        else:
            scale_t = 1.0
        w_scaled = w_raw * scale_t

        # blend with last month's weights to reduce turnover
        if w_smoothed_prev is not None and ids_prev is not None:
            id_to_prev = dict(zip(ids_prev, w_smoothed_prev))
            w_prev_aligned = np.array(
                [id_to_prev.get(sid, 0.0) for sid in ids_t], dtype=np.float64
            )
            w_blended = (1.0 - SMOOTHING_ALPHA) * w_prev_aligned + SMOOTHING_ALPHA * w_scaled
        else:
            w_blended = w_scaled  # no previous weights on the first date

        # re-hit gross/net targets after blending (blending perturbs them)
        if USE_REGIME_BLEND:
            cum12_t = cum12_map.get(d, np.nan)
            gross_t, net_t = _blend_targets(
                cum12_t, GROSS_BULL, NET_BULL, GROSS_BEAR, NET_BEAR
            )
            w_final = _rescale_to_targets(w_blended, gross_t, net_t)
        elif regime_t == "bull":
            w_final = _rescale_to_targets(w_blended, GROSS_BULL, NET_BULL)
        else:
            w_final = _rescale_to_targets(w_blended, GROSS_BEAR, NET_BEAR)

        w_smoothed_prev = w_final
        ids_prev        = ids_t

        for sid, wt in zip(ids_t, w_final):
            if abs(wt) > 1e-8:
                output_records.append((t, int(sid), float(wt)))

    # --- prepare output ---

    print(f"[CTF-DEBUG] Generated {len(output_records):,} weight records", flush=True)
    output_df = pd.DataFrame(output_records, columns=["eom", "id", "w"])
    output_df["eom"] = pd.to_datetime(output_df["eom"])
    output_df["id"] = output_df["id"].astype(int)
    output_df["w"] = output_df["w"].astype(float)

    # Reorder columns to match CTF format: id, eom, w
    output_df = output_df[["id", "eom", "w"]]

    elapsed = time.time() - t_start
    print(f"[CTF-DEBUG] Strategy completed in {elapsed:.1f}s", flush=True)
    print(f"[CTF-DEBUG] Output shape: {output_df.shape}, date range: {output_df['eom'].min().date()} to {output_df['eom'].max().date()}", flush=True)

    return output_df


def verify_output(output_df: pd.DataFrame, regime_series: pd.Series) -> None:
    """
    Quick sanity check — print gross/net/max-weight stats by regime.
    """
    print("\n-- Output verification --", flush=True)

    by_date = output_df.groupby("eom")["w"].agg(
        gross=lambda w: w.abs().sum(),
        net="sum",
        n_stocks="count",
        max_w=lambda w: w.abs().max(),
    ).reset_index()

    _reg_dict = regime_series.to_dict() if hasattr(regime_series, "to_dict") else regime_series
    by_date["regime"] = by_date["eom"].map(lambda d: _reg_dict.get(pd.Timestamp(d), "?"))

    bull_rows = by_date[by_date["regime"] == "bull"]
    bear_rows = by_date[by_date["regime"] == "bear"]

    def _stats(rows, label):
        if len(rows) == 0:
            print(f"  {label}: no dates", flush=True)
            return
        target_gross = GROSS_BULL if "bull" in label else GROSS_BEAR
        target_net   = NET_BULL   if "bull" in label else NET_BEAR
        print(f"  {label} ({len(rows)} dates):", flush=True)
        print(f"    gross: mean={rows['gross'].mean():.4f}  (target {target_gross:.2f})", flush=True)
        print(f"    net:   mean={rows['net'].mean():.4f}  (target {target_net:.2f})", flush=True)
        print(f"    stocks: avg {rows['n_stocks'].mean():.0f}/month", flush=True)
        print(f"    max |w|: {rows['max_w'].max():.6f}  (cap={MAX_ABS_WEIGHT})", flush=True)

    _stats(bull_rows, "Bull")
    _stats(bear_rows, "Bear")

    cap_violations = (output_df["w"].abs() > MAX_ABS_WEIGHT + 0.001).sum()
    print(f"  Weight cap violations: {cap_violations}", flush=True)
    print(flush=True)


def compute_sharpe(
    weights_df: pd.DataFrame,
    data_path: str = DATA_PATH,
    regime_series: pd.Series = None,
    periods_per_year: int = 12,
) -> pd.DataFrame:
    """
    Compute the annualized Sharpe ratio for the strategy.

    The portfolio return in month t+1 is just the dot product of the weights
    set at t with the returns earned from t to t+1. Both live in the same row
    of the data (ret_exc_lead1m at eom=t), so the join is straightforward.

    Prints a summary and optionally breaks down by bull/bear regime.
    Returns the monthly return series.
    """
    ret_df = pd.read_parquet(data_path, columns=[DATE_COL, ID_COL, RET_COL])
    ret_df[DATE_COL] = pd.to_datetime(ret_df[DATE_COL])

    weights_df = weights_df.copy()
    weights_df["eom"] = pd.to_datetime(weights_df["eom"])

    merged = weights_df.merge(
        ret_df.rename(columns={DATE_COL: "eom"}),
        on=["eom", "id"],
        how="inner",
    )

    port_ret = (
        merged.groupby("eom")
        .apply(lambda g: (g["w"] * g[RET_COL]).sum(), include_groups=False)
        .rename("port_ret")
        .sort_index()
    )

    ann    = periods_per_year ** 0.5
    mean_r = port_ret.mean()
    std_r  = port_ret.std(ddof=1)
    sharpe = mean_r / std_r * ann if std_r > 0 else np.nan

    print("\n-- Sharpe Ratio --", flush=True)
    print(f"  Period:       {port_ret.index.min().date()} to {port_ret.index.max().date()}", flush=True)
    print(f"  Months:       {len(port_ret)}", flush=True)
    print(f"  Mean return:  {mean_r*100:.3f}%/month  ({mean_r*12*100:.2f}%/year)", flush=True)
    print(f"  Volatility:   {std_r*100:.3f}%/month  ({std_r*ann*100:.2f}%/year)", flush=True)
    print(f"  Sharpe:       {sharpe:.4f}  (annualized)", flush=True)

    if regime_series is not None:
        _reg = regime_series.to_dict() if hasattr(regime_series, "to_dict") else regime_series
        port_ret_df = port_ret.reset_index()
        port_ret_df.columns = ["eom", "port_ret"]
        port_ret_df["regime"] = port_ret_df["eom"].map(lambda d: _reg.get(pd.Timestamp(d)))

        for reg in ["bull", "bear"]:
            sub = port_ret_df.loc[port_ret_df["regime"] == reg, "port_ret"]
            if len(sub) < 2:
                continue
            s_sharpe = sub.mean() / sub.std(ddof=1) * ann
            print(f"  {reg.capitalize():5s} Sharpe: {s_sharpe:.4f}  "
                  f"(n={len(sub)}, mean={sub.mean()*100:.3f}%/month)", flush=True)

    print(flush=True)
    return port_ret


def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    CTF entry point for the bull/bear regime-aware strategy.

    Args:
        chars: Stock characteristics DataFrame (from ctff_chars.parquet)
        features: Feature metadata DataFrame (from ctff_features.parquet)
        daily_ret: Daily returns DataFrame (from ctff_daily_ret.parquet) - not used in this strategy

    Returns:
        DataFrame with columns: id, eom, w (portfolio weights)
    """
    print("[CTF-DEBUG] main() called - CTF pipeline mode", flush=True)
    print(f"[CTF-DEBUG] chars shape: {chars.shape}", flush=True)
    print(f"[CTF-DEBUG] features shape: {features.shape}", flush=True)
    print(f"[CTF-DEBUG] daily_ret shape: {daily_ret.shape}", flush=True)

    # Run the strategy using provided DataFrames
    result_df = run_strategy(chars_df=chars, features_df=features)

    return result_df


if __name__ == "__main__":
    # Standalone mode for local testing
    script_dir = os.path.dirname(os.path.abspath(__file__))
    os.chdir(script_dir)

    print("=" * 60, flush=True)
    print("Long/Short Equity Strategy — Bull/Bear Regime-Aware", flush=True)
    print("=" * 60, flush=True)
    print(f"USE_COMPOSITES  = {USE_COMPOSITES}", flush=True)
    print(f"BURN_IN_MONTHS  = {BURN_IN_MONTHS}", flush=True)
    print(f"REFIT_EVERY     = {REFIT_EVERY}", flush=True)
    print(f"HYSTERESIS      = {HYSTERESIS} (k={HYSTERESIS_K})", flush=True)
    print(f"GROSS targets   = bull={GROSS_BULL}, bear={GROSS_BEAR}", flush=True)
    print(f"NET targets     = bull={NET_BULL}, bear={NET_BEAR}", flush=True)
    print(f"MAX_ABS_WEIGHT  = {MAX_ABS_WEIGHT}", flush=True)
    print(f"SMOOTHING_ALPHA = {SMOOTHING_ALPHA}", flush=True)
    print("=" * 60, flush=True)

    result_df = run_strategy()

    # For standalone mode, write output to file
    print(f"\nWriting {len(result_df):,} records to {OUTPUT_PATH}...", flush=True)
    result_df.to_csv(OUTPUT_PATH, index=False)

    # recompute regime labels for verification (fast — just needs returns)
    df_small = pd.read_parquet(DATA_PATH, columns=[DATE_COL, ID_COL, ME_COL, RET_COL])
    df_small[DATE_COL] = pd.to_datetime(df_small[DATE_COL])
    Rm_chk  = compute_market_returns(df_small)
    reg_chk = compute_regime(Rm_chk, REGIME_WINDOW, HYSTERESIS, HYSTERESIS_K)

    verify_output(result_df, reg_chk)
    port_ret = compute_sharpe(result_df, DATA_PATH, reg_chk)

    print("Sample output (first 10 rows):", flush=True)
    print(result_df.head(10).to_string(index=False), flush=True)
    print(f"\nTotal rows: {len(result_df):,}", flush=True)
    print(f"Date range: {result_df['eom'].min().date()} to {result_df['eom'].max().date()}", flush=True)
    print(f"Output: {os.path.abspath(OUTPUT_PATH)}", flush=True)
