import pandas as pd
import numpy as np
import time

def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    Constructs a Minimum Variance Portfolio (MVP) using dynamic
    Ledoit-Wolf-style shrinkage and leverage constraints.

    Args:
        chars: Stock characteristics (ctff_chars.parquet)
        features: Computed features (ctff_features.parquet)
        daily_ret: Historical daily returns (ctff_daily_ret.parquet)

    Returns:
        DataFrame with columns: id, eom, w
    """

    # --- Configuration Parameters ---
    H_min = 14          # Minimum lookback for covariance estimation
    H_max = 252         # Maximum lookback (1 trading year)
    N_min = 42          # Scaling factor for pairwise observation confidence
    N_low = 14          # Minimum overlap threshold for pairwise observations
    lambda_max = 0.8    # Upper bound for the shrinkage intensity

    # --- Debug: Entry Point ---
    print(f"[DEBUG] main: Starting MVP calculation", flush=True)
    print(f"[DEBUG] main: chars shape: {chars.shape}", flush=True)
    print(f"[DEBUG] main: features shape: {features.shape}", flush=True)
    print(f"[DEBUG] main: daily_ret shape: {daily_ret.shape}", flush=True)
    print(f"[DEBUG] main: Config - H_min={H_min}, H_max={H_max}, N_min={N_min}, N_low={N_low}, lambda_max={lambda_max}", flush=True)

    start_time = time.time()

    def mvp_weights(sigma, max_leverage=3.0):
        """
        Calculates weights for the Minimum Variance Portfolio.
        Uses pseudoinverse as a fallback for singular matrices.
        Includes a shrinkage-to-equal-weight mechanism to enforce leverage constraints.
        """
        n = sigma.shape[0]
        ones = np.ones(n)

        try:
            # Efficiently solve for Sigma * x = 1 instead of explicit inversion
            x = np.linalg.solve(sigma, ones)
        except np.linalg.LinAlgError:
            # Moore-Penrose pseudoinverse handles singular/ill-conditioned matrices
            sigma_pinv = np.linalg.pinv(sigma)
            x = np.dot(sigma_pinv, ones)

        x_sum = np.sum(x)

        # Handle edge case where the system yields a near-zero denominator
        if x_sum < 1e-8:
            ew = np.ones(n) / n
            return ew

        weights = x / x_sum
        current_leverage = np.abs(weights).sum()

        if current_leverage > max_leverage:
            # Linear interpolation (shrinkage) between MVP and Equal Weight (EW)
            # to satisfy the max_leverage constraint exactly.
            alpha = (max_leverage - 1) / (current_leverage - 1)
            ew = np.ones(n) / n
            return alpha * weights + (1 - alpha) * ew

        return weights


    def identity_target(S):
        """
        Generates a target shrinkage matrix (T) using the Identity matrix
        scaled by the average variance of assets.
        """
        vars_diag = np.diag(S)
        avg_var = np.nanmean(vars_diag)

        if np.isnan(avg_var):
            avg_var = 1.0
        return np.eye(S.shape[0]) * avg_var


    def run_monthly_rebalance_dynamic(
            data_wide: pd.DataFrame,
            H_min: int,
            H_max: int,
            N_min: int,
            N_low: int,
            lambda_max: float,
            target_func,
            loop_start_time: float
    ) -> pd.DataFrame:

        log_returns = np.log1p(data_wide)
        rebalance_dates = data_wide.index[H_min:].to_series().resample('ME').max()
        results = []

        total_dates = len(rebalance_dates)
        print(f"[PROGRESS] Starting rebalancing for {total_dates} dates", flush=True)

        for idx, date in enumerate(rebalance_dates, 1):
            # --- Data Slicing & Filtering ---
            current_idx = data_wide.index.get_loc(date)
            hist_data = log_returns.iloc[max(0, current_idx - H_max): current_idx]

            if len(hist_data) < H_min:
                continue

            valid_counts = hist_data.count()
            valid_stocks = valid_counts[valid_counts >= H_min].index

            if len(valid_stocks) < 2:
                continue

            sub_returns = hist_data[valid_stocks].copy()

            # --- Data Cleaning ---
            # Remove stale price sequences (consecutive zeros) which bias volatility downwards
            is_zero = abs(sub_returns) < 1e-8
            is_streak = (is_zero & is_zero.shift(1)) | (is_zero & is_zero.shift(-1))
            sub_returns[is_streak] = np.nan

            # --- Dynamic Shrinkage Intensity (lambda) ---
            n_obs, p_vars = sub_returns.shape
            lambda_conf = min(np.sqrt(n_obs / p_vars), lambda_max)

            # --- Observation Weighting Matrix ---
            mask = (abs(sub_returns) > 1e-8).astype(int)
            count_matrix = mask.T @ mask
            obs_weight = ((count_matrix - N_low) / N_min).clip(0, 1)

            # --- Covariance Estimation ---
            # Compute Sample (S) and Target (T) matrices (monthly horizon)
            sub_returns = sub_returns.fillna(sub_returns.mean())
            S = sub_returns.cov().values * 21
            T = target_func(S)

            # Fallbacks for NaNs in Target or Sample
            if np.isnan(T).any():
                avg_var = np.nanmean(np.diag(S)) if not np.all(np.isnan(np.diag(S))) else 1.0
                T = np.where(np.isnan(T), np.eye(len(valid_stocks)) * avg_var, T)
            if np.isnan(S).any():
                S = np.where(np.isnan(S), T, S)

            # --- Shrink Covariance Matrix ---
            eff_weight = lambda_conf * obs_weight
            Sigma_shrinked = T + eff_weight * (S - T)

            # --- Calculate MVP weights ---
            try:
                w = mvp_weights(Sigma_shrinked)
            except np.linalg.LinAlgError:
                print(f"[DEBUG] Error on {date.date()}: Fallback to equal weights", flush=True)
                w = np.ones(len(valid_stocks)) / len(valid_stocks)

            # Store results
            results.append(pd.DataFrame({
                'eom': date,
                'id': valid_stocks,
                'w': w
            }))

            # Progress reporting
            elapsed = time.time() - loop_start_time
            pct_complete = (idx / total_dates) * 100
            if idx > 1 and pct_complete > 0:
                eta_seconds = (elapsed / pct_complete) * (100 - pct_complete)
                if eta_seconds >= 60:
                    eta_str = f"ETA {eta_seconds / 60:.1f} min"
                else:
                    eta_str = f"ETA {eta_seconds:.0f}s"
            else:
                eta_str = "ETA calculating..."
            print(f"[PROGRESS] {idx}/{total_dates} ({pct_complete:.1f}%) - {date.date()} - {len(valid_stocks)} stocks - {elapsed:.1f}s elapsed - {eta_str}", flush=True)

        df_weights = pd.concat(results, ignore_index=True) if results else pd.DataFrame(columns=['eom', 'id', 'w'])
        return df_weights

    # --- Pre-processing & Execution ---
    print(f"[DEBUG] main: Clipping extreme returns...", flush=True)
    daily_ret["ret_exc"] = daily_ret["ret_exc"].clip(lower=-0.9999)
    print(f"[DEBUG] main: Pivoting to wide format...", flush=True)
    daily_ret_wide = daily_ret.pivot(index="date", columns="id", values="ret_exc").sort_index()
    daily_ret_wide.index = pd.to_datetime(daily_ret_wide.index)
    print(f"[DEBUG] main: daily_ret_wide shape: {daily_ret_wide.shape}", flush=True)

    print(f"[DEBUG] main: Starting rebalancing loop...", flush=True)
    pf_weights = run_monthly_rebalance_dynamic(
        daily_ret_wide,
        H_min=H_min,
        H_max=H_max,
        N_min=N_min,
        N_low=N_low,
        lambda_max=lambda_max,
        target_func=identity_target,
        loop_start_time=start_time,
    )
    # Post-processing for submission format
    pf_weights['eom'] = pd.to_datetime(pf_weights['eom']) + pd.offsets.MonthEnd(0)
    pf_weights['id'] = pf_weights['id'].astype(int)
    pf_weights['w'] = pf_weights['w'].astype(float)

    # Filter to test period only (dates where ctff_test == True)
    # This reduces output file size significantly for full-dataset runs
    test_dates = pd.to_datetime(chars.loc[chars['ctff_test'] == True, 'eom'].unique())
    pre_filter_count = len(pf_weights)
    pf_weights = pf_weights[pf_weights['eom'].isin(test_dates)]
    print(f"[DEBUG] main: Filtered to test period ({len(test_dates)} dates): {pre_filter_count} -> {len(pf_weights)} rows", flush=True)

    # Final summary
    total_elapsed = time.time() - start_time
    print(f"[DEBUG] main: Post-processing complete", flush=True)
    print(f"[DEBUG] main: Completed - {len(pf_weights)} total rows in {total_elapsed:.1f}s", flush=True)
    if len(pf_weights) > 0:
        date_range = f"{pf_weights['eom'].min().date()} to {pf_weights['eom'].max().date()}"
        print(f"[DEBUG] main: Date range: {date_range}", flush=True)

    return pf_weights[['id', 'eom', 'w']]
