# %%
# Method: MVO portfolio of IPCA factors (Kelly, Pruitt, Su (JFE, 2019))
# PARALLELIZED VERSION using joblib for multi-core utilization
import os
import sys
import time
from datetime import timedelta
from pathlib import Path

import numpy as np
import pandas as pd
from ipca import InstrumentedPCA
from joblib import Parallel, delayed

# Data caching and output buffering are handled by the infrastructure 
# This allows user code to remain simple while benefiting from I/O optimizations


def ecdf(data: pd.Series) -> pd.Series:
    """Example helper function for ecdf."""
    if data.empty:
        return data

    sorted_data = data.sort_values()
    ranks = sorted_data.rank(method="max", pct=True)
    cdf_values = ranks

    return pd.Series(cdf_values, index=data.index)


def prepare_data(chars: pd.DataFrame, features: pd.Series, eom: str) -> pd.DataFrame:
    """Example helper function to apply an ECDF transformation grouped by 'eom'."""
    for feature in features:
        is_zero = chars[feature] == 0  # Preserve zeros
        chars[feature] = chars.groupby(eom)[feature].transform(lambda x: ecdf(x))
        chars.loc[is_zero, feature] = 0  # Restore zeros
        chars[feature].fillna(0.5, inplace=True)  # Impute missing values

    chars[features] -= 0.5
    return chars


def process_single_date(
    d,
    chars: pd.DataFrame,
    features: list[str],
    window_length: int,
    date_idx: int = None,
    total_dates: int = None,
) -> tuple:
    """Process a single date for IPCA - designed to be parallelized.

    This function is called in parallel for each date, enabling multi-core utilization.
    Returns: (result_df, date_idx, elapsed_time) for main process logging
    """
    start_time = time.time()

    try:
        chars_train = chars[
            (chars["eom_ret"] <= d)
            & (
                chars["eom_ret"]
                >= (
                    d
                    + timedelta(days=1)
                    - pd.DateOffset(months=window_length)
                    - timedelta(days=1)
                )
            )
        ]
        train = chars_train[chars_train["eom_ret"] < d][
            ["id", "eom", "ret_exc_lead1m", *features]
        ]
        train = train.set_index(["id", "eom"])
        py_X = train.drop(columns=["ret_exc_lead1m"])
        py_y = train["ret_exc_lead1m"]

        # InstrumentedPCA with n_jobs=1 since we're already parallelizing at the date level
        # Run with 3 factors for now, to test completion (runs 4-5x faster)
        regr = InstrumentedPCA(n_factors=5, intercept=False, n_jobs=1)
        regr.fit(py_X, py_y, quiet=True)
        gamma = np.array(regr.get_factors(label_ind=True)[0])
        fct_ret = regr.get_factors(label_ind=True)[1].T

        mu = fct_ret.mean()
        inv = np.linalg.inv(fct_ret.cov())
        wf = inv @ np.array(mu)

        X_t = np.array(chars.loc[chars["eom_ret"] == d, features])
        w = X_t @ gamma @ np.linalg.inv(gamma.T @ X_t.T @ X_t @ gamma) @ wf

        # Extract raw data for faster serialization (avoid returning DataFrame)
        date_slice = chars.loc[chars["eom_ret"] == d, ["id", "eom"]]
        ids = date_slice["id"].tolist()
        eoms = date_slice["eom"].tolist()
        weights = w.tolist()

        elapsed = time.time() - start_time

        # Write progress marker file (atomic, works across process boundaries)
        progress_file = f"/dev/shm/ipca_progress/{d}.done"
        # [infrastructure directive]
        with Path(progress_file).open("w") as f:
            f.write(f"completed in {elapsed:.1f}s\n")

        # Return raw data instead of DataFrame for faster serialization
        return (ids, eoms, weights, date_idx, elapsed, d)

    except Exception as e:
        print(
            f"[ERROR] process_single_date: Failed processing date {d} "
            f"(index {date_idx}/{total_dates}): {type(e).__name__}: {e!s}",
            flush=True,
        )
        import traceback

        traceback.print_exc()
        raise


def run_ipca(
    chars: pd.DataFrame, features: list[str], window_length: int
) -> pd.DataFrame:
    """Run IPCA in parallel."""
    print("[DEBUG] run_ipca: Starting PARALLELIZED function", flush=True)
    sys.stdout.flush()

    pf_dates = chars.loc[chars.ctff_test, "eom_ret"].sort_values().unique()
    print(f"[DEBUG] run_ipca: Found {len(pf_dates)} portfolio dates", flush=True)
    sys.stdout.flush()

    # Get number of joblib workers from environment (separate from BLAS threading)
    # CTF_N_JOBS: Number of parallel joblib workers for processing dates
    # OMP_NUM_THREADS/OPENBLAS_NUM_THREADS: BLAS threads per worker (set to 1 to prevent oversubscription)
    n_jobs = int(os.environ.get("CTF_N_JOBS", -1))
    print(
        f"[DEBUG] run_ipca: Using {n_jobs} parallel jobs "
        f"({'all cores' if n_jobs == -1 else f'{n_jobs} cores'})",
        flush=True,
    )
    sys.stdout.flush()

    # Create progress tracking directory
    # [infrastructure directive]
    progress_dir = "/dev/shm/ipca_progress"
    Path(progress_dir).mkdir(parents=True, exist_ok=True)
    print(
        f"[IPCA-PROGRESS] Progress marker files will be written to {progress_dir}",
        flush=True,
    )
    print(f"[IPCA-PROGRESS] Monitor with: ls {progress_dir} | wc -l", flush=True)
    sys.stdout.flush()

    print("[DEBUG] run_ipca: Entering PARALLEL IPCA processing", flush=True)
    sys.stdout.flush()

    total = len(pf_dates)

    # Get joblib backend from environment (threading for memory efficiency, loky for isolation)
    # JOBLIB_BACKEND=threading: Shares memory across workers (prevents OOM with many workers)
    # JOBLIB_BACKEND=loky: Process isolation (more memory but better fault tolerance)
    backend = os.environ.get("JOBLIB_BACKEND", "threading")
    print(
        f"[IPCA-PROGRESS] Starting parallel processing of {total} dates "
        f"with {n_jobs} workers ({backend} backend)",
        flush=True,
    )
    print("[IPCA-PROGRESS] Processing ALL dates at once for optimal load balancing", flush=True)
    sys.stdout.flush()

    try:
        # Process ALL dates at once - joblib handles work distribution dynamically
        # This eliminates idle workers and maximizes CPU utilization
        all_results = Parallel(n_jobs=n_jobs, verbose=0, backend=backend)(
            delayed(process_single_date)(
                d, chars, features, window_length, idx, total
            )
            for idx, d in enumerate(pf_dates)
        )

        # Extract results from all dates
        all_ids = []
        all_eoms = []
        all_weights = []

        for ids, eoms, weights, date_idx, elapsed, date in all_results:
            all_ids.extend(ids)
            all_eoms.extend(eoms)
            all_weights.extend(weights)
            print(
                f"[IPCA-PROGRESS] Completed date {date_idx + 1}/{total}: {date} "
                f"in {elapsed:.1f}s",
                flush=True,
            )
            sys.stdout.flush()

        # Count completed dates from progress directory
        completed_count = len(list(Path(progress_dir).iterdir()))
        print(
            f"[IPCA-PROGRESS] All dates processed! "
            f"Collected {len(all_ids)} total rows",
            flush=True,
        )
        print(
            f"[IPCA-PROGRESS] Progress: {completed_count}/{total} dates completed "
            f"({100 * completed_count / total:.1f}%)",
            flush=True,
        )
        sys.stdout.flush()

    except Exception as e:
        print(
            f"[ERROR] run_ipca: Processing failed: {type(e).__name__}: {e!s}",
            flush=True,
        )
        import traceback

        traceback.print_exc()
        raise

    print(
        f"[DEBUG] run_ipca: All dates processed! Collected {len(all_ids)} total rows "
        f"from {total} dates",
        flush=True,
    )
    sys.stdout.flush()

    try:
        print("[DEBUG] run_ipca: Building final DataFrame from raw data...", flush=True)
        sys.stdout.flush()

        # Build final DataFrame from accumulated raw data
        final_results = pd.DataFrame({"id": all_ids, "eom": all_eoms, "w": all_weights})

        print(
            f"[DEBUG] run_ipca: DataFrame built! Result shape: {final_results.shape}",
            flush=True,
        )
        sys.stdout.flush()
    except Exception as e:
        print(
            f"[ERROR] run_ipca: Failed to build DataFrame: {type(e).__name__}: {e!s}",
            flush=True,
        )
        print(f"[ERROR] Rows collected: {len(all_ids)}", flush=True)
        import traceback

        traceback.print_exc()
        raise

    print("[DEBUG] run_ipca: About to return final_results...", flush=True)
    sys.stdout.flush()

    return final_results


def main(
    chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame  # noqa: ARG001
) -> pd.DataFrame:
    """Main function to load packages, prepare data, train model, and calculate portfolio weights.

    PARALLELIZED VERSION:
    - Uses joblib to parallelize date processing across all available CPU cores
    - Infrastructure automatically handles data caching and output buffering for optimal I/O

    Args:
        chars (pd.DataFrame): DataFrame containing characteristics data.
        features (pd.DataFrame): DataFrame containing feature names.
        daily_ret (pd.DataFrame): DataFrame containing daily returns data.

    Returns:
        pd.DataFrame: DataFrame with columns 'id', eom, and 'w'.
    """
    print("[DEBUG] main: Starting PARALLELIZED main function")

    # Extract feature names from DataFrame (assuming column name is 'features')
    feature_list = features["features"].tolist()
    print(f"[DEBUG] main: Extracted {len(feature_list)} features")

    # Convert date columns from strings to datetime objects (WRDS format: YYYY-MM-DD)
    print("[DEBUG] main: Converting date columns to datetime...")
    chars["eom"] = pd.to_datetime(chars["eom"])
    chars["eom_ret"] = pd.to_datetime(chars["eom_ret"])
    print("[DEBUG] main: Date columns converted")

    # Prepare the data
    print("[DEBUG] main: About to call prepare_data...")
    chars = prepare_data(chars, feature_list, "eom")
    print("[DEBUG] main: prepare_data completed")

    # Run IPCA with parallel processing
    print("[DEBUG] main: About to call run_ipca (PARALLEL)...")
    pf_ipca = run_ipca(chars, feature_list, window_length=120)
    print(f"[DEBUG] main: run_ipca returned! Result shape: {pf_ipca.shape}")

    # Output
    print("[DEBUG] main: About to return pf_ipca...")
    return pf_ipca


if __name__ == "__main__":
    # This block is NOT used in the CTF infrastructure
    # The CTF infrastructure calls main() directly via [runner]
    # Data caching and output buffering are handled by [runner]
    # This block is kept for local testing only

    import pandas as pd

    # For local testing: load data directly
    chars = pd.read_parquet("/data/ctff_chars.parquet")
    features = pd.read_parquet("/data/ctff_features.parquet")
    daily_ret = pd.read_parquet("/data/ctff_daily_ret.parquet")

    # Run the parallelized IPCA algorithm
    output = main(chars, features, daily_ret)

    # Write output (infrastructure would handle buffering in production)
    output.to_csv("/outputs/output.csv", index=False)
    print("IPCA BENCHMARK COMPLETED")
