# =============================================================================
# Stock Return Prediction - Common Task Framework (Simplified Version)
# =============================================================================
#
# CHANGES FROM ORIGINAL (CTF_StockReturn_Prediction_Submission.py):
# - Changed main() signature from no args to main(chars, features, daily_ret)
# - Replaced pd.read_parquet() calls with function arguments
# - Fixed features column name: 'feature' -> 'features'
# - Replaced date-based train/test split with ctff_test column filtering
# - Changed output from file write (to_csv) to return statement
# - Removed if __name__ == "__main__" block
#
# =============================================================================

import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler

plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style('whitegrid')


def main(chars, features, daily_ret):
    print("="*80)
    print("📊 Stock Return Prediction System - CTF Competition")
    print("="*80)
    print(f"⏰ Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

    # =============================================================================
    # Step 1: Data Loading
    # =============================================================================
    print("📁 [1/6] Loading Data...")
    all_features = features['features'].tolist()
    print(f"   ✓ Number of Features: {len(all_features)}")

    df = chars.copy()
    print(f"   ✓ Data Shape: {df.shape}")
    print(f"   ✓ Date Range: {df['eom'].min()} to {df['eom'].max()}")

    # =============================================================================
    # Step 2: Data Preprocessing
    # =============================================================================
    print("\n🔧 [2/6] Data Preprocessing...")

    df['eom'] = pd.to_datetime(df['eom'])
    train_data = df[df['ctff_test'] != 1].copy()
    test_data = df[df['ctff_test'] == 1].copy()

    print(f"   ✓ Training Set: {train_data['eom'].min().date()} to {train_data['eom'].max().date()}")
    print(f"   ✓ Test Set: {test_data['eom'].min().date()} to {test_data['eom'].max().date()}")
    print(f"   ✓ Training Samples: {len(train_data):,} | Test Samples: {len(test_data):,}")

    print("   - Selecting valid features...")
    missing_rate = train_data[all_features].isnull().mean()
    valid_features = missing_rate[missing_rate < 0.5].index.tolist()
    print(f"   ✓ Retained Features: {len(valid_features)}/{len(all_features)}")

    print("   - Filling missing values...")
    for col in valid_features:
        median_val = train_data[col].median()
        train_data[col].fillna(median_val, inplace=True)
        test_data[col].fillna(median_val, inplace=True)

    train_data = train_data.dropna(subset=['ret_exc_lead1m'])
    test_data = test_data.dropna(subset=['ret_exc_lead1m'])
    print(f"   ✓ Cleaned - Training: {len(train_data):,} | Test: {len(test_data):,}")

    # =============================================================================
    # Step 3: Model 1 - Simple Factor Model
    # =============================================================================
    print("\n📈 [3/6] Model 1: Simple Factor Model...")

    def simple_factor_strategy(data, features):
        data = data.copy()
        factor_keywords = ['size', 'market', 'equity', 'book', 'value', 'momentum',
                           'ret', 'profit', 'roa', 'roe']
        selected_features = [f for f in features if any(kw in f.lower() for kw in factor_keywords)][:15]
        print(f"   - Using {len(selected_features)} factor features")

        for feat in selected_features:
            data[f'{feat}_rank'] = data.groupby('eom')[feat].rank(pct=True)
        rank_cols = [f'{f}_rank' for f in selected_features]
        data['signal'] = data[rank_cols].mean(axis=1)

        portfolio_returns = []
        for date in data['eom'].unique():
            date_data = data[data['eom'] == date].copy()
            date_data = date_data.sort_values('signal')
            n = len(date_data)
            long_stocks = date_data.tail(int(n * 0.2))
            short_stocks = date_data.head(int(n * 0.2))
            long_ret = long_stocks['ret_exc_lead1m'].mean()
            short_ret = short_stocks['ret_exc_lead1m'].mean()
            portfolio_returns.append({'date': date, 'return': long_ret - short_ret})
        return pd.DataFrame(portfolio_returns)

    train_returns_model1 = simple_factor_strategy(train_data, valid_features)
    test_returns_model1 = simple_factor_strategy(test_data, valid_features)
    print("   ✓ Completed")

    # =============================================================================
    # Step 4: Model 2 - Random Forest
    # =============================================================================
    print("\n🤖 [4/6] Model 2: Machine Learning Model (Random Forest)...")

    use_features = valid_features[:50]
    print(f"   - Using {len(use_features)} features")

    X_train = train_data[use_features].values
    y_train = train_data['ret_exc_lead1m'].values
    X_test = test_data[use_features].values
    y_test = test_data['ret_exc_lead1m'].values

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    print("   - Training Random Forest model...")
    rf_model = RandomForestRegressor(
        n_estimators=50, max_depth=5, min_samples_split=100,
        min_samples_leaf=50, random_state=42, n_jobs=-1
    )
    rf_model.fit(X_train_scaled, y_train)
    print("   ✓ Training completed")

    test_data['pred_return'] = rf_model.predict(X_test_scaled)

    # =============================================================================
    # Step 5: MVO Ensemble
    # =============================================================================
    print("\n🎯 [5/6] MVO Ensemble...")

    mean_returns = np.array([0.01, 0.02])
    cov_matrix = np.array([[0.0025, 0.001], [0.001, 0.002]])

    inv_cov = np.linalg.inv(cov_matrix)
    weights = inv_cov @ mean_returns
    weights = np.abs(weights) / np.abs(weights).sum()
    print(f"   ✓ Model1 Weight: {weights[0]:.3f}")
    print(f"   ✓ Model2 Weight: {weights[1]:.3f}")

    # =============================================================================
    # Step 6: Generate Output File (id, eom, w)
    # =============================================================================
    print("\n💾 [6/6] Generating Required Output File (id, eom, w)...")
    output_df = test_data[['id', 'eom']].copy()
    output_df['w'] = (test_data['pred_return'] - test_data['pred_return'].mean()) / test_data['pred_return'].std()

    print("\n✅ All Tasks Completed!")
    print(f"⏰ End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*80)

    return output_df
