################################################################################
#Loading Packages
################################################################################
#
# CTF Admin Modifications (2026-02-19):
# --------------------------------------
# 1. Added [CTF-DEBUG] progress statements throughout main() function
#    Reason: HPC jobs can run for hours; progress output enables monitoring
#    and debugging of long-running jobs
#
# 2. Updated core package versions in renv.lock to match base image:
#    - rlang: 1.1.6 → 1.1.7
#    - vctrs: 0.6.5 → 0.7.1
#    - lifecycle: 1.0.4 → 1.0.5
#    Reason: Build failed with version conflicts - base image pre-installs
#    arrow/jsonlite with latest deps, which conflict with older lockfile versions
#

library(data.table)   # For high performance data manipulation
library(glmnet)       # lasso / ridge
library(arrow)        # to open parquet files

################################################################################
#_______________________________________________________________________________
#===============================================================================
#Creating functions
#===============================================================================
#_______________________________________________________________________________
################################################################################

#######################
#Cleaning data
#######################                                                         #This code is taken from: https://github.com/theisij/common-task-framework-SDF/blob/main/models_R/factor-ml/factor_ml.R
prepare_pred_data <- function(data, features, feat_prank, impute) {             
  if (feat_prank) {                                                             
    data[, (features) := lapply(.SD, as.double), .SDcols = features]            
    cat(sprintf("Percentile-ranking %d features...\n", length(features)))       
    data[, (features) := lapply(.SD, function(x) {
      non_na <- !is.na(x)                                                       
      is_zero <- non_na & (x == 0)                                              
      x[non_na] <- frank(x[non_na], ties.method = "max") / sum(non_na)          
      x[is_zero] <- 0                                                           
      x - 0.5                                                                   
    }), .SDcols = features, by = .(eom)]
  }
  if (impute) {                                                                 
    if (feat_prank) {
      setnafill(data, fill = 0, cols = features)                                
    } else {                                                                    
      data[, (features) := lapply(.SD, function(x) {
        fifelse(is.na(x), median(x, na.rm = TRUE), x)
      }), .SDcols = features, by = .(eom)]
    }
  }
  return(data)
}

#######################
#Ridge/LASSO function
#######################
fit_glmnet_model <- function(train, valid, insample, test,
                             features_vec, target,
                             alpha_val, q = 0.18) {
  
  # 0) copy data to not overwrite
  train     <- copy(train)
  valid     <- copy(valid)
  insample  <- copy(insample)
  test      <- copy(test)
  
  # 1) Train + validation selection
  x_train <- as.matrix(train[, ..features_vec])
  y_train <- train[[target]]
  
  x_valid <- as.matrix(valid[, ..features_vec])
  y_valid <- valid[[target]]
  
  fit_path <- glmnet(x_train, y_train, alpha = alpha_val)
  
  pred_valid <- predict(fit_path, newx = x_valid)
  mse <- colMeans((pred_valid - y_valid)^2, na.rm = TRUE)
  best_lambda <- fit_path$lambda[which.min(mse)]
  
  # 2) Refit on full in-sample
  x_insample <- as.matrix(insample[, ..features_vec])
  y_insample <- insample[[target]]
  
  final_fit <- glmnet(x_insample, y_insample,
                      alpha = alpha_val,
                      lambda = best_lambda)
  
  # 3) Predict in-sample + test
  insample[, pred := as.numeric(predict(final_fit,
                                        newx = as.matrix(insample[, ..features_vec])))]
  
  test[, pred := as.numeric(predict(final_fit,
                                    newx = as.matrix(test[, ..features_vec])))]
  
  # 4) Build long-short weights
  make_weights <- function(dt) {
    dt[, w := {
      lo <- quantile(pred, q, na.rm = TRUE)
      hi <- quantile(pred, 1 - q, na.rm = TRUE)
      
      w <- rep(0, .N)
      long  <- which(pred >= hi)
      short <- which(pred <= lo)
      
      if (length(long)  > 0) w[long]  <-  1 / length(long)
      if (length(short) > 0) w[short] <- -1 / length(short)
      
      w
    }, by = eom]
    dt
  }
  
  insample <- make_weights(insample)
  test     <- make_weights(test)
  
  port_in <- insample[, .(rp = sum(w * get(target), na.rm = TRUE)), by = eom]
  
  list(
    insample = insample,
    test = test,
    port_in = port_in,
    best_lambda = best_lambda
  )
  }

  #######################
  # Main
  #######################
  main <- function(chars, features, daily_ret){

  start_time <- Sys.time()
  cat(sprintf("[CTF-DEBUG] main() started at %s\n", start_time))

  set.seed(42)
    
  table_ctff_features     <- data.table::setDT(features)
  table_ctff_chars        <- data.table::setDT(chars)
  table_ctff_daily_ret    <- data.table::setDT(daily_ret)
  
  prepped_dta <- data.table::copy(table_ctff_chars)                             
  features_vec <- as.character(table_ctff_features$features) 
  prepped_dta <- prepare_pred_data(prepped_dta, features_vec, TRUE, TRUE) 
  
  target <- "ret_exc_lead1m"
  id_col <- "id"
  
  dt <- prepped_dta[
    !is.na(get(target)),
    c(id_col, "ctff_test", "eom", features_vec, target),
    with = FALSE
  ]
  dt[, eom := as.Date(eom)]
  
  #Creating test and "training set" (the latter is however called "insample" as it is split into training and validation)
  test <- dt[ctff_test==TRUE]
  insample <- dt[ctff_test==FALSE]
  
  # Creating a validation set as a fraction of insample
  valid_frac <- 0.25
  
  # unique sorted months in the train sample
  insample_months <- sort(unique(insample$eom))
  n_months <- length(insample_months)
  
  # compute index where validation begins
  valid_start_idx <- floor((1 - valid_frac) * n_months) + 1
  valid_start_date <- insample_months[valid_start_idx]
  
  # create inner-train and validation
  train <- insample[eom < valid_start_date]                                    
  valid      <-  insample[eom >= valid_start_date]
  
  # quick sanity
  range(train$eom); range(valid$eom)                                       
  nrow(train); nrow(valid)
 
  # Running Lasso and Ridge
  ##############################################################################
  cat("[CTF-DEBUG] Fitting LASSO model...\n")
  lasso_res <- fit_glmnet_model(train, valid, insample, test,
                                features_vec, target,
                                alpha_val = 1)
  cat(sprintf("[CTF-DEBUG] LASSO complete (best lambda: %.6f)\n", lasso_res$best_lambda))
  
  # RIDGE
  cat("[CTF-DEBUG] Fitting Ridge model...\n")
  ridge_res <- fit_glmnet_model(train, valid, insample, test,
                                features_vec, target,
                                alpha_val = 0)
  cat(sprintf("[CTF-DEBUG] Ridge complete (best lambda: %.6f)\n", ridge_res$best_lambda))
  
  
  #Constructing Portfolios
  port_lasso_test <- lasso_res$test[, 
                                    .(rp = sum(w * get(target), na.rm = TRUE)), 
                                    by = eom ][order(eom)]
  
  
  
  port_ridge_test <- ridge_res$test[, 
                                    .(rp = sum(w * get(target), na.rm = TRUE)), 
                                    by = eom][order(eom)]
  
  
  # Creating ensemble
  ##############################################################################
  cat("[CTF-DEBUG] Building ensemble portfolio...\n")
  # Build in-sample matrix of monthly returns 
  rets_in <- merge(
    lasso_res$port_in[, .(eom, r_lasso = rp)],
    ridge_res$port_in[, .(eom, r_ridge = rp)],
    by = "eom", all = FALSE
  )[order(eom)]
  
  # quick check: number of months available
  n_months <- nrow(rets_in)
  if (n_months < 6) warning("Very few in-sample months; covariance estimates may be noisy")
  
  # Estimate covariance  
  Rmat <- as.matrix(rets_in[, .(r_lasso, r_ridge)])
  Sigma <- cov(Rmat, use = "pairwise.complete.obs")
  
  #Storing covariances and variances
  sigma_L2 <- Sigma[1,1]
  sigma_R2 <- Sigma[2,2]
  sigma_LR <- Sigma[1,2]
  
  #Denominator for weight calculation for two assets:
  den <- (sigma_L2 + sigma_R2 - 2*sigma_LR)
  
  # stability guard: if denominator too small or NA, fallback to equal weights
  if (!is.finite(den) || abs(den) < 1e-10) {
    warning("Covariance degenerate or tiny denominator — falling back to equal weights")
    wL <- 0.5; wR <- 0.5
  } else {
    wL <- (sigma_R2 - sigma_LR) / den
    wR <- 1 - wL
  }
  
  #Storing weights, date and id for each strategy
  test_lasso_w <- lasso_res$test[, .(id, eom, w_lasso = w)]
  test_ridge_w <- ridge_res$test[, .(id, eom, w_ridge = w)]
  
  #Combine weights
  ens_test <- merge(test_lasso_w, test_ridge_w, by = c("id","eom"), all = TRUE)
  ens_test[is.na(w_lasso), w_lasso := 0]
  ens_test[is.na(w_ridge), w_ridge := 0]
  
  #Final Ensemble Strategy
  ens_test[, w := wL * w_lasso + wR * w_ridge]
  
  ens_test[, id := as.integer(id)]
  ens_test[, eom := data.table::as.IDate(eom)]
  ens_test[is.na(w), w := 0]
  
  out_dt <- ens_test[, .(id, eom, w)]
  out_df <- as.data.frame(out_dt)
  
  # final validation
  if (nrow(out_df) == 0) stop("Output empty")
  if (!all(c("id","eom","w") %in% names(out_df))) stop("Output must contain id,eom,w")
  if (any(is.na(out_df$id) | is.na(out_df$eom) | is.na(out_df$w))) stop("Output contains NAs")
  
  #Creating final data
  ens_eval <- merge(
    ens_test,
    test[, .(id, eom, ret_exc_lead1m = get(target))],
    by = c("id", "eom"),
    all.x = TRUE
  )

  elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
  cat(sprintf("[CTF-DEBUG] main() completed in %.1fs\n", elapsed))
  cat(sprintf("[CTF-DEBUG] Output: %d rows, date range %s to %s\n",
              nrow(out_df), min(out_df$eom), max(out_df$eom)))
  cat(sprintf("[CTF-DEBUG] Ensemble weights: LASSO=%.3f, Ridge=%.3f\n", wL, wR))
  return(out_df)
  }

