"""
ClusterSharpe Selection + MeanVar Weighting + Linear LS
=========================================================================

Three-stage pipeline:
  1. Compute factor long-short returns using linear rank weights:
       For each feature, stock weight = percentile_rank - 0.5 (mean-zero).
       LS return = sum(w_i * ret_i) / sum(|w_i|).
  2. Every 6 months (rebalance), use the full expanding history of LS returns
     to select 10 factors via ClusterSharpe:
       a. Compute correlation matrix of factor LS returns.
       b. Hierarchical clustering (average linkage, distance = 1 - |corr|).
       c. Cut into 30 clusters, rank by best Sharpe per cluster.
       d. Pick single best-Sharpe factor from top 10 clusters.
     Weight factors via mean-variance optimization:
       w* = Sigma^{-1} mu, with Ledoit-Wolf shrinkage (lambda=0.3),
       constrained to non-negative weights (no factor shorting), sum to 1.
  3. Map portfolio-level factor weights to individual stock positions:
       For factor k with weight w_k:
         stock i gets  w_k * (rank_pct_ik - 0.5) / Z_k
       where Z_k = sum_i |rank_pct_ik - 0.5| normalizes so that
       sum_i(stock_weight_ik * ret_i) = w_k * LS_return_k.

Configuration: Window=expanding, Rebalance=6mo, N_factors=10,
               Selection=ClusterSharpe, Weighting=MeanVar

Rules compliance:
  Rule 1:  Temporal integrity -- LS returns for selection end at t-1.
           Stock rank weights use only characteristics available at t.
  Rule 2:  Feature selection is entirely algorithmic (cluster + Sharpe).
           No manual factor choices based on known historical performance.
  Rule 3:  Uses only the provided CTF dataset; no external data.
  Rule 6:  Monthly stock-level positions; factor selection every 6 months.
  Rule 11: Defines main(chars, features, daily_ret) -> DataFrame[id, eom, w].
  Rule 12: Output has columns id, eom, w with no missing values.
"""

import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform

# ---------------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------------
REBAL_MONTHS = 6           # rebalance factor selection every N months
N_FACTORS = 10             # number of factors to select at each rebalance
N_CLUSTERS_MULT = 3        # clusters = N_FACTORS * this multiplier
MIN_OBS = 60               # minimum non-NaN months for a factor to be valid
MIN_STOCKS = 10            # minimum stocks with valid feature per month
SHRINKAGE = 0.3            # Ledoit-Wolf covariance shrinkage intensity


# ---------------------------------------------------------------------------
# Step 1: Compute factor linear-rank LS returns from stock-level data
# ---------------------------------------------------------------------------
def compute_factor_ls_returns(chars, feature_names, min_stocks, min_obs):
    """
    For each feature, compute monthly linear-rank long-short returns.

    Methodology:
      - Cross-sectional percentile rank via pandas rank(pct=True)
      - Weight = rank - 0.5  (mean-zero, range [-0.5, +0.5])
      - LS return = sum(w_i * ret_i) / sum(|w_i|)
      - Requires >= 2*min_stocks non-NaN ranks per feature

    Returns DataFrame with DatetimeIndex (eom) and one column per feature.
    """
    import time as _time
    print("  Computing linear-rank factor LS returns...")
    t0 = _time.time()

    valid_features = [f for f in feature_names if f in chars.columns]
    cols_needed = ["eom", "ret_exc_lead1m"] + valid_features
    df = chars[cols_needed].copy()
    df = df[df["ret_exc_lead1m"].notna()]

    months = sorted(df["eom"].unique())
    n_features = len(valid_features)
    ls_matrix = np.full((len(months), n_features), np.nan)

    for m_idx, eom in enumerate(months):
        month_data = df[df["eom"] == eom]
        n_stocks = len(month_data)
        if n_stocks < min_stocks * 2:
            continue

        ret = month_data["ret_exc_lead1m"].values
        ranks_pct = month_data[valid_features].rank(pct=True).values

        # Linear weight = rank - 0.5 (mean-zero)
        rank_weights = np.nan_to_num(ranks_pct, nan=0.5) - 0.5
        not_nan = ~np.isnan(month_data[valid_features].values)
        rank_weights = rank_weights * not_nan

        weighted_ret = (ret[:, None] * rank_weights).sum(axis=0)
        abs_weight_sum = np.abs(rank_weights).sum(axis=0)

        valid = (not_nan.sum(axis=0) >= min_stocks * 2) & (abs_weight_sum > 1e-10)
        ls_row = np.full(n_features, np.nan)
        ls_row[valid] = weighted_ret[valid] / abs_weight_sum[valid]
        ls_matrix[m_idx] = ls_row

        if (m_idx + 1) % 200 == 0:
            print(f"    {m_idx+1}/{len(months)} months ({_time.time()-t0:.1f}s)")

    ls_df = pd.DataFrame(ls_matrix, index=pd.to_datetime(months),
                         columns=valid_features).sort_index()

    # Drop features with too few valid months
    valid_cols = ls_df.columns[ls_df.notna().sum() >= min_obs]
    ls_df = ls_df[valid_cols]

    print(f"  LS returns: {ls_df.shape[0]} months x "
          f"{ls_df.shape[1]} features ({_time.time()-t0:.1f}s)")
    return ls_df


# ---------------------------------------------------------------------------
# Step 2: Build rolling rebalance schedule (ClusterSharpe + MeanVar)
# ---------------------------------------------------------------------------
def _sharpe_series(tc):
    """Annualized Sharpe ratio for each column of a returns DataFrame."""
    stds = tc.std()
    stds = stds.replace(0, np.nan)
    return ((tc.mean() / stds) * np.sqrt(12)).fillna(0)


def _cluster_sharpe_select(ls_avail, sel_start, sel_end, n_factors, min_obs):
    """
    Hierarchical clustering -> pick best-Sharpe factor from top-N clusters.

    1. Compute correlation matrix of factor LS returns in [sel_start, sel_end].
    2. Convert to distance (1 - |corr|), run hierarchical clustering.
    3. Cut into max(n_factors * N_CLUSTERS_MULT, 20) clusters.
    4. Rank clusters by best Sharpe per cluster.
    5. From top-N clusters, pick the single best factor per cluster.
    """
    # Get valid factors
    train = ls_avail.loc[:sel_end]
    threshold = min(min_obs, max(6, len(train) // 2))
    valid = list(train.columns[train.notna().sum() >= threshold])

    tc = ls_avail.loc[sel_start:sel_end][valid].fillna(0)
    if len(tc) < 12 or len(valid) < n_factors:
        # Fallback: plain Sharpe
        sharpes = _sharpe_series(tc)
        return list(sharpes.nlargest(min(n_factors, len(valid))).index)

    # Correlation-based distance
    corr = tc.corr().fillna(0).clip(-1, 1)
    dist_arr = (1.0 - corr.abs()).values.copy()
    np.fill_diagonal(dist_arr, 0)
    dist_arr = (dist_arr + dist_arr.T) / 2
    dist_arr = np.clip(dist_arr, 0, None)

    try:
        condensed = squareform(dist_arr, checks=False)
        condensed = np.nan_to_num(condensed, nan=1.0, posinf=1.0, neginf=0.0)
        Z = linkage(condensed, method='average')
        n_clusters = max(n_factors * N_CLUSTERS_MULT, 20)
        n_clusters = min(n_clusters, len(valid))
        labels = fcluster(Z, t=n_clusters, criterion='maxclust')
    except Exception:
        # Clustering failed, fallback to raw Sharpe
        sharpes = _sharpe_series(tc)
        return list(sharpes.nlargest(min(n_factors, len(valid))).index)

    # Score each factor by Sharpe
    sharpes = _sharpe_series(tc)

    cluster_df = pd.DataFrame({
        'factor': valid,
        'cluster': labels,
        'score': sharpes.fillna(-999).values,
    })

    # Best factor per cluster
    best_per_cluster = cluster_df.loc[
        cluster_df.groupby('cluster')['score'].idxmax()]
    best_per_cluster = best_per_cluster[best_per_cluster['score'] > -998]
    best_per_cluster = best_per_cluster.sort_values('score', ascending=False)

    selected = best_per_cluster.head(n_factors)['factor'].tolist()
    return selected


def _mean_var_weight(ls_avail, sel_start, sel_end, factors):
    """
    Mean-variance optimised weights: w* = Sigma^{-1} mu (scaled).

    Uses Ledoit-Wolf-style shrinkage on the covariance matrix for stability.
    Constraints: w >= 0 (long-only in factor space), sum(w) = 1.
    Falls back to EW if optimisation is degenerate.
    """
    tc = ls_avail.loc[sel_start:sel_end][factors].fillna(0)
    n = len(factors)
    if len(tc) < max(12, n + 2):
        return np.full(n, 1.0 / n)

    mu = tc.mean().values
    cov = tc.cov().values

    # Ledoit-Wolf shrinkage toward diagonal
    trace_cov = np.trace(cov) / n
    cov_shrunk = ((1 - SHRINKAGE) * cov
                  + SHRINKAGE * trace_cov * np.eye(n))

    try:
        inv_cov = np.linalg.inv(cov_shrunk)
        w_raw = inv_cov @ mu
    except np.linalg.LinAlgError:
        return np.full(n, 1.0 / n)

    # Clip negatives (no shorting factors) and normalize
    w_raw = np.maximum(w_raw, 0)
    total = w_raw.sum()
    if total < 1e-10:
        return np.full(n, 1.0 / n)
    return w_raw / total


def build_rebalance_schedule(ls_returns, rebal_months, n_factors, min_obs):
    """
    At each rebalance point (every rebal_months), use expanding window of
    LS returns to:
      1. Select top n_factors via ClusterSharpe.
      2. Weight by MeanVar optimization.

    Returns dict: rebal_date -> (factors_list, weights_array)

    Temporal integrity (Rule 1):
      sel_end = rebalance_date - 1 month  (no same-month data in selection)
      sel_start = first available month  (expanding window)
    """
    import time as _time
    print("  Building rebalance schedule (ClusterSharpe + MeanVar)...")
    t0 = _time.time()

    all_months = ls_returns.index.sort_values()
    rebal_points = all_months[::rebal_months]

    schedule = {}
    n_active = 0

    for rp in rebal_points:
        sel_end = rp - pd.DateOffset(months=1)
        sel_start = all_months[0]  # expanding window

        ls_avail = ls_returns.loc[:sel_end]
        train = ls_avail.loc[sel_start:]
        if len(train.dropna(how='all')) < 24:
            continue

        # ClusterSharpe factor selection
        factors = _cluster_sharpe_select(
            ls_avail, sel_start, sel_end, n_factors, min_obs)
        if not factors:
            continue

        # MeanVar weighting
        weights = _mean_var_weight(ls_avail, sel_start, sel_end, factors)

        schedule[rp] = (factors, weights)
        n_active += 1

    print(f"  Schedule: {n_active} active rebalance points "
          f"({_time.time()-t0:.1f}s)")

    # Log first and last few rebalance points
    dates = sorted(schedule.keys())
    if dates:
        print(f"  First: {dates[0].strftime('%Y-%m')}, "
              f"Last: {dates[-1].strftime('%Y-%m')}")
        for d in dates[:3]:
            f, w = schedule[d]
            print(f"    {d.strftime('%Y-%m')}: {f[:5]}... "
                  f"w=[{', '.join(f'{x:.3f}' for x in w[:5])}...]")

    return schedule


# ---------------------------------------------------------------------------
# Step 3: Compute stock-level portfolio weights (linear rank)
# ---------------------------------------------------------------------------
def compute_stock_weights(chars, schedule):
    """
    For each month, look up the applicable rebalance schedule, then assign
    stock-level weights using linear rank weighting.

    For factor k with portfolio weight w_k:
      stock i gets  w_k * (rank_pct_i - 0.5) / Z_k
    where Z_k = sum_j |rank_pct_j - 0.5| for all stocks with valid rank
    on factor k in that month.

    This exactly replicates the portfolio-level linear LS return:
      sum_i(stock_w_i * ret_i) = sum_k(w_k * LS_k)

    Uses all stocks (not just those with ret_exc_lead1m) to avoid look-ahead
    bias: at time t we don't know which stocks will have valid returns. (Rule 1)
    """
    import time as _time
    print("  Computing stock-level weights (linear rank)...")
    t0 = _time.time()

    rebal_dates = sorted(schedule.keys())
    if not rebal_dates:
        return pd.DataFrame(columns=["id", "eom", "w"])

    # Gather all factors that appear in any rebalance
    all_factors = set()
    for factors, _ in schedule.values():
        all_factors.update(factors)
    all_factors = sorted(all_factors)

    # Work with required columns only
    available = [f for f in all_factors if f in chars.columns]
    cols = ["id", "eom"] + available
    df_work = chars[cols].copy()
    months = sorted(df_work["eom"].unique())

    results = []
    active_months = 0

    for m_idx, eom in enumerate(months):
        # Find most recent rebalance on or before this month
        applicable = [d for d in rebal_dates if d <= eom]
        if not applicable:
            continue
        factors, weights = schedule[applicable[-1]]
        if not factors:
            continue

        # Get stocks for this month
        mask = df_work["eom"] == eom
        month_data = df_work.loc[mask]
        n_stocks = len(month_data)
        if n_stocks < 50:
            continue

        ids = month_data["id"].values
        stock_w = np.zeros(n_stocks, dtype=np.float64)
        n_active_factors = 0

        for k, (factor, fw) in enumerate(zip(factors, weights)):
            if factor not in month_data.columns:
                continue

            # Percentile rank within this month
            ranks = month_data[factor].rank(pct=True).values

            # Linear weight: rank - 0.5 (range: -0.5 to +0.5)
            nan_mask = np.isnan(ranks)
            rank_w = np.where(nan_mask, 0.0, ranks - 0.5)

            # Normalizer: sum of absolute rank weights
            abs_sum = np.abs(rank_w).sum()
            if abs_sum < 1e-10:
                continue

            n_valid = (~nan_mask).sum()
            if n_valid < MIN_STOCKS * 2:
                continue

            # Factor contribution = fw * rank_w / abs_sum
            stock_w += fw * rank_w / abs_sum
            n_active_factors += 1

        # Collect non-zero weights
        nonzero = np.abs(stock_w) > 1e-15
        if nonzero.any():
            results.append(pd.DataFrame({
                "id": ids[nonzero],
                "eom": eom,
                "w": stock_w[nonzero],
            }))
            active_months += 1

        if (m_idx + 1) % 200 == 0:
            print(f"    {m_idx+1}/{len(months)} months ({_time.time()-t0:.1f}s)")

    if not results:
        return pd.DataFrame(columns=["id", "eom", "w"])

    output = pd.concat(results, ignore_index=True)
    elapsed = _time.time() - t0
    nm = output["eom"].nunique()
    n_long = (output["w"] > 0).sum()
    n_short = (output["w"] < 0).sum()
    print(f"  Weights: {len(output):,} rows, {nm} months, "
          f"{output['id'].nunique():,} unique stocks ({elapsed:.1f}s)")
    print(f"  Avg long/month: {n_long/max(nm,1):.0f}, "
          f"avg short/month: {n_short/max(nm,1):.0f}")
    return output


# ===================================================================
# main() -- CTF entry point (Rules 11-12)
# ===================================================================
def main(chars: pd.DataFrame,
         features: pd.DataFrame,
         daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    CTF-compliant strategy entry point.

    Args:
        chars:     Stock characteristics (ctff_chars.parquet)
        features:  Feature list (ctff_features.parquet)
        daily_ret: Daily returns (ctff_daily_ret.parquet) -- not used

    Returns:
        DataFrame with columns: id, eom, w  (no missing values)
    """
    import time
    t0 = time.time()

    # --- Prepare data ---
    feature_names = features["features"].tolist()
    chars = chars.copy()
    chars["eom"] = pd.to_datetime(chars["eom"])
    n_months = chars["eom"].nunique()
    print(f"Data: {len(chars):,} stock-months, {len(feature_names)} features, "
          f"{n_months} months")

    # --- Adapt thresholds for small datasets (validation mode = 123 months) ---
    if n_months < 200:
        min_obs = max(6, n_months // 5)
        min_stocks = 5
        rebal = max(3, REBAL_MONTHS)
        n_fac = min(N_FACTORS, 5)
        print(f"  Small dataset mode: {n_months} months, min_obs={min_obs}, "
              f"n_factors={n_fac}")
    else:
        min_obs = MIN_OBS
        min_stocks = MIN_STOCKS
        rebal = REBAL_MONTHS
        n_fac = N_FACTORS

    # --- Step 1: Compute factor LS returns from stock-level data ---
    ls_returns = compute_factor_ls_returns(
        chars, feature_names, min_stocks, min_obs)

    # --- Step 2: Build rolling rebalance schedule ---
    schedule = build_rebalance_schedule(
        ls_returns, rebal, n_fac, min_obs)

    # --- Step 3: Compute stock-level weights ---
    output = compute_stock_weights(chars, schedule)

    # --- Validate output (Rule 12) ---
    assert set(output.columns) == {"id", "eom", "w"}, "Bad columns"
    assert output["w"].notna().all(), "NaN weights"
    assert len(output) > 0, "Empty output"

    elapsed = time.time() - t0
    print(f"\nTotal time: {elapsed:.1f}s")
    print(f"Output: {len(output):,} rows")
    return output
