import pandas as pd
import numpy as np

def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    Constructs a Minimum Variance Portfolio (MVP) using dynamic
    Ledoit-Wolf-style shrinkage and leverage constraints.

    Args:
        chars: Stock characteristics (ctff_chars.parquet)
        features: Computed features (ctff_features.parquet)
        daily_ret: Historical daily returns (ctff_daily_ret.parquet)

    Returns:
        DataFrame with columns: id, eom, w
    """

    # --- Configuration Parameters ---
    H_min = 14          # Minimum lookback for covariance estimation
    H_max = 252         # Maximum lookback (1 trading year)
    N_min = 42          # Scaling factor for pairwise observation confidence
    N_low = 14          # Minimum overlap threshold for pairwise observations
    lambda_max = 0.8    # Upper bound for the shrinkage intensity
    ctff_test_data = True  # Only look at stocks to be tested

    def mvp_weights(sigma, max_leverage=3.0):
        """
        Calculates weights for the Minimum Variance Portfolio.
        Uses pseudoinverse as a fallback for singular matrices.
        Includes a shrinkage-to-equal-weight mechanism to enforce leverage constraints.
        """
        n = sigma.shape[0]
        ones = np.ones(n)

        try:
            # Efficiently solve for Sigma * x = 1 instead of explicit inversion
            x = np.linalg.solve(sigma, ones)
        except np.linalg.LinAlgError:
            # Moore-Penrose pseudoinverse handles singular/ill-conditioned matrices
            sigma_pinv = np.linalg.pinv(sigma)
            x = np.dot(sigma_pinv, ones)

        x_sum = np.sum(x)

        # Handle edge case where the system yields a near-zero denominator
        if x_sum < 1e-8:
            ew = np.ones(n) / n
            return ew

        weights = x / x_sum
        current_leverage = np.abs(weights).sum()

        if current_leverage > max_leverage:
            # Linear interpolation (shrinkage) between MVP and Equal Weight (EW)
            # to satisfy the max_leverage constraint exactly.
            alpha = (max_leverage - 1) / (current_leverage - 1)
            ew = np.ones(n) / n
            return alpha * weights + (1 - alpha) * ew

        return weights

    def identity_target(S):
        """
        Generates a target shrinkage matrix (T) using the Identity matrix
        scaled by the average variance of assets.
        """
        vars_diag = np.diag(S)
        avg_var = np.nanmean(vars_diag)

        if np.isnan(avg_var):
            avg_var = 1.0
        return np.eye(S.shape[0]) * avg_var

    def run_monthly_rebalance_dynamic(
            log_returns: pd.DataFrame,
            universe_map: dict,
            H_min: int,
            H_max: int,
            N_min: int,
            N_low: int,
            lambda_max: float,
            target_func
    ) -> pd.DataFrame:
        """
        Iterates through monthly rebalance dates, selecting the valid universe
        from universe_map and calculating MVP weights.
        """
        # Identify rebalance dates (Month Ends) based on available return data
        rebalance_dates = log_returns.index[H_min:].to_series().resample('ME').max()
        results = []

        for date in rebalance_dates:
            # --- Data Slicing & Filtering ---
            # Look up valid universe for this specific month
            lookup_key = (date + pd.offsets.MonthEnd(0)).date()
            candidate_ids = universe_map.get(lookup_key, set())

            if not candidate_ids:
                continue

            # Slice Historical Data
            current_idx = log_returns.index.get_loc(date)
            start_idx = max(0, current_idx - H_max)

            # Intersection of "Stocks allowed this month" AND "Stocks with return data"
            valid_cols = list(candidate_ids.intersection(log_returns.columns))
            if len(valid_cols) < 2:
                continue

            # Slice history window
            hist_data = log_returns.iloc[start_idx: current_idx][valid_cols]

            # Filter for Minimum History (H_min)
            valid_counts = hist_data.count()
            valid_stocks = valid_counts[valid_counts >= H_min].index
            if len(valid_stocks) < 2:
                continue

            sub_returns = hist_data[valid_stocks].copy()

            # --- Data Cleaning ---
            # Remove stale price sequences (consecutive zeros) which bias volatility downwards
            is_zero = abs(sub_returns) < 1e-8
            is_streak = (is_zero & is_zero.shift(1)) | (is_zero & is_zero.shift(-1))
            sub_returns[is_streak] = np.nan

            # --- Dynamic Shrinkage Intensity (lambda) ---
            n_obs, p_vars = sub_returns.shape
            lambda_conf = min(np.sqrt(n_obs / p_vars), lambda_max)

            # --- Observation Weighting Matrix ---
            mask = (abs(sub_returns) > 1e-8).astype(int)
            count_matrix = mask.T @ mask
            obs_weight = ((count_matrix - N_low) / N_min).clip(0, 1)

            # --- Covariance Estimation ---
            # Compute Sample (S) and Target (T) matrices (monthly horizon)
            sub_returns = sub_returns.fillna(sub_returns.mean())
            S = sub_returns.cov().values * 21
            T = target_func(S)

            # Fallbacks for NaNs in Target or Sample
            if np.isnan(T).any():
                avg_var = np.nanmean(np.diag(S)) if not np.all(np.isnan(np.diag(S))) else 1.0
                T = np.where(np.isnan(T), np.eye(len(valid_stocks)) * avg_var, T)
            if np.isnan(S).any():
                S = np.where(np.isnan(S), T, S)

            # --- Shrink Covariance Matrix ---
            eff_weight = lambda_conf * obs_weight
            Sigma_shrinked = T + eff_weight * (S - T)

            # --- Calculate MVP weights ---
            try:
                w = mvp_weights(Sigma_shrinked)
            except np.linalg.LinAlgError:
                print(f"\rError on {date.date()}: Fallback to equal weights")
                w = np.ones(len(valid_stocks)) / len(valid_stocks)

            # Store results
            results.append(pd.DataFrame({
                'eom': date,
                'id': valid_stocks,
                'w': w
            }))

            print(f"\rRebalanced {date.date()}: {len(valid_stocks)} stocks (Lambda: {lambda_conf:.2f})", end="")

        if results:
            return pd.concat(results, ignore_index=True)
        return pd.DataFrame(columns=['eom', 'id', 'w'])


    # --- Main Execution Setup ---
    chars['eom'] = pd.to_datetime(chars['eom'])
    daily_ret['date'] = pd.to_datetime(daily_ret['date'])

    # Prepare Universe Map
    if ctff_test_data and 'ctff_test' in chars.columns:
        valid_universe = chars[chars['ctff_test'] == True].copy()
    else:
        valid_universe = daily_ret[['id', 'date']].copy()
        valid_universe['eom'] = valid_universe['date'] + pd.offsets.MonthEnd(0)
        valid_universe = valid_universe[['id', 'eom']].drop_duplicates()
    valid_universe['eom'] = (pd.to_datetime(valid_universe['eom']) + pd.offsets.MonthEnd(0)).dt.date

    # Lookup dictionary {Date -> Set(IDs)}
    universe_map = valid_universe.groupby('eom')['id'].apply(set).to_dict()

    # Prepare Returns Data
    daily_ret['ret_exc'] = daily_ret['ret_exc'].clip(lower=-0.999)
    daily_ret_wide = daily_ret.pivot(index="date", columns="id", values="ret_exc").sort_index()
    log_returns = np.log1p(daily_ret_wide)

    # Run Rebalance
    pf_weights = run_monthly_rebalance_dynamic(
        log_returns=log_returns,
        universe_map=universe_map,
        H_min=H_min,
        H_max=H_max,
        N_min=N_min,
        N_low=N_low,
        lambda_max=lambda_max,
        target_func=identity_target,
    )

    # Final Formatting
    pf_weights['eom'] = pd.to_datetime(pf_weights['eom']) + pd.offsets.MonthEnd(0)
    pf_weights['id'] = pf_weights['id'].astype(int)
    pf_weights['w'] = pf_weights['w'].astype(float)

    return pf_weights[['id', 'eom', 'w']]