"""
CTF Strategy V6: Rolling Factor Selection with RankSharpe Weighting
====================================================================

  1. Compute factor long-short (Q5-Q1) quintile returns from stock-level data.
  2. Every 3 months (rebalance), use trailing 10-year window of LS returns to:
     a. Select top 5 factors by annualized Sharpe ratio.
     b. Weight them by RankSharpe (rank of Sharpe, normalized to sum to 1).
  3. Map portfolio-level factor weights to individual stock positions:
     For each selected factor k with weight w_k at month t:
       stock i gets  +w_k / n_Q5  if in top quintile     (rank > 0.8)
                     -w_k / n_Q1  if in bottom quintile   (rank <= 0.2)
                     0            otherwise
     Net stock weight = sum across selected factors.

  This exactly replicates the portfolio-level weighted-LS return at the
  individual stock level: sum_i(w_i * ret_i) = sum_k(w_k * LS_k).

Configuration: Window=10yr, Rebalance=3mo, N_factors=5,
               Selection=Sharpe, Weighting=RankSharpe

Rules compliance:
  Rule 1:  Temporal integrity — LS returns for selection end at t-1.
           Stock quintile rankings use only characteristics available at t.
  Rule 2:  Feature selection is entirely algorithmic (rolling Sharpe ranking).
           No manual factor choices based on known historical performance.
  Rule 3:  Uses only the provided CTF dataset; no external data.
  Rule 6:  Monthly stock-level positions; factor selection every 3 months.
  Rule 11: Defines main(chars, features, daily_ret) -> DataFrame[id, eom, w].
  Rule 12: Output has columns id, eom, w with no missing values.
"""

import numpy as np
import pandas as pd

# ---------------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------------
WINDOW_YEARS = 10          # rolling lookback for factor selection
REBAL_MONTHS = 3           # rebalance factor selection every N months
N_FACTORS = 5              # number of factors to select at each rebalance
MIN_OBS = 60               # minimum non-NaN months for a factor to be valid
MIN_STOCKS_QUINTILE = 10   # minimum stocks per quintile leg


# ---------------------------------------------------------------------------
# Step 1: Compute factor long-short returns from stock-level data
# ---------------------------------------------------------------------------
def compute_factor_ls_returns(chars, feature_names, min_stocks, min_obs):
    """
    For each feature, form monthly equal-weight long-short quintile returns.

    Methodology (matches JKP quintile construction):
      - Cross-sectional percentile rank via pandas rank(pct=True)
      - Q5 (long) = rank > 0.8;  Q1 (short) = rank <= 0.2
      - Both quintiles must have >= MIN_STOCKS_QUINTILE stocks
      - LS return = mean(ret | Q5) - mean(ret | Q1)

    Returns DataFrame with DatetimeIndex (eom) and one column per feature.
    """
    import time as _time
    print("  Computing factor LS returns from stock-level data...")
    t0 = _time.time()

    valid_features = [f for f in feature_names if f in chars.columns]
    cols_needed = ["eom", "ret_exc_lead1m"] + valid_features
    df = chars[cols_needed].copy()
    df = df[df["ret_exc_lead1m"].notna()]

    months = sorted(df["eom"].unique())
    n_features = len(valid_features)
    ls_matrix = np.full((len(months), n_features), np.nan)

    for m_idx, eom in enumerate(months):
        month_data = df[df["eom"] == eom]
        if len(month_data) < min_stocks:
            continue

        ret = month_data["ret_exc_lead1m"].values
        ranks = month_data[valid_features].rank(pct=True).values

        q1_mask = np.nan_to_num(ranks, nan=0.5) <= 0.2
        q5_mask = np.nan_to_num(ranks, nan=0.5) > 0.8

        q5_count = q5_mask.sum(axis=0)
        q1_count = q1_mask.sum(axis=0)
        q5_sum = (ret[:, None] * q5_mask).sum(axis=0)
        q1_sum = (ret[:, None] * q1_mask).sum(axis=0)

        valid = (q5_count >= MIN_STOCKS_QUINTILE) & (q1_count >= MIN_STOCKS_QUINTILE)
        ls_row = np.full(n_features, np.nan)
        ls_row[valid] = (q5_sum[valid] / q5_count[valid]
                         - q1_sum[valid] / q1_count[valid])
        ls_matrix[m_idx] = ls_row

        if (m_idx + 1) % 200 == 0:
            print(f"    {m_idx+1}/{len(months)} months ({_time.time()-t0:.1f}s)")

    ls_df = pd.DataFrame(ls_matrix, index=pd.to_datetime(months),
                         columns=valid_features).sort_index()

    # Drop features with too few valid months
    valid_cols = ls_df.columns[ls_df.notna().sum() >= min_obs]
    ls_df = ls_df[valid_cols]

    print(f"  LS returns: {ls_df.shape[0]} months x "
          f"{ls_df.shape[1]} features ({_time.time()-t0:.1f}s)")
    return ls_df


# ---------------------------------------------------------------------------
# Step 2: Build rolling rebalance schedule
# ---------------------------------------------------------------------------
def build_rebalance_schedule(ls_returns, window_years, rebal_months,
                             n_factors, min_obs):
    """
    At each rebalance point (every rebal_months), use trailing window_years
    of LS returns to:
      1. Select top n_factors by annualized Sharpe ratio.
      2. Weight by RankSharpe (rank of Sharpe, normalized to sum to 1).

    Returns dict: rebal_date -> (factors_list, weights_array)

    Temporal integrity (Rule 1):
      sel_end = rebalance_date - 1 month  (no same-month data in selection)
      sel_start = sel_end - window_years
    """
    import time as _time
    print("  Building rebalance schedule...")
    t0 = _time.time()

    all_months = ls_returns.index.sort_values()
    rebal_points = all_months[::rebal_months]

    schedule = {}
    n_active = 0

    for rp in rebal_points:
        sel_end = rp - pd.DateOffset(months=1)
        sel_start = sel_end - pd.DateOffset(years=window_years)

        # Get data in window
        window = ls_returns.loc[sel_start:sel_end]
        if len(window.dropna(how="all")) < 24:
            continue

        # Find valid factors (enough non-NaN observations in the window)
        threshold = min(min_obs, max(6, len(window) // 2))
        valid = list(window.columns[window.notna().sum() >= threshold])
        if not valid:
            continue

        # Compute Sharpe for each valid factor
        tc = window[valid].fillna(0)
        stds = tc.std()
        stds = stds.replace(0, np.nan)
        sharpes = ((tc.mean() / stds) * np.sqrt(12)).dropna()

        if len(sharpes) < 1:
            continue

        # Select top N by Sharpe
        n_sel = min(n_factors, len(sharpes))
        top = sharpes.nlargest(n_sel)
        factors = list(top.index)

        # RankSharpe weighting: rank(1..N), then normalize
        ranks = top.rank().values
        weights = ranks / ranks.sum()

        schedule[rp] = (factors, weights)
        n_active += 1

    print(f"  Schedule: {n_active} active rebalance points "
          f"({_time.time()-t0:.1f}s)")

    # Log first and last few rebalance points
    dates = sorted(schedule.keys())
    if dates:
        print(f"  First: {dates[0].strftime('%Y-%m')}, "
              f"Last: {dates[-1].strftime('%Y-%m')}")
        for d in dates[:3]:
            f, w = schedule[d]
            print(f"    {d.strftime('%Y-%m')}: {f} "
                  f"w=[{', '.join(f'{x:.3f}' for x in w)}]")

    return schedule


# ---------------------------------------------------------------------------
# Step 3: Compute stock-level portfolio weights
# ---------------------------------------------------------------------------
def compute_stock_weights(chars, schedule):
    """
    For each month, look up the applicable rebalance schedule, then assign
    stock-level weights that exactly replicate the portfolio-level strategy.

    For factor k with portfolio weight w_k:
      stock i in Q5 (top quintile):    +w_k / n_Q5
      stock i in Q1 (bottom quintile): -w_k / n_Q1
      stock i in Q2-Q4 or NaN:          0

    Net stock weight = sum across K selected factors.

    Uses all stocks (not just those with ret_exc_lead1m) to avoid look-ahead
    bias: at time t we don't know which stocks will have valid returns at t+1.
    Quintile ranks are based on all stocks with valid characteristics. (Rule 1)
    """
    import time as _time
    print("  Computing stock-level weights...")
    t0 = _time.time()

    rebal_dates = sorted(schedule.keys())
    if not rebal_dates:
        return pd.DataFrame(columns=["id", "eom", "w"])

    # Gather all factors that appear in any rebalance
    all_factors = set()
    for factors, _ in schedule.values():
        all_factors.update(factors)
    all_factors = sorted(all_factors)

    # Work with required columns only
    available = [f for f in all_factors if f in chars.columns]
    cols = ["id", "eom"] + available
    df_work = chars[cols].copy()
    months = sorted(df_work["eom"].unique())

    results = []
    active_months = 0

    for m_idx, eom in enumerate(months):
        # Find most recent rebalance on or before this month
        applicable = [d for d in rebal_dates if d <= eom]
        if not applicable:
            continue
        factors, weights = schedule[applicable[-1]]
        if not factors:
            continue

        # Get stocks for this month
        mask = df_work["eom"] == eom
        month_data = df_work.loc[mask]
        n_stocks = len(month_data)
        if n_stocks < 50:
            continue

        ids = month_data["id"].values
        stock_w = np.zeros(n_stocks, dtype=np.float64)
        n_active_factors = 0

        for k, (factor, fw) in enumerate(zip(factors, weights)):
            if factor not in month_data.columns:
                continue

            # Percentile rank within this month
            ranks = month_data[factor].rank(pct=True).values

            q5 = ranks > 0.8
            q1 = ranks <= 0.2
            # NaN ranks -> neither quintile
            nan_mask = np.isnan(ranks)
            q5 = q5 & ~nan_mask
            q1 = q1 & ~nan_mask

            n_q5 = q5.sum()
            n_q1 = q1.sum()
            if n_q5 < MIN_STOCKS_QUINTILE or n_q1 < MIN_STOCKS_QUINTILE:
                continue

            # Factor contribution: weighted by this factor's RankSharpe weight
            contrib = np.zeros(n_stocks, dtype=np.float64)
            contrib[q5] = fw / n_q5
            contrib[q1] = -fw / n_q1
            stock_w += contrib
            n_active_factors += 1

        # Collect non-zero weights
        nonzero = stock_w != 0.0
        if nonzero.any():
            results.append(pd.DataFrame({
                "id": ids[nonzero],
                "eom": eom,
                "w": stock_w[nonzero],
            }))
            active_months += 1

        if (m_idx + 1) % 200 == 0:
            print(f"    {m_idx+1}/{len(months)} months ({_time.time()-t0:.1f}s)")

    if not results:
        return pd.DataFrame(columns=["id", "eom", "w"])

    output = pd.concat(results, ignore_index=True)
    elapsed = _time.time() - t0
    nm = output["eom"].nunique()
    n_long = (output["w"] > 0).sum()
    n_short = (output["w"] < 0).sum()
    print(f"  Weights: {len(output):,} rows, {nm} months, "
          f"{output['id'].nunique():,} unique stocks ({elapsed:.1f}s)")
    print(f"  Avg long/month: {n_long/max(nm,1):.0f}, "
          f"avg short/month: {n_short/max(nm,1):.0f}")
    return output


# ===================================================================
# main() -- CTF entry point (Rules 11-12)
# ===================================================================
def main(chars: pd.DataFrame,
         features: pd.DataFrame,
         daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    CTF-compliant strategy entry point.

    Args:
        chars:     Stock characteristics (ctff_chars.parquet)
        features:  Feature list (ctff_features.parquet)
        daily_ret: Daily returns (ctff_daily_ret.parquet) -- not used

    Returns:
        DataFrame with columns: id, eom, w  (no missing values)
    """
    import time
    t0 = time.time()

    # --- Prepare data ---
    feature_names = features["features"].tolist()
    chars = chars.copy()
    chars["eom"] = pd.to_datetime(chars["eom"])
    n_months = chars["eom"].nunique()
    print(f"Data: {len(chars):,} stock-months, {len(feature_names)} features, "
          f"{n_months} months")

    # --- Adapt thresholds for small datasets (validation mode = 123 months) ---
    if n_months < 200:
        min_obs = max(6, n_months // 5)
        min_stocks = 20
        print(f"  Small dataset mode: {n_months} months, min_obs={min_obs}")
    else:
        min_obs = MIN_OBS
        min_stocks = 50

    # --- Step 1: Compute factor LS returns from stock-level data ---
    ls_returns = compute_factor_ls_returns(
        chars, feature_names, min_stocks, min_obs)

    # --- Step 2: Build rolling rebalance schedule ---
    schedule = build_rebalance_schedule(
        ls_returns, WINDOW_YEARS, REBAL_MONTHS, N_FACTORS, min_obs)

    # --- Step 3: Compute stock-level weights ---
    output = compute_stock_weights(chars, schedule)

    # --- Validate output (Rule 12) ---
    assert set(output.columns) == {"id", "eom", "w"}, "Bad columns"
    assert output["w"].notna().all(), "NaN weights"
    assert len(output) > 0, "Empty output"

    elapsed = time.time() - t0
    print(f"\nTotal time: {elapsed:.1f}s")
    print(f"Output: {len(output):,} rows")
    return output
