########## Cross-Sectional Ridge-Forest Ensemble ##########
# Author: Hassan Mir
#
# CTF Admin Modifications (2026-02-23):
# --------------------------------------
# 1. Added [CTF-DEBUG] progress statements throughout main()
#    Reason: HPC jobs can run for hours; progress output is essential
#    for monitoring and debugging if the script hangs or crashes.
#
# 2. Added security rule suppression for false positive (line 92)
#    Reason: SEC34A flagged paste() with collapse in error message
#    construction, which is legitimate code, not obfuscation.
#
##########################################################

suppressPackageStartupMessages({
  library(data.table)
  library(glmnet)
  library(ranger)
})

#helper function
rank_to_unit <- function(x) {
  r <- data.table::frank(x, ties.method = "average", na.last = "keep")
  n <- sum(!is.na(r))
  if (n <= 1) return(rep(0, length(x)))
  u <- (r - 1) / (n - 1)
  out <- 2 * u - 1
  out[is.na(out)] <- 0
  out
}

standardize_month_dt <- function(dt_month, feature_cols) {
  for (cc in feature_cols) {
    v <- dt_month[[cc]]
    med <- median(v, na.rm = TRUE)
    if (is.na(med)) med <- 0
    v[is.na(v)] <- med
    dt_month[[cc]] <- rank_to_unit(v)
  }
  dt_month
}

standardize_panel <- function(dt_panel, feature_cols) {
  dt_panel[, standardize_month_dt(.SD, feature_cols), by = eom]
}

choose_lambda_once <- function(dt_all, feature_cols, ycol, lambda_grid, calib_end) {
  train_all <- dt_all[eom <= calib_end]
  months <- sort(unique(train_all$eom))
  if (length(months) < 48) return(lambda_grid[round(length(lambda_grid) / 2)])
  
  val_months <- tail(months, 24)
  tr_months  <- setdiff(months, val_months)
  
  tr <- train_all[eom %in% tr_months]
  va <- train_all[eom %in% val_months]
  
  tr_std <- standardize_panel(data.table::copy(tr), feature_cols)
  va_std <- standardize_panel(data.table::copy(va), feature_cols)
  
  X_tr <- as.matrix(tr_std[, ..feature_cols])
  y_tr <- tr_std[[ycol]]
  X_va <- as.matrix(va_std[, ..feature_cols])
  y_va <- va_std[[ycol]]
  
  fit <- glmnet(X_tr, y_tr, alpha = 0, lambda = lambda_grid, standardize = FALSE)
  pred_va <- predict(fit, X_va)
  mse <- colMeans((pred_va - y_va)^2)
  
  lambda_grid[which.min(mse)]
}

#main function
main <- function(chars, features, daily_ret) {
  start_time <- Sys.time()
  cat("[CTF-DEBUG] Starting main() at", format(start_time, "%Y-%m-%d %H:%M:%S"), "\n")

  set.seed(26)
  
  ycol <- "ret_exc_lead1m"
  
  ROLL_MONTHS     <- 120
  LONG_SHORT_FRAC <- 0.20
  
  W_RIDGE <- 0.7
  
  RETRAIN_EVERY <- 12
  MAX_TRAIN_N   <- 150000
  NUM_TREES     <- 300
  
  dt0 <- as.data.table(chars)
  
  req_cols <- c("id", "eom", "ctff_test", ycol)
  miss <- setdiff(req_cols, names(dt0))
  if (length(miss) > 0) stop(paste("Missing required columns in chars:", paste(miss, collapse = ", ")))
  
  dt0[, eom := as.Date(eom)]
  dt0[, ctff_test := as.logical(ctff_test)]
  dt0 <- dt0[!is.na(get(ycol))]
  
  if (!("features" %in% names(features))) stop("features input must have a column named 'features'")
  feature_cols <- as.character(features$features)
  feature_cols <- feature_cols[feature_cols %in% names(dt0)]
  if (length(feature_cols) == 0) stop("No valid feature columns found in chars based on features$features.")
  
  dt0[, (feature_cols) := lapply(.SD, as.numeric), .SDcols = feature_cols]
  
  test_months <- sort(unique(dt0[ctff_test == TRUE, eom]))
  if (length(test_months) == 0) stop("No test months found where ctff_test == TRUE.")
  
  first_test <- test_months[1]
  all_months <- sort(unique(dt0$eom))
  first_idx  <- match(first_test, all_months)
  if (is.na(first_idx) || first_idx <= 1) stop("Cannot form calibration end month (need >=1 month before first test).")
  calib_end <- all_months[first_idx - 1]
  
  lambda_grid <- 10^seq(3, -3, length.out = 25)
  cat("[CTF-DEBUG] Selecting best lambda via cross-validation...\n")
  best_lambda <- choose_lambda_once(dt0, feature_cols, ycol, lambda_grid, calib_end)
  cat(sprintf("[CTF-DEBUG] Best lambda selected: %.6f\n", best_lambda))
  cat(sprintf("[CTF-DEBUG] Processing %d test months\n", length(test_months)))

  out_list <- vector("list", length(test_months))
  rf_model <- NULL
  
  for (k in seq_along(test_months)) {
    t_month <- test_months[k]
    if (k == 1 || k %% 48 == 0 || k == length(test_months)) {
      cat(sprintf("[CTF-DEBUG] Processing month %d/%d (%s) - %.1f%%\n",
                  k, length(test_months), as.character(t_month),
                  100 * k / length(test_months)))
    }

    train_dt <- dt0[eom < t_month]
    pred_dt  <- dt0[eom == t_month]
    
    if (nrow(pred_dt) < 50 || nrow(train_dt) < 5000) next
    
    train_months <- sort(unique(train_dt$eom))
    if (length(train_months) > ROLL_MONTHS) {
      cutoff <- train_months[length(train_months) - ROLL_MONTHS + 1]
      train_dt <- train_dt[eom >= cutoff]
    }
    if (nrow(train_dt) < 5000) next
    
    train_std <- standardize_panel(data.table::copy(train_dt), feature_cols)
    pred_std  <- standardize_month_dt(data.table::copy(pred_dt), feature_cols)
    
    X_train <- as.matrix(train_std[, ..feature_cols])
    y_train <- train_std[[ycol]]
    X_pred  <- as.matrix(pred_std[, ..feature_cols])
    
    fit_ridge <- glmnet(X_train, y_train, alpha = 0, lambda = best_lambda, standardize = FALSE)
    pred_ridge <- as.numeric(predict(fit_ridge, X_pred))
    
    do_retrain <- is.null(rf_model) || ((k - 1) %% RETRAIN_EVERY == 0)
    if (do_retrain) {
      cat(sprintf("[CTF-DEBUG] Retraining Random Forest (month %d)...\n", k))
      ntr <- nrow(train_std)
      if (ntr > MAX_TRAIN_N) {
        set.seed(1000 + k)
        idx <- sample.int(ntr, MAX_TRAIN_N)
        train_use <- train_std[idx]
      } else {
        train_use <- train_std
      }
      
      train_df <- data.frame(y = train_use[[ycol]], train_use[, ..feature_cols])
      
      rf_model <- ranger(
        y ~ .,
        data = train_df,
        num.trees = NUM_TREES,
        mtry = floor(sqrt(length(feature_cols))),
        min.node.size = 50,
        sample.fraction = 0.7,
        seed = 123
      )
    }
    
    pred_df <- data.frame(pred_std[, ..feature_cols])
    pred_rf <- as.numeric(predict(rf_model, data = pred_df)$predictions)
    
    pred_std[, pred := W_RIDGE * pred_ridge + (1 - W_RIDGE) * pred_rf]
    
    setorder(pred_std, pred)
    n <- nrow(pred_std)
    q <- max(1, floor(LONG_SHORT_FRAC * n))
    
    pred_std[, w := 0.0]
    pred_std[1:q, w := -1 / q]
    pred_std[(n - q + 1):n, w := 1 / q]
    
    out_list[[k]] <- pred_std[, .(id, eom, w)]
  }
  
  output <- rbindlist(out_list, use.names = TRUE, fill = TRUE)
  output <- output[eom %in% test_months]
  setkey(output, id, eom)
  output <- unique(output, by = c("id", "eom"))
  
  if (nrow(output) == 0) stop("Output is empty after processing.")
  if (anyNA(output$id) || anyNA(output$eom) || anyNA(output$w)) stop("Output contains missing values.")
  if (!all(c("id", "eom", "w") %in% names(output))) stop("Output missing required columns.")

  elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "mins"))
  cat(sprintf("[CTF-DEBUG] Completed in %.1f minutes\n", elapsed))
  cat(sprintf("[CTF-DEBUG] Output: %d rows, dates %s to %s\n",
              nrow(output), min(output$eom), max(output$eom)))

  return(as.data.frame(output[, .(id, eom, w)]))
}