##############################################################################
"""
PARSIMONIUS IMPUTED EXPANDING LASSO

CTF Admin Modifications (2026-02-16):
--------------------------------------
1. Removed Jupyter notebook magic commands (!uv init, !uv add)
   Reason: These cause SyntaxError when running as a Python script.

2. Removed unused imports (keyring, sqlalchemy, psycopg2-binary)
   Reason: Database/credential libraries not available in HPC environment.

3. Fixed main() signature: changed 'eom' parameter to 'daily_ret'
   Reason: CTF harness expects main(chars, features, daily_ret) signature.

4. Added [CTF-DEBUG] progress statements
   Reason: Enable HPC job monitoring with start/end timestamps.

5. Removed unnecessary dependencies from pyproject.toml
   Reason: keyring, sqlalchemy, psycopg2-binary, ipywidgets not needed for HPC execution.

6. Fixed feature extraction from features DataFrame
   Reason: features parameter is a DataFrame with 'features' column, not a Series.
           Code was doing list(DataFrame) which returns column names, not row values.
"""
##############################################################################

# Imports
import time
import os
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error
from sklearn.exceptions import ConvergenceWarning
from joblib import Parallel, delayed
import warnings

# Functions

def prepare_data(chars: pd.DataFrame, features: pd.Series, eom: str) -> pd.DataFrame:
    """
    Parallel month-by-month standardization and MICE imputation of firm characteristics.

    Args:
        chars: Stock characteristics.
        features: Series of feature column names.
        eom: String name of the date column.

    Returns:
        DataFrame: A processed version of chars with features standardized (mean 0, std 1) 
                   and missing values imputed via IterativeImputer (MICE).
    """
    n_jobs = max(1, (os.cpu_count() or 1) - 4)

    feat = features
    characteristics = chars.copy()
    characteristics[eom] = pd.to_datetime(characteristics[eom])

    characteristics = characteristics.sort_values([eom, "id"]).reset_index(drop=True)
    characteristics[feat] = characteristics[feat].astype("float64")

    def _impute_month(g):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            warnings.filterwarnings("ignore", message="Skipping features.*")

            mean = g[feat].mean()
            std = g[feat].std().replace(0.0, 1.0)
            X = (g[feat] - mean) / std

            cols_with_data = X.columns[X.notna().any()].tolist()
            if cols_with_data:
                imputer = IterativeImputer(
                    max_iter=5,
                    random_state=42,
                    sample_posterior=False,
                    initial_strategy="median",
                    skip_complete=True,
                    n_nearest_features=50,
                )
                X.loc[:, cols_with_data] = imputer.fit_transform(X[cols_with_data])

            X = X.fillna(0.0)
            X = X - X.mean()
            X = X / X.std().replace(0.0, 1.0)

            g.loc[:, feat] = X
            return g

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="A worker stopped while some jobs were given to the executor",
            category=UserWarning,
        )

        out = Parallel(
            n_jobs=n_jobs, 
            backend="loky", 
            max_nbytes="10M"
        )(
            delayed(_impute_month)(g)
            for _, g in tqdm(characteristics.groupby(eom),
                             desc=f"Parallel Imputation ({n_jobs} cores)")
        )

    characteristics = (
        pd.concat(out)
        .sort_values([eom, "id"])
        .reset_index(drop=True)
    )

    return characteristics

def lasso(train: pd.DataFrame, features: pd.Series, eom: str) -> Lasso:
    """
    Train a Lasso model using an expanding training set approach to optimize 
    the regularization parameter (alpha).

    Args:
        train: Training set containing stock characteristics and excess returns.
        features: Series of feature column names.
        eom: String name of the date column.

    Returns:
        Lasso: The final Lasso model fitted on the full training set 
               using the best alpha identified during walk-forward validation.
    """
    y_col = "ret_exc_lead1m"
    train = train.copy()
    feat = features
    
    train[y_col] = train.groupby(eom)[y_col].transform(
        lambda x: (x - x.mean()) / (x.std() if x.std() > 0 else 1.0)
    )

    all_months = sorted(train[eom].unique())
    alphas = np.logspace(-3, 0, 50)
    
    initial_train_size = 60 
    validation_step = 12
    alpha_performance = {a: [] for a in alphas}

    window_range = range(initial_train_size, len(all_months) - validation_step, validation_step)

    for i in tqdm(window_range, desc="Expanding Training Set for Hyperparameter"):
        tr_df = train[train[eom].isin(all_months[:i])]
        va_df = train[train[eom].isin(all_months[i : i + validation_step])]
        
        X_tr, y_tr = tr_df[feat].values, tr_df[y_col].values
        X_va, y_va = va_df[feat].values, va_df[y_col].values
        
        for a in alphas:
            model = Lasso(alpha=a, max_iter=5000, random_state=42)
            model.fit(X_tr, y_tr)
            mse = mean_squared_error(y_va, model.predict(X_va))
            alpha_performance[a].append(mse)

    avg_mse = {a: np.mean(scores) for a, scores in alpha_performance.items()}

    best_alpha = min(avg_mse, key=avg_mse.get)

    X_full, y_full = train[feat].values, train[y_col].values
    final_model = Lasso(alpha=best_alpha, fit_intercept=True, max_iter=10000, random_state=42)
    final_model.fit(X_full, y_full)

    return final_model

def predict_returns(model: Lasso, test: pd.DataFrame, features: pd.Series, eom: str) -> pd.DataFrame:
    """
    Predict standardized returns using the trained Lasso model.

    Args:
        model: The fitted Lasso model from the lasso() function.
        test: The test set DataFrame.
        features: Series of feature column names.
        eom: String name of the date column (e.g., 'eom').

    Returns:
        DataFrame: Contains columns 'id', 'eom', and 'pred_lead1m' with 
                   the out-of-sample predictions.
    """
    test = test.copy()
    feat = features
    
    test = test.sort_values([eom, "id"]).reset_index(drop=True)
    X = test[feat].to_numpy()
    
    preds = model.predict(X)

    lasso_predictions = test[["id", eom]].copy()
    lasso_predictions["pred_lead1m"] = preds.astype(float)
    return lasso_predictions

def create_ls_pf_w(pred: pd.DataFrame, test: pd.DataFrame) -> pd.DataFrame:
    """
    Build long-short portfolio weights based on predictions using a decile-based
    equal-weighting strategy.

    Args:
        pred: DataFrame containing predictions.
        test: The test set DataFrame containing stock identifiers and dates.

    Returns:
        DataFrame: A clean table with columns ['id', 'eom', 'w'], where 'w' 
                   represents the assigned portfolio weight (Long-Short).
    """
    n_deciles = 10

    pf_creation = pred.merge(test[['id', 'eom']], on=['id', 'eom'])

    pf_creation['decile'] = pf_creation.groupby('eom')['pred_lead1m'].transform(
        lambda x: pd.qcut(x, n_deciles, labels=False, duplicates='drop')
    )

    pf_creation['is_long'] = (pf_creation['decile'] == n_deciles - 1).astype(int)
    pf_creation['is_short'] = (pf_creation['decile'] == 0).astype(int)

    counts = pf_creation.groupby('eom')[['is_long', 'is_short']].transform('sum')
    
    pf_creation['w'] = 0.0
    pf_creation.loc[pf_creation['is_long'] == 1, 'w'] = 1.0 / counts['is_long']
    pf_creation.loc[pf_creation['is_short'] == 1, 'w'] = -1.0 / counts['is_short']

    portfolio_weights = pf_creation[['id', 'eom', 'w']].copy()
    
    return portfolio_weights

def main(chars: pd.DataFrame, features: pd.DataFrame, daily_ret: pd.DataFrame) -> pd.DataFrame:
    """
    Parsimonious Imputed Expanding Lasso:
    Main function to prepare data, train model, predict standardized returns,
    and calculate portfolio weights.

    Args:
        chars: Stock characteristics (ctff_chars.parquet).
        features: Computed features (ctff_features.parquet).
        daily_ret: Daily returns (ctff_daily_ret.parquet) - not used in this model.

    Returns:
        DataFrame: Final portfolio weights with columns ['id', 'eom', 'w'].

    """
    # CTF-FIX: Add start timestamp for HPC monitoring
    print(f"[CTF-DEBUG] Starting main() at {time.strftime('%Y-%m-%d %H:%M:%S')}", flush=True)
    start_time = time.time()

    eom = 'eom'
    characteristics = chars.copy()
    characteristics[eom] = pd.to_datetime(characteristics[eom])
    # CTF-FIX: Extract feature names from DataFrame (features has 'features' column containing names)
    if isinstance(features, pd.DataFrame):
        feat_list = features["features"].tolist()
    elif isinstance(features, pd.Series):
        feat_list = features.tolist()
    else:
        feat_list = list(features)

    print(f"[CTF-DEBUG] Preparing data with {len(feat_list)} features...", flush=True)
    characteristics_full = prepare_data(characteristics, feat_list, 'eom')

    train = characteristics_full[characteristics_full['ctff_test'] == False]
    test = characteristics_full[characteristics_full['ctff_test'] == True].copy()
    print(f"[CTF-DEBUG] Train: {len(train)} rows, Test: {len(test)} rows", flush=True)

    print("[CTF-DEBUG] Training Lasso model...", flush=True)
    model = lasso(train, feat_list, 'eom')

    print("[CTF-DEBUG] Predicting returns...", flush=True)
    lasso_pred = predict_returns(model, test, feat_list, 'eom')

    print("[CTF-DEBUG] Creating portfolio weights...", flush=True)
    pf_weights = create_ls_pf_w(lasso_pred, test)

    # CTF-FIX: Add completion summary for HPC monitoring
    elapsed = time.time() - start_time
    print(f"[CTF-DEBUG] Completed main() in {elapsed:.1f}s", flush=True)
    print(f"[CTF-DEBUG] Output: {len(pf_weights)} rows, {pf_weights['eom'].nunique()} unique months, "
          f"date range {pf_weights['eom'].min()} to {pf_weights['eom'].max()}", flush=True)

    return pf_weights
