# ==============================================================================
# MAXSER Portfolio Optimization Model
# ==============================================================================
#
# This script implements the MAXSER (Maximum Sharpe Ratio with Elastic Net
# Regularization) portfolio optimization algorithm. It constructs cross-sectional
# factor portfolios using Lasso regularization and computes optimal stock weights
# based on factor exposures.
#
# CTF Compatibility Notes:
# - The training window (T) is adaptive: defaults to 360 months but automatically
#   reduces to (available_months - 1) when insufficient historical data exists.
#   This change was made to support CTF validation datasets that may have fewer
#   than 360 months of data, while preserving the original algorithm behavior
#   when sufficient data is available.
# - Note: The original 360-month window reflects the authors' intended training
#   period. The adaptive reduction is solely to pass CTF validation; it does not
#   imply that shorter windows produce statistically meaningful results.
# - Output is filtered to only include rows where eom is in the test period
#   (ctff_test == TRUE). The original algorithm computed weights for all dates
#   including training periods, but CTF validation requires predictions only for
#   test set dates (>= 1989-12-31). Without this filter, validation fails with
#   "EOM date is before 1989-12-31" errors for training period rows.
#
# HPC Parallelization Notes:
# - Uses doSNOW with PSOCK clusters for parallel cross-validation.
# - CRITICAL: BLAS/OpenMP thread counts (OMP_NUM_THREADS, etc.) are set to 1
#   BEFORE loading libraries to prevent thread oversubscription. Libraries like
#   glmnet use OpenMP, which initializes thread pools at library load time.
#   With N parallel workers, if each worker spawns M OpenMP threads, you get
#   N*M threads competing for CPUs. For example: 8 workers × 32 threads = 256
#   threads on 32 CPUs causes severe contention and slowdown.
# - These Sys.setenv() calls override any values inherited from the sbatch
#   environment, so no changes to standard HPC job scripts are required.
#
# ==============================================================================

# CRITICAL: Set thread counts BEFORE loading libraries to prevent oversubscription
# Libraries like glmnet use OpenMP, which initializes with these values at load time
# With 8 parallel workers, each should use 1 thread (not 32) to avoid 256 threads
# competing for 32 CPUs
Sys.setenv(OMP_NUM_THREADS = 1)
Sys.setenv(MKL_NUM_THREADS = 1)
Sys.setenv(OPENBLAS_NUM_THREADS = 1)
Sys.setenv(BLAS_NUM_THREADS = 1)

cat("[MAXSER] Thread settings applied BEFORE library loads:\n")
cat(sprintf("[MAXSER]   OMP_NUM_THREADS=%s\n", Sys.getenv("OMP_NUM_THREADS")))
cat(sprintf("[MAXSER]   MKL_NUM_THREADS=%s\n", Sys.getenv("MKL_NUM_THREADS")))
cat(sprintf("[MAXSER]   OPENBLAS_NUM_THREADS=%s\n", Sys.getenv("OPENBLAS_NUM_THREADS")))
cat(sprintf("[MAXSER]   BLAS_NUM_THREADS=%s\n", Sys.getenv("BLAS_NUM_THREADS")))

# Print Slurm environment if running on HPC
slurm_job_id <- Sys.getenv("SLURM_JOB_ID")
if (nchar(slurm_job_id) > 0) {
    cat("[MAXSER] Slurm environment:\n")
    cat(sprintf("[MAXSER]   SLURM_JOB_ID=%s\n", slurm_job_id))
    cat(sprintf("[MAXSER]   SLURM_CPUS_PER_TASK=%s\n", Sys.getenv("SLURM_CPUS_PER_TASK")))
    cat(sprintf("[MAXSER]   SLURM_MEM_PER_NODE=%s\n", Sys.getenv("SLURM_MEM_PER_NODE")))
    cat(sprintf("[MAXSER]   SLURM_JOB_PARTITION=%s\n", Sys.getenv("SLURM_JOB_PARTITION")))
}

library(data.table)
library(glmnet)
library(arrow)
library(parallel)
library(doSNOW)
library(foreach)



# ==================== helper functions ====================
# Helper function to compute the MAXSER portfolio weights.
MAXSER = function(R, sigma=0.04, K=10, seed=NULL){
  T = dim(R)[1]
  N = dim(R)[2]
  Y = rep(1, T)

  if (!is.null(seed))  set.seed(seed)
  all.folds = split(sample(1:T), rep(1:K, length = T))
  glm_est = glmnet(R, Y, alpha=1, nlambda=100, standardize=FALSE, intercept=FALSE)
  lambdas = glm_est$lambda
  risk_list = vector("list", K)
  sr_list = vector("list", K)
  
  for(i in seq(K)) {
    omit = all.folds[[i]]
    fit = glmnet(R[-omit, , drop=FALSE], Y[-omit], alpha=1, lambda=lambdas, standardize=FALSE, intercept=FALSE)
    w = coef(fit, s=lambdas)[-1, ]
    outfit = R[omit, ] %*% w  
    mean = apply(outfit, 2, mean, na.rm=TRUE)
    risk = apply(outfit, 2, sd, na.rm=TRUE)
    risk_list[[i]] = risk
    sr_list[[i]] = (mean / risk)^2
    sr_list[[i]][is.na(sr_list[[i]])] = -Inf

  }
  val_sr = do.call(cbind, lapply(lapply(sr_list, unlist), 'length<-', max(lengths(sr_list))))
  val_risk = do.call(cbind, lapply(lapply(risk_list, unlist), 'length<-', max(lengths(risk_list))))
  val_risk = apply(val_risk, 1, mean)
  idx = which.max(apply(val_sr, 1, mean))
  zeta_star = lambdas[idx]
  lev = sigma / val_risk[idx]

  w.maxser = lev * data.matrix(predict(glm_est, type="coefficients", s=zeta_star)[-1, ])
  return(w.maxser)
}

# Helper function for z-score transformation. 
zscore_transform = function(data) {
    if (length(data) == 0) return(numeric(0))       # no data
    if (all(is.na(data))) return(rep(NA_real_, length(data)))  # all NA
    
    # Calculate mean and standard deviation
    data_mean = mean(data, na.rm = TRUE)
    data_sd = sd(data, na.rm = TRUE)
    
    # Handle case where sd is 0 or NA
    if (is.na(data_sd) || data_sd == 0) {
        return(rep(0, length(data)))
    }
    
    # Apply z-score transformation
    result = (data - data_mean) / data_sd
    return(result)
}

# Helper function applying an z-score transformation grouped by 'eom'.
prepare_data = function(chars, features) {
    eom = "eom"
    for(feature in features) {
        # Apply z-score transformation grouped by 'eom'
        chars[[feature]] = ave(chars[[feature]], chars[[eom]], FUN = function(x) zscore_transform(x))
        # Impute missing values with cross-sectional median
        chars[[feature]] = ave(chars[[feature]], chars[[eom]], FUN = function(x) {
            if (all(is.na(x))) return(x)
            med = median(x, na.rm = TRUE)
            x[is.na(x)] = med
            return(x)
        })
    }
    return(chars)
}

# Helper function to split chars by date. 
splitChars = function(chars, folder) {
  if (!dir.exists(folder))  dir.create(folder, recursive = TRUE)
  cat("Starting to split chars by Date...\n")
  
  dates = unique(chars$eom_ret)
  total = length(dates)
  pb = txtProgressBar(min=0, max=total, style=3)
  invisible(lapply(seq_along(dates), function(i) {
    date = dates[i]
    file.path = sprintf("%s/chars[%s].rds", folder, date)
    if (file.exists(file.path)) {
      setTxtProgressBar(pb, i)
      return(NULL)  
    }
    subset.data = chars[eom_ret == date]
    saveRDS(subset.data, file=file.path)
    setTxtProgressBar(pb, i)
  }))

  close(pb)
  cat("Finished splitting chars.\n\n")
}

# Helper function to load chars by date. 
loadChars = function(date, folder) {
  file.path = sprintf("%s/chars[%s].rds", folder, date)
  if (!file.exists(file.path)) {
    cat(sprintf("File for date %s does not exist in folder %s.\n", date, folder))
    return(NULL)
  }
  chars = readRDS(file.path)
  return(chars)
}

# Helper function to build cross-section factors. 
process_factors = function(chars, folder, formula, ratio_thres=0.5, VIF_thres=NULL, file, ncores=40) {
    cat("Building Cross-Section Factors...\n")
    cl = makeCluster(ncores)
    registerDoSNOW(cl)
    unique_dates = unique(chars$eom_ret)

    # progress bar
    pb = txtProgressBar(max=length(unique_dates), style=3)   
    progress = function(n) setTxtProgressBar(pb, n)
    opts = list(progress=progress)

    # Run the process in parallel using foreach
    result_list = foreach(date=unique_dates, .packages="data.table", 
                            .options.snow=opts, .export=c("formula", "loadChars")) %dopar% {
        data.d = loadChars(date, folder)
        formula.parts = all.vars(formula)
        chars.names = formula.parts[-1]
        chars.d = data.d[, ..chars.names]
        K = length(chars.names)

        cs.factors = data.table(eom=unique(data.d$eom_ret))
        cs.factors[, (chars.names) := NA_real_]

        # Remove columns with high zero/NA ratio
        valid.columns = names(chars.d)[sapply(chars.d, function(col) {
            mean(col == 0 | is.na(col)) < ratio_thres
        })]

        if (!is.null(VIF_thres)){
            # Initialize variables for VIF selection
            while (length(valid.columns) > 6) {
                # Create a temporary dataset with valid columns for VIF calculation
                temp_data = copy(chars.d[, ..valid.columns])
                
                # Calculate the correlation and covariance matrices
                corr = cor(matrix(as.numeric(as.matrix(temp_data)), nrow=nrow(temp_data)), use="pairwise.complete.obs")
                Sig = diag(apply(temp_data, 2, sd, na.rm=TRUE)) %*% corr %*% diag(apply(temp_data, 2, sd, na.rm=TRUE))
                
                # Calculate VIF values
                if (min(eigen(Sig)$values) < 0.000001) {
                    # Manually calculate VIF if covariance matrix is near singular
                    VIFs = sapply(1:ncol(temp_data), function(j) {
                        lmfit = lm(temp_data[[j]] ~ as.matrix(temp_data[, -j, with=FALSE]))
                        1 / (1 - summary(lmfit)$r.squared)
                    })
                } else {
                    # Use precision matrix to calculate VIF if covariance matrix is non-singular
                    Precision = solve(Sig)
                    VIFs = drop(diag(Sig) %*% diag(diag(Precision)))
                }
                
                # Identify the column with the highest VIF and remove it if it exceeds the threshold
                if (max(VIFs) > 1 / (1 - VIF_thres)) {
                    highest_vif_idx = which.max(VIFs)
                    rm_column = valid.columns[highest_vif_idx]
                    print(paste("Removing column:", rm_column, "with VIF:", VIFs[highest_vif_idx]))
                    valid.columns = valid.columns[-highest_vif_idx]
                } else {
                    break
                }
            }
        }

        # Proceed with valid columns after VIF selection
        formula.valid = as.formula(paste("ret_exc_lead1m ~", paste(valid.columns, collapse = " + ")))

        if (length(valid.columns) > 0) {
            # Perform OLS regression without intercept
            coefficients = coef(lm(update(formula.valid, . ~ . - 1), data=data.d))
            cs.factors[, (valid.columns) := as.list(coefficients)]
        } 
        
        res = list(cs.factors=cs.factors, valid.chars=valid.columns)
        res
    }

    # Stop the cluster
    stopCluster(cl)

    # Combine cs factors
    cs.factors = rbindlist(lapply(result_list, function(x) x$cs.factors), fill=TRUE)
    valid.chars = lapply(result_list, function(x) x$valid.chars)
    names(valid.chars) = as.character(cs.factors$eom)

    save(cs.factors, valid.chars, file=file)
    
    cat("\nFinished.\n\n")
    return(list(cs.factors=cs.factors, valid.chars=valid.chars))
}




# ==================== main function ====================
# Main function to load packages, prepare data, train model, and calculate portfolio weights. 
main <- function(chars, features, daily_ret) {

    # chars: data.frame from ctff_chars.parquet
    # features: data.frame from ctff_features.parquet
    # daily_ret: data.frame from ctff_daily_ret.parquet

    eom = "eom"
    features = as.character(features$features)
    chars = prepare_data(chars, features)
    setDT(chars)
    
    dir.create("data/split", recursive=TRUE, showWarnings=FALSE)
    splitChars(chars, "data/split")
    fct_file = "data/factors.RData"
    ncores = 8

    # Print parallelization diagnostics
    cat("[MAXSER] Parallelization setup:\n")
    cat(sprintf("[MAXSER]   detectCores()=%d (what R sees)\n", parallel::detectCores()))
    cat(sprintf("[MAXSER]   ncores=%d (workers we will use)\n", ncores))
    cat(sprintf("[MAXSER]   Effective threads = %d workers × 1 OMP thread = %d total\n", ncores, ncores))

    # build cross-section factors
    if (!file.exists(fct_file)) {
        formula = as.formula(paste("ret_exc_lead1m ~", paste(features, collapse = " + ")))
        fct_res = process_factors(chars, "data/split", formula, ratio_thres=0.5, VIF_thres=0.9, file=fct_file, ncores=ncores)
        factors = fct_res$cs.factors
        setorder(factors, eom)
        valid_chars = fct_res$valid.chars
        cal = factors[, .(eom)]
    }else{
        load(fct_file)
        factors = cs.factors
        valid_chars = valid.chars
        setorder(factors, eom)
        cal = factors[, .(eom)]
    }

    # training timeline
    T_default = 360
    cal_unique = unique(cal$eom)

    # Adapt training window if insufficient data available
    if (length(cal_unique) < T_default) {
        T = length(cal_unique) - 1
        warning(sprintf("Insufficient data: only %d months available. Reducing training window from %d to %d months.",
                        length(cal_unique), T_default, T))
    } else {
        T = T_default
    }

    if (T < 12) {
        stop(sprintf("Insufficient data for training: need at least 12 months, but only %d available.",
                     length(cal_unique)))
    }

    timeline = data.table(
        train_begin = cal_unique[1:(length(cal_unique) - T + 1)],
        train_end = cal_unique[T:length(cal_unique)]
    )
    print(timeline)

    # Create cluster for parallel computation
    cat("[MAXSER] Creating PSOCK cluster...\n")
    cluster_start = Sys.time()
    cl = makeCluster(ncores)
    registerDoSNOW(cl)
    cat(sprintf("[MAXSER] Cluster created in %.1f seconds\n", as.numeric(difftime(Sys.time(), cluster_start, units="secs"))))

    # Compute factor portfolio weights
    cat(paste("Computing factor weights for", nrow(timeline), "periods...\n"))
    phase1_start = Sys.time()
    pb1 = txtProgressBar(max=nrow(timeline), style=3)
    progress1 = function(n) setTxtProgressBar(pb1, n)
    opts1 = list(progress=progress1)
    
    fct_wts = rbindlist(foreach(i=seq_len(nrow(timeline)), .packages=c("data.table", "glmnet"),
                               .options.snow=opts1, .export=c("MAXSER")) %dopar% {
        tb = timeline[i]$train_begin
        te = timeline[i]$train_end
        pool = colnames(factors)[-1]
        wts_dt = data.table(eom=te)
        wts_dt[, (pool) := 0]
        
        # Get all T months data
        train_full = factors[eom >= tb & eom <= te, !"eom", with = FALSE]
        
        # Get the most recent T/2 months to evaluate NA percentage
        train_dates = factors[eom >= tb & eom <= te, eom]
        half_T = floor(length(train_dates) / 2)
        recent_dates = tail(train_dates, half_T)
        train_half = factors[eom %in% recent_dates, !"eom", with = FALSE]
        
        # Select columns with NA percentage <= 10% in recent T/2 months
        na_pct = colSums(is.na(train_half)) / nrow(train_half)
        valid_cols = names(na_pct)[na_pct <= 0.10]
        
        # Use these columns from full T months data and fill NA with 0
        if (length(valid_cols) > 0) {
            train = train_full[, valid_cols, with = FALSE]
            train[is.na(train)] = 0
            train = data.matrix(train)
            w = MAXSER(R = train, sigma = 0.04, K = 10, seed = 123)
            valid_pool = rownames(w)
            wts_dt[, (valid_pool) := as.list(w[, 1])]
        }
        wts_dt
    })
    
    close(pb1)
    phase1_elapsed = as.numeric(difftime(Sys.time(), phase1_start, units="mins"))
    cat(sprintf("\n[MAXSER] Factor weights completed in %.1f minutes\n", phase1_elapsed))

    fct_wts[, latest := shift(eom, n=1, type='lead', fill=NA)]
    fct_wts = na.omit(fct_wts)

    # Parallel computation of stock weights
    cat(paste("Computing final weights for", nrow(fct_wts), "periods...\n"))
    phase2_start = Sys.time()
    pb2 = txtProgressBar(max=nrow(fct_wts), style=3)
    progress2 = function(n) setTxtProgressBar(pb2, n)
    opts2 = list(progress=progress2)
    
    result = rbindlist(foreach(t=seq_len(nrow(fct_wts)), .packages="data.table",
                               .options.snow=opts2, .export=c("loadChars")) %dopar% {
        te = fct_wts[t]$eom
        latest = fct_wts[t]$latest
        chars_sel = valid_chars[[as.character(latest)]]
        chars_ = loadChars(latest, 'data/split')
        stock_pool = chars_$id
        if (!length(stock_pool) == 0) {
            chars_T = data.matrix(chars_[, ..chars_sel])
            coef = solve(t(chars_T) %*% chars_T) %*% t(chars_T)
            maxser_s_d = data.matrix(fct_wts[t, ..chars_sel])
            data.table(id = stock_pool, eom = te, w = drop(maxser_s_d %*% coef))
        }else{
            NULL
        }
    })
    
    close(pb2)
    phase2_elapsed = as.numeric(difftime(Sys.time(), phase2_start, units="mins"))
    cat(sprintf("\n[MAXSER] Final weights completed in %.1f minutes\n", phase2_elapsed))

    stopCluster(cl)
    cat(sprintf("[MAXSER] Total parallel computation: %.1f minutes\n", phase1_elapsed + phase2_elapsed))

    # Filter to only include test period dates (ctff_test == TRUE)
    # CTF validation only accepts predictions for the test set
    test_dates = unique(chars[ctff_test == TRUE, eom])
    n_before = nrow(result)
    result = result[eom %in% test_dates]
    n_after = nrow(result)
    if (n_before > n_after) {
        cat(sprintf("[MAXSER] Filtered output to test period (ctff_test == TRUE): %d -> %d rows (removed %d training rows)\n",
                    n_before, n_after, n_before - n_after))
    }

    return(result)
}