diff --git a/.Rbuildignore b/.Rbuildignore index b28eb65..73d8b09 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,17 +1,21 @@ ^.*\.Rproj$ ^\.Rproj\.user$ -^\.git$ -^\.gitignore$ -^\.gitattributes$ +^\.Rhistory$ +^\.RData$ +^\.RDataTmp$ +^output$ +^\.renv$ ^renv$ ^renv\.lock$ +# Ignore vignettes and tests for CRAN +#^vignettes$ +#^tests$ ^01dev_config.R$ ^02document_genera.R$ ^README\.Rmd$ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ -# 你特定的Demo文件 ^RHealth---medcode Module Demo\.html$ ^RHealth---medcode Module Demo\.qmd$ ^LICENSE\.md$ diff --git a/DESCRIPTION b/DESCRIPTION index 5afdde2..306e3d6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,13 @@ Package: RHealth Type: Package Title: A Deep Learning Toolkit for Healthcare Predictive Modeling -Version: 0.1.0 +Version: 0.2.0 Authors@R: c( + person( + "Junyi", "Gao", + email = "junyii.gao@gmail.com", + role = c("aut", "cre") + ), person( "Ji", "Song", email = "eaad0907@163.com", @@ -42,31 +47,18 @@ Authors@R: c( "Ewen", "Harrison", email = "ewen.harrison@ed.ac.uk", role = c("aut") - ), - person( - "Junyi", "Gao", - email = "junyii.gao@gmail.com", - role = c("aut", "cre") ) ) Description: RHealth is an open-source R package specifically designed to bring comprehensive deep learning toolkits to the R community for healthcare predictive modeling. License: MIT + file LICENSE Encoding: UTF-8 LazyData: true -URL: https://v1xerunt.github.io/dev/RHealth, https://github.com/v1xerunt/RHealth +URL: https://v1xerunt.github.io/RHealth/dev, https://github.com/v1xerunt/RHealth BugReports: https://github.com/v1xerunt/RHealth/issues Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Depends: R (>= 4.1.0) -Suggests: - devtools (>= 2.4.5), - usethis, - testthat (>= 3.0.0), - lintr (>= 3.2.0), - styler (>= 1.10.3), - pkgdown (>= 2.1.1), - roxygen2 (>= 7.3.2) -Imports: +Imports: R6, httr, readr, @@ -78,18 +70,29 @@ Imports: yaml, psych, progressr, - polars, - pROC, + DBI, metrica, lubridate, future, future.apply, futile.logger, - PRROC, MLmetrics, tidyr, checkmate, fs, - duckdb -Additional_repositories: https://community.r-multiverse.org, https://rpolars.r-universe.dev/ + duckdb, + stringr +Suggests: + devtools (>= 2.4.5), + usethis, + testthat (>= 3.0.0), + lintr (>= 3.2.0), + styler (>= 1.10.3), + pkgdown (>= 2.1.1), + roxygen2 (>= 7.3.2), + knitr, + rmarkdown, + ps +VignetteBuilder: knitr +Additional_repositories: https://community.r-multiverse.org Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 50c1cb8..17268a0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,6 +8,7 @@ export(DatasetProcessor) export(EmbeddingModel) export(Event) export(FeatureProcessor) +export(InHospitalMortalityMIMIC3) export(InHospitalMortalityMIMIC4) export(JoinConfig) export(MIMIC3Dataset) @@ -15,6 +16,8 @@ export(MIMIC4EHRDataset) export(MIMIC4NoteDataset) export(MultiClassLabelProcessor) export(MultiLabelProcessor) +export(NextMortalityMIMIC3) +export(NextMortalityMIMIC4) export(Patient) export(Processor) export(RNN) @@ -41,6 +44,7 @@ export(get_descendants) export(get_metrics_fn) export(get_processor) export(is_best) +export(load_sample_dataset) export(load_yaml_config) export(lookup_code) export(map_code) @@ -54,16 +58,32 @@ import(R6) import(checkmate) import(dplyr) import(lubridate) -import(polars) import(tidyr) import(torch) import(yaml) +importFrom(DBI,dbConnect) +importFrom(DBI,dbDisconnect) +importFrom(DBI,dbExecute) +importFrom(DBI,dbExistsTable) +importFrom(DBI,dbListTables) +importFrom(DBI,dbRemoveTable) importFrom(MLmetrics,Accuracy) importFrom(MLmetrics,F1_Score) importFrom(MLmetrics,Precision) importFrom(MLmetrics,Recall) -importFrom(PRROC,pr.curve) importFrom(R6,R6Class) +importFrom(dplyr,collect) +importFrom(dplyr,distinct) +importFrom(dplyr,filter) +importFrom(dplyr,left_join) +importFrom(dplyr,mutate) +importFrom(dplyr,pull) +importFrom(dplyr,rename) +importFrom(dplyr,rename_with) +importFrom(dplyr,select) +importFrom(dplyr,tbl) +importFrom(dplyr,union_all) +importFrom(duckdb,duckdb) importFrom(fs,dir_create) importFrom(fs,file_exists) importFrom(fs,path) @@ -82,8 +102,6 @@ importFrom(lubridate,as.duration) importFrom(lubridate,dhours) importFrom(metrica,balacc) importFrom(metrica,jaccardindex) -importFrom(pROC,auc) -importFrom(polars,pl) importFrom(progressr,handlers) importFrom(progressr,progressor) importFrom(psych,cohen.kappa) @@ -92,6 +110,9 @@ importFrom(purrr,map) importFrom(rappdirs,user_cache_dir) importFrom(readr,cols) importFrom(readr,read_csv) +importFrom(stats,aggregate) +importFrom(stats,setNames) +importFrom(stringr,str_pad) importFrom(torch,nn_dropout) importFrom(torch,nn_linear) importFrom(torch,nn_module) diff --git a/R/Dataset_BaseDataset.R b/R/Dataset_BaseDataset.R index 481aea2..a06a1e3 100644 --- a/R/Dataset_BaseDataset.R +++ b/R/Dataset_BaseDataset.R @@ -2,32 +2,35 @@ #' BaseDataset — R6 infrastructure for clinical event datasets #' #' The **BaseDataset** class mirrors rhealth's `BaseDataset`, providing a -#' fully‑featured, YAML driven loader that converts multi‑table electronic +#' fully-featured, YAML driven loader that converts multi-table electronic #' health records into a single *event* table. It supports: #' \itemize{ -#' \item URL or local‐file ingestion (with automatic `.csv` / `.csv.gz` +#' \item URL or local-file ingestion (with automatic `.csv` / `.csv.gz` #' fallback). -#' \item Per‑table joins as declared in the config. -#' \item Flexible timestamp parsing (single or multi‑column). +#' \item Per-table joins as declared in the config. +#' \item Flexible timestamp parsing (single or multi-column). #' \item A \code{dev} mode that caps the number of patients for rapid #' prototyping. -#' \item Multi‑threaded sample generation with progress bars. +#' \item Multi-threaded sample generation with progress bars. #' } #' -#' Down‑stream, it cooperates with \code{BaseTask} (task definition), -#' \code{Patient} (per‑subject wrapper), and \code{SampleDataset} (collection of +#' Down-stream, it cooperates with \code{BaseTask} (task definition), +#' \code{Patient} (per-subject wrapper), and \code{SampleDataset} (collection of #' input/output pairs). #' #' @section Dependencies: #' Polars is used via the \code{polars} R package. Parallelism and progress #' reporting require \code{future}, \code{future.apply}, and \code{progressr}. -#' @importFrom polars pl +#' @importFrom R6 R6Class +#' @importFrom dplyr tbl collect filter select mutate left_join union_all distinct pull rename rename_with +#' @importFrom DBI dbConnect dbDisconnect dbExecute dbListTables dbExistsTable dbRemoveTable +#' @importFrom duckdb duckdb +#' @importFrom glue glue #' @importFrom future plan multisession #' @importFrom future.apply future_lapply #' @importFrom progressr handlers progressor -#' @importFrom R6 R6Class #' @export BaseDataset <- R6::R6Class( "BaseDataset", @@ -40,20 +43,23 @@ BaseDataset <- R6::R6Class( #' @field tables Character vector of table names to ingest. tables = NULL, - #' @field dataset_name Human‑readable dataset label. + #' @field dataset_name Human-readable dataset label. dataset_name = NULL, #' @field config Parsed YAML configuration list. config = NULL, - #' @field dev Logical flag — when TRUE limits to 1 000 patients. + #' @field dev Logical flag — when TRUE limits to 1000 patients. dev = FALSE, + + #' @field con a duckdb connection + con = NULL, - #' @field global_event_df A polars LazyFrame with all events combined. + #' @field global_event_df A duckdb lazy query with all events combined. global_event_df = NULL, #-------------------------------------------------------------------- - # Private‑cache fields ---------------------------------------------- + # Private-cache fields ---------------------------------------------- #' @field .collected_global_event_df Polars dataframe storing all global events. .collected_global_event_df = NULL, #' @field .unique_patient_ids Character vector of unique patient IDs. @@ -74,6 +80,12 @@ BaseDataset <- R6::R6Class( dataset_name = NULL, config_path = NULL, dev = FALSE) { + + self$con <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") + reg.finalizer(self, function(e) { + message("Auto-disconnecting duckdb") + DBI::dbDisconnect(e$con, shutdown = TRUE) + }, onexit = TRUE) self$root <- root self$tables <- unique(tolower(tables)) @@ -87,43 +99,20 @@ BaseDataset <- R6::R6Class( #-------------------------------------------------------------------- #' @description - #' Materialise (collect) the lazy event dataframe. In dev‑mode only the + #' Materialise (collect) the lazy event dataframe. In dev-mode only the #' first 1000 patients are kept. - #' @return A polars DataFrame containing all selected events. + #' @return A dataframe containing all selected events. collected_global_event_df = function() { - if (is.null(self$.collected_global_event_df)) { - message("[info] Collecting global event dataframe …") - - df <- self$global_event_df # LazyFrame - - # ---------------- Development mode: limit to 1 000 patients ---------------- - schema <- df$schema - dt_cols <- names( - Filter(function(dtype) { - tryCatch({ - dtype$kind() == "Datetime" - }, error = function(e) FALSE) - }, schema) - ) - - if (length(dt_cols) > 0) { - cast_exprs <- lapply(dt_cols, function(colname) { - pl$col(colname)$cast(pl$Utf8)$alias(colname) - }) - df <- df$with_columns(cast_exprs) + if (is.null(self$.collected_global_event_df)) { + message("[info] Collecting global event dataframe ...") + + df <- self$global_event_df # duckdb query + + self$.collected_global_event_df <- df %>% dplyr::collect() } - - self$.collected_global_event_df <- df$collect() - - # shp <- self$.collected_global_event_df$shape() - # message( - # sprintf("[info] Collected dataframe shape: (%d rows, %d columns)", - # shp[1], shp[2]) - # ) - } - - return(self$.collected_global_event_df) - }, + + return(self$.collected_global_event_df) + }, @@ -132,113 +121,154 @@ BaseDataset <- R6::R6Class( #' Load one table, apply joins, lowercase columns, and standardise to the #' event schema. #' @param table_name Character key present in \code{config$tables}. - #' @return A polars LazyFrame in event format. + #' @return A dplyr lazy query in event format. load_table = function(table_name) { - - if (!(table_name %in% names(self$config$tables))) - stop(sprintf("Table %s not in config", table_name)) - - cfg <- self$config$tables[[table_name]] - csv_path <- .clean_path(file.path(self$root, cfg$file_path)) - parquet_path <- .csv2parquet_path(csv_path) - - # one-time conversion; subsequent runs hit the cached parquet - .ensure_parquet(csv_path, parquet_path) - - message(sprintf("Scanning %s", parquet_path)) - lf <- pl$scan_parquet(parquet_path, streaming = TRUE) - - # ── optional joins (each join source is treated the same way) ── - if (!is.null(cfg$join)) { - for (j in cfg$join) { - join_csv <- .clean_path(file.path(self$root, j$file_path)) - join_parq <- .csv2parquet_path(join_csv) - .ensure_parquet(join_csv, join_parq) - - join_df <- pl$scan_parquet(join_parq, streaming = TRUE)$ - select(j$on, j$columns) - - lf <- lf$join(join_df, on = j$on, how = j$how) + + if (!(table_name %in% names(self$config$tables))) + stop(sprintf("Table %s not in config", table_name)) + + cfg <- self$config$tables[[table_name]] + + base_path <- .clean_path(file.path(self$root, cfg$file_path)) + file_info <- .find_path_with_fallback(base_path) + csv_path <- file_info$path + separator <- file_info$separator + + parquet_path <- .csv2parquet_path(csv_path) + + # one-time conversion; subsequent runs hit the cached parquet + .ensure_parquet(csv_path, parquet_path, separator = separator) + + message(sprintf("Scanning %s", parquet_path)) + + view_name <- tools::file_path_sans_ext(basename(table_name)) + DBI::dbExecute(self$con, glue::glue("CREATE OR REPLACE VIEW \"{view_name}\" AS SELECT * FROM '{parquet_path}';")) + lf <- dplyr::tbl(self$con, view_name) %>% dplyr::rename_with(tolower) + + #── optional joins (each join source is treated the same way) ── + if (!is.null(cfg$join)) { + for (j in cfg$join) { + join_base_path <- .clean_path(file.path(self$root, j$file_path)) + join_file_info <- .find_path_with_fallback(join_base_path) + join_csv <- join_file_info$path + join_separator <- join_file_info$separator + + join_parq <- .csv2parquet_path(join_csv) + .ensure_parquet(join_csv, join_parq, separator = join_separator) + + join_table_name <- tools::file_path_sans_ext(basename(j$file_path)) + + DBI::dbExecute(self$con, glue::glue("CREATE OR REPLACE VIEW \"{join_table_name}\" AS SELECT * FROM '{join_parq}';")) + + join_df <- dplyr::tbl(self$con, join_table_name) %>% + dplyr::rename_with(tolower) %>% + dplyr::select(dplyr::all_of(j$on), dplyr::all_of(j$columns)) + + lf <- dplyr::left_join(lf, join_df, by = j$on) + } } - } - - # ── timestamp expression ───────────────────────────────────── - ts_expr <- if (!is.null(cfg$timestamp)) { - if (is.list(cfg$timestamp)) { - pl$concat_str(lapply(cfg$timestamp, pl$col)) + + # ── timestamp expression ───────────────────────────────────── + ts_col <- if (!is.null(cfg$timestamp)) { + if (is.list(cfg$timestamp)) { + # This is tricky to replicate directly without more info on what concat_str does. + # Assuming it concatenates columns to form a string. + # We can do this with paste. The equivalent in SQL is CONCAT. + # For now, I will assume it creates a string representation. + # This might need adjustment based on the exact polars behavior. + rlang::parse_expr(paste0("paste(", paste0(cfg$timestamp, collapse=", "), ")")) + } else { + rlang::sym(cfg$timestamp) + } } else { - pl$col(cfg$timestamp) + NA_character_ } - } else { - pl$lit(NA_character_) - } - - # ── patient-id handling ────────────────────────────────────── - pid_expr <- if (!is.null(cfg$patient_id)) - pl$col(cfg$patient_id)$cast(pl$Utf8) - else - pl$int_range(0, pl$count())$cast(pl$Utf8) - - # ── attribute columns (prefixed with table name) ───────────── - attrs <- lapply(cfg$attributes, function(a) - pl$col(a)$alias(paste0(table_name, "/", a))) - - # ── final event-schema LazyFrame ───────────────────────────── - exprs_all <- c( - list(pid_expr$alias("patient_id")), - list(pl$lit(table_name)$cast(pl$Utf8)$alias("event_type")), - list(ts_expr$alias("timestamp")), - attrs - ) - - lf <- lf$select(exprs_all) + + # ── patient-id handling ────────────────────────────────────── + pid_col <- if (!is.null(cfg$patient_id)) { + rlang::sym(cfg$patient_id) + } else { + # polars `pl$int_range(0, pl$count())` creates a sequence from 0 to N-1 + # The equivalent in dplyr/SQL is row_number() - 1 + rlang::expr(row_number() - 1) + } + + # ── attribute columns (prefixed with table name) ───────────── + attrs <- cfg$attributes + + # ── final event-schema LazyFrame ───────────────────────────── + lf <- lf %>% + dplyr::mutate( + patient_id = !!pid_col, + event_type = table_name, + timestamp = !!ts_col + ) %>% + dplyr::select( + dplyr::all_of(c("patient_id", "event_type", "timestamp")), + dplyr::all_of(attrs) + ) %>% + dplyr::rename_with(~paste0(table_name, "/", .x), .cols = dplyr::all_of(attrs)) + + lf }, - + #-------------------------------------------------------------------- #' @description #' Load every configured table, returning a single \emph{lazy} frame. - #' @return A polars LazyFrame. + #' @return A duckdb lazy query. load_data = function() { - ## 1. build a list of LazyFrames (one per table) - frames <- lapply(self$tables, self$load_table) # each item is LazyFrame - + ## 1. build a list of lazy queries (one per table) + frames <- lapply(self$tables, self$load_table) + ## 2. concatenate lazily - # * "diagonal_relaxed": keeps columns even if some tables lack them - # * rechunk = FALSE : avoid an eager rechunk that can eat memory - # * parallel = TRUE : allow multi-threaded scan - df <- pl$concat( - frames, - how = "diagonal_relaxed", - rechunk = FALSE, - parallel = TRUE - ) - + all_cols <- unique(unlist(lapply(frames, function(df) colnames(df)))) + + frames_aligned <- lapply(frames, function(df) { + missing_cols <- setdiff(all_cols, colnames(df)) + + add_cols_exprs <- stats::setNames( + lapply(missing_cols, function(c) rlang::expr(NA)), + missing_cols + ) + + if (length(add_cols_exprs) > 0) { + df <- df %>% dplyr::mutate(!!!add_cols_exprs) + } + + df %>% dplyr::select(dplyr::all_of(all_cols)) + }) + + df <- purrr::reduce(frames_aligned, dplyr::union_all) + ## 3. dev-mode: early down-sampling to speed up prototyping - if (isTRUE(self$dev) && "patient_id" %in% names(df$schema)) { + if (isTRUE(self$dev) && "patient_id" %in% colnames(df)) { message("[dev] Limiting to 1000 patients (early filter)") - patient_ids <- df$ - select("patient_id")$ - unique()$ - limit(1000)$collect()$to_series()$to_list() - patient_ids <- unlist(patient_ids) - df <- df$filter(pl$col("patient_id")$is_in(patient_ids)) + patient_ids <- df %>% + dplyr::select("patient_id") %>% + dplyr::distinct() %>% + head(1000) %>% + dplyr::pull() + + df <- df %>% dplyr::filter(patient_id %in% patient_ids) } - - ## 4. return the final LazyFrame + + df <- df %>% dplyr::arrange(patient_id, timestamp) + ## 4. return the final lazy query return(df) }, - + #-------------------------------------------------------------------- #' @description #' Retrieve (and cache) the vector of unique patient IDs. #' @return Character vector of patient IDs. unique_patient_ids = function() { if (is.null(self$.unique_patient_ids)) - self$.unique_patient_ids <- self$collected_global_event_df()$ - select("patient_id")$unique()$to_series()$to_list() + self$.unique_patient_ids <- self$collected_global_event_df() %>% + dplyr::distinct(patient_id) %>% + dplyr::pull(patient_id) self$.unique_patient_ids }, - + #-------------------------------------------------------------------- #' @description #' Construct a \code{Patient} object for one subject. @@ -246,50 +276,61 @@ BaseDataset <- R6::R6Class( #' @return A new \code{Patient} R6 instance. get_patient = function(patient_id) { stopifnot(patient_id %in% self$unique_patient_ids()) - sub_df <- self$collected_global_event_df()$ - filter(pl$col("patient_id") == patient_id) + sub_df <- self$collected_global_event_df() %>% + dplyr::filter(patient_id == !!patient_id) Patient$new(patient_id = patient_id, data_source = sub_df) }, - + #-------------------------------------------------------------------- #' @description #' Iterate over all patients (optionally a filtered dataframe). - #' @param df Optional polars DataFrame (already collected). + #' @param df Optional dataframe (already collected). #' @return List of \code{Patient} objects. iter_patients = function(df = NULL) { - ids <- self$unique_patient_ids() + if (is.null(df)) { + df <- self$collected_global_event_df() + } + + # Group by patient_id and split into a list of data frames + # using group_split + patient_dfs <- df %>% + dplyr::group_by(patient_id) %>% + dplyr::group_split() + if (self$dev) { - message("[dev] Limiting to 100 patients for rapid prototyping") - ids <- head(ids, 100) + message("[dev] Limiting to 1000 patients for rapid prototyping") + patient_dfs <- head(patient_dfs, 1000) } - - progressr::handlers(global = TRUE) - p <- progressr::progressor(steps = length(ids)) - - lapply(seq_along(ids), function(i) { - id <- ids[[i]] - percent <- sprintf("%.1f%%", (i / length(ids)) * 100) - + + p <- progressr::progressor(steps = length(patient_dfs)) + + # Iterate over the list of data frames + lapply(seq_along(patient_dfs), function(i) { + patient_df <- patient_dfs[[i]] + # All rows in patient_df have the same patient_id, so we can take the first one. + id <- patient_df$patient_id[1] + percent <- sprintf("%.1f%%", (i / length(patient_dfs)) * 100) + p(message = sprintf("[%s] Processing patient: %s", percent, id)) - + Patient$new( patient_id = id, - data_source = df$filter(pl$col("patient_id")$eq(id)) + data_source = patient_df ) }) }, - - + + #-------------------------------------------------------------------- #' @description - #' Print dataset‑level statistics. - #' @return Invisible NULL (called for side‑effects). + #' Print dataset-level statistics. + #' @return Invisible NULL (called for side-effects). stats = function() { df <- self$collected_global_event_df() cat(sprintf("Dataset : %s\n", self$dataset_name)) cat(sprintf("Dev mode : %s\n", self$dev)) cat(sprintf("Patients : %d\n", length(self$unique_patient_ids()))) - cat(sprintf("Events : %d\n", df$height)) + cat(sprintf("Events : %d\n", nrow(df))) invisible(NULL) }, @@ -305,38 +346,84 @@ BaseDataset <- R6::R6Class( #' @param task A \code{BaseTask} instance; if NULL, \code{default_task()} is #' used. #' @param num_workers Integer ≥1. Number of parallel workers. + #' @param chunk_size Integer. Number of patients to process in each chunk. + #' @param cache_dir Optional path to a directory for caching samples. If set, + #' processed samples will be saved to an `.rds` file and reloaded on + #' subsequent runs, skipping the generation step. #' @return A populated \code{SampleDataset}. - set_task = function(task = NULL, num_workers = 1) { + set_task = function(task = NULL, num_workers = 1, chunk_size = 1000, cache_dir = NULL) { task <- task %||% self$default_task() - stopifnot(!is.null(task)) - cat(task$task_name, "\n") + + message(sprintf("Setting task %s for %s", task$task_name, self$dataset_name)) + + if (!is.null(cache_dir)) { + cache_file <- file.path(cache_dir, "sd_object.rds") + if (file.exists(cache_file)) { + message(sprintf("[cache] Loading cached SampleDataset from %s", cache_dir)) + return(load_sample_dataset(cache_dir)) + } + } + message(sprintf("Generating samples for task %s", task$task_name)) df <- self$collected_global_event_df() filtered_df <- task$pre_filter(df) - patients <- self$iter_patients(filtered_df) + patient_ids <- filtered_df %>% + dplyr::distinct(patient_id) %>% + dplyr::pull(patient_id) - progressr::handlers(global = TRUE) + if (self$dev) { + message("[dev] Limiting to 1000 patients for rapid prototyping") + patient_ids <- head(patient_ids, 1000) + } - if (num_workers == 1) { - p <- progressr::progressor(along = patients) - samples <- unlist(lapply(patients, function(pat) { p(); task$call(pat) }), - recursive = FALSE) - } else { + id_chunks <- split(patient_ids, ceiling(seq_along(patient_ids) / chunk_size)) + all_samples <- list() + + if (num_workers > 1) { future::plan(future::multisession, workers = num_workers) - samples <- future.apply::future_lapply(patients, task$call) - samples <- unlist(samples, recursive = FALSE) } + p <- progressr::progressor(steps = length(id_chunks)) + + for (chunk_ids in id_chunks) { + p(message = sprintf("Processing chunk of %d patients", length(chunk_ids))) + + chunk_df <- filtered_df %>% dplyr::filter(patient_id %in% chunk_ids) + + patient_dfs <- chunk_df %>% + dplyr::group_by(patient_id) %>% + dplyr::group_split() + + if (num_workers == 1) { + patients <- lapply(patient_dfs, function(pdf) { + Patient$new(patient_id = pdf$patient_id[1], data_source = pdf) + }) + chunk_samples <- unlist(lapply(patients, task$call), recursive = FALSE) + } else { + task_runner <- .create_task_runner(task) + chunk_samples_list <- future.apply::future_lapply(patient_dfs, task_runner, future.seed = TRUE) + chunk_samples <- unlist(chunk_samples_list, recursive = FALSE) + } + + all_samples <- c(all_samples, chunk_samples) + rm(chunk_df, patient_dfs, chunk_samples) + if (exists("patients")) rm(patients) + gc() + } + + samples <- all_samples + message(sprintf("Generated %d samples", length(samples))) result <- SampleDataset( samples = samples, input_schema = task$input_schema, output_schema = task$output_schema, dataset_name = self$dataset_name, - task_name = task + task_name = task, + save_path = cache_dir ) message("[info] Task set successfully") return(result) @@ -348,6 +435,28 @@ BaseDataset <- R6::R6Class( # Internal helper wrappers (prefixed with .) ---------------------------------- +#' Helper function to create a clean closure for parallel processing +#' +#' This function acts as a factory. It creates and returns another function +#' (a closure) that is suitable for use with `future.apply`. The returned +#' function's environment is intentionally minimal, containing only the `task` +#' object. This prevents large objects from the parent environment (like the +#' full event dataframe) from being accidentally exported to parallel workers. +#' +#' @param task The task object to be used inside the closure. +#' @return A function that takes a patient dataframe (`pdf`) and applies the task. +#' @keywords internal +.create_task_runner <- function(task) { + force(task) # Ensure task is evaluated in this clean environment + function(pdf) { + # This function now has a very small closure environment, + # only containing 'task'. 'Patient' is found in the global scope. + patient <- Patient$new(patient_id = pdf$patient_id[1], data_source = pdf) + task$call(patient) + } +} + + #' Determines whether a path is an HTTP(S) URL. #' @param path A character string. #' @return Logical scalar indicating if it's a valid URL. @@ -377,59 +486,6 @@ BaseDataset <- R6::R6Class( } -#' Lazily loads a `.csv` or `.csv.gz` file and returns a Polars LazyFrame. -#' Automatically tries the alternate extension if the primary path fails. -#' @param path File path or URL ending in `.csv` or `.csv.gz`. -#' @return A polars LazyFrame. -#' @keywords internal -# Load CSV or CSV.GZ as a LazyFrame, with fallback and case-insensitive matching -.scan_csv_gz_or_csv <- function(path) { - # Attempt to match actual file case (case-insensitive filesystem) - path <- .match_actual_case(path) - - # If the path exists, scan it using Polars with disabled schema inference - if (.path_exists(path)) { - message("Loading file: ", path) - lf <- pl$scan_csv(path, infer_schema_length = 0, streaming = TRUE,) - } else { - # If original path doesn't exist, try fallback between .csv and .csv.gz - alt <- if (grepl("\\.csv\\.gz$", path, ignore.case = TRUE)) { - sub("\\.gz$", "", path, ignore.case = TRUE) - } else if (grepl("\\.csv$", path, ignore.case = TRUE)) { - paste0(path, ".gz") - } else { - stop("Unsupported file extension: ", path) - } - - alt <- .match_actual_case(alt) - - if (.path_exists(alt)) { - message("Fallback to: ", alt) - lf <- pl$scan_csv(alt, infer_schema_length = 0, streaming = TRUE) - path <- alt # update to actual path - } else { - stop("Neither ", path, " nor ", alt, " exist.") - } - } - - # Read the first line of the CSV (header) to extract column names - con <- gzfile(path, open = "rt") - header_line <- readLines(con, n = 1) - close(con) - col_names <- strsplit(header_line, split = ",")[[1]] - - # Generate Polars expressions with lowercase aliases - exprs <- lapply(col_names, function(x) { - pl$col(x)$alias(tolower(x)) - }) - - # Return both the LazyFrame and expression list - return(list( - lazy_frame = lf, - exprs = exprs - )) -} - .match_actual_case <- function(path) { dir <- dirname(path) base <- basename(path) @@ -441,24 +497,26 @@ BaseDataset <- R6::R6Class( #' @keywords internal -.ensure_parquet <- function(csv_path, parquet_path) { +.ensure_parquet <- function(csv_path, parquet_path, separator = ",") { if (file.exists(parquet_path) && file.info(parquet_path)$size > 0) return(invisible(parquet_path)) - message("[cache] DuckDB COPY → ", basename(parquet_path)) + message("[cache] DuckDB COPY -> ", basename(parquet_path)) tmp_parq <- tempfile(fileext = ".parquet") on.exit(unlink(tmp_parq), add = TRUE) con <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") on.exit(DBI::dbDisconnect(con, shutdown = TRUE), add = TRUE) + # duckdb's read_csv needs tabs to be escaped as '\\t' in the sprintf format string + db_separator <- if (separator == "\t") "\\t" else separator sql <- sprintf( " COPY ( SELECT * FROM read_csv( '%s', - delim = ',', + delim = '%s', quote = '\"', escape = '\"', header = TRUE, @@ -471,6 +529,7 @@ BaseDataset <- R6::R6Class( (FORMAT PARQUET, COMPRESSION ZSTD); ", normalizePath(csv_path, winslash = "/"), + db_separator, normalizePath(tmp_parq, winslash = "/") ) @@ -486,9 +545,60 @@ BaseDataset <- R6::R6Class( #' Given a *.csv(.gz) path, return *.parquet path in a /subset folder +#' @param csv_path Path to the csv file. +#' @return A character string. +#' @keywords internal .csv2parquet_path <- function(csv_path) { dir.create(file.path(dirname(csv_path), "subset"), showWarnings = FALSE) file.path(dirname(csv_path), "subset", sub("\\.csv(\\.gz)?$", ".parquet", basename(csv_path), ignore.case = TRUE)) } + + +#' Find an existing data file path with fallback for .gz extension. +#' +#' This function checks for the existence of a path and its alternative with/without +#' `.gz`. It also determines the separator based on the file extension (.csv or .tsv). +#' +#' @param path A character path to a .csv, .csv.gz, .tsv, or .tsv.gz file. +#' @return A list with `path` to an existing file and `separator` (',' or a tab). +#' @keywords internal +.find_path_with_fallback <- function(path) { + + scan_file <- function(file_path) { + if (.path_exists(file_path)) { + separator <- if (grepl("\\.tsv", file_path, ignore.case = TRUE)) "\t" else "," + return(list(path = file_path, separator = separator)) + } + return(NULL) + } + + result <- scan_file(path) + if (!is.null(result)) { + return(result) + } + + # Try the alternative extension + alt_path <- NULL + if (endsWith(path, ".csv.gz")) { + alt_path <- sub("\\.gz$", "", path) # Remove .gz -> try .csv + } else if (endsWith(path, ".csv")) { + alt_path <- paste0(path, ".gz") # Add .gz -> try .csv.gz + } else if (endsWith(path, ".tsv.gz")) { + alt_path <- sub("\\.gz$", "", path) # Remove .gz -> try .tsv + } else if (endsWith(path, ".tsv")) { + alt_path <- paste0(path, ".gz") # Add .gz -> try .tsv.gz + } else { + stop(sprintf("Path does not have expected extension: %s", path)) + } + + alt_result <- scan_file(alt_path) + if (!is.null(alt_result)) { + message(sprintf("Original path does not exist. Using alternative: %s", alt_path)) + return(alt_result) + } + + stop(sprintf("Neither path exists: %s or %s", path, alt_path)) +} + diff --git a/R/Dataset_MIMIC3Dataset.R b/R/Dataset_MIMIC3Dataset.R index 4e2ad0b..ac99810 100644 --- a/R/Dataset_MIMIC3Dataset.R +++ b/R/Dataset_MIMIC3Dataset.R @@ -4,10 +4,6 @@ #' It ensures key tables like patients, admissions, and icustays are loaded, #' and allows appending additional tables. Also provides per-table preprocessing. #' -#' -#' @docType class -#' @method initialize MIMIC3Dataset -#' @usage \method{MIMIC3Dataset}{initialize}(...) #' @export MIMIC3Dataset <- R6::R6Class( "MIMIC3Dataset", @@ -19,7 +15,7 @@ MIMIC3Dataset <- R6::R6Class( #' @param dataset_name Optional dataset name. #' @param config_path Optional path to YAML config file. #' @param dev Logical flag for dev mode. - #' @param ... Additional arguments passed to `BaseDataset`. + #' @param ... Additional arguments passed to `BaseDataset$initialize`. initialize = function(root, tables = character(), dataset_name = NULL, @@ -50,20 +46,6 @@ MIMIC3Dataset <- R6::R6Class( if ("prescriptions" %in% tables) { warning("Timestamp granularity of prescriptions is not enough.") } - }, - - #' @description - #' Table-specific preprocessing for noteevents. - #' If `charttime` is missing, fills it with `chartdate` + " 00:00:00". - #' @param df A polars LazyFrame. - #' @return A modified LazyFrame. - preprocess_noteevents = function(df) { - df$with_columns( - pl$when(pl$col("charttime")$is_null())$ - then(pl$col("chartdate") + pl$lit(" 00:00:00"))$ - otherwise(pl$col("charttime"))$ - alias("charttime") - ) } ) ) diff --git a/R/Dataset_MIMIC4EHRDataset.R b/R/Dataset_MIMIC4EHRDataset.R index 6c2295c..ea2f8b3 100644 --- a/R/Dataset_MIMIC4EHRDataset.R +++ b/R/Dataset_MIMIC4EHRDataset.R @@ -4,9 +4,6 @@ #' This class inherits from BaseDataset and is specialized for handling MIMIC-IV EHR data. #' It ensures key tables like patients, admissions, and icustays are included, #' and allows appending additional tables. It also logs memory usage if needed. -#' @docType class -#' @method initialize MMIMIC4EHRDataset -#' @usage \method{MIMIC4EHRDataset}{initialize}(...) #' @export MIMIC4EHRDataset <- R6::R6Class( "MIMIC4EHRDataset", @@ -19,7 +16,7 @@ MIMIC4EHRDataset <- R6::R6Class( #' @param dataset_name Optional dataset name. Default is "mimic4_ehr". #' @param config_path Optional path to YAML config file. #' @param dev Logical flag for dev mode. - #' @param ... Additional arguments passed to `BaseDataset`. + #' @param ... Additional arguments passed to `BaseDataset$initialize`. initialize = function(root, tables = character(), dataset_name = "mimic4_ehr", diff --git a/R/Dataset_MIMIC4NoteDataset.R b/R/Dataset_MIMIC4NoteDataset.R index 51625cc..07cd073 100644 --- a/R/Dataset_MIMIC4NoteDataset.R +++ b/R/Dataset_MIMIC4NoteDataset.R @@ -4,9 +4,6 @@ #' This class inherits from BaseDataset and is specialized for handling MIMIC-IV Clinical Notes data. #' It includes tables such as discharge, discharge_detail, and radiology. #' -#' @docType class -#' @method initialize MIMIC4NoteDataset -#' @usage \method{MIMIC4NoteDataset}{initialize}(...) #' @export MIMIC4NoteDataset <- R6::R6Class( "MIMIC4NoteDataset", @@ -19,7 +16,7 @@ MIMIC4NoteDataset <- R6::R6Class( #' @param dataset_name Optional dataset name. Default is "mimic4_note". #' @param config_path Optional path to YAML config file. #' @param dev Logical flag for dev mode. - #' @param ... Additional arguments passed to `BaseDataset`. + #' @param ... Additional arguments passed to `BaseDataset$initialize`. initialize = function(root, tables = character(), dataset_name = "mimic4_note", diff --git a/R/Dataset_SampleDataset.R b/R/Dataset_SampleDataset.R index a37b82c..2ae2020 100644 --- a/R/Dataset_SampleDataset.R +++ b/R/Dataset_SampleDataset.R @@ -2,6 +2,7 @@ #' @description Sample dataset class for handling and processing data samples. #' @import R6 #' @import torch +#' @importFrom progressr progressor #' @export SampleDataset <- torch::dataset( "SampleDataset", @@ -38,11 +39,13 @@ SampleDataset <- torch::dataset( #' @param output_schema Named list specifying types for outputs #' @param dataset_name Optional dataset name #' @param task_name Optional task name + #' @param save_path Optional path to save the processed dataset. initialize = function(samples, input_schema, output_schema, dataset_name = "", - task_name = "") { + task_name = "", + save_path = NULL) { self$samples <- samples self$input_schema <- input_schema self$output_schema <- output_schema @@ -53,81 +56,123 @@ SampleDataset <- torch::dataset( self$patient_to_index <- list() self$record_to_index <- list() - for (i in seq_along(samples)) { - sample <- samples[[i]] - patient_id <- sample[["patient_id"]] - if (!is.null(patient_id)) { - if (is.null(self$patient_to_index[[patient_id]])) { - self$patient_to_index[[patient_id]] <- c() - } - self$patient_to_index[[patient_id]] <- c(self$patient_to_index[[patient_id]], i) - } - - record_id <- sample[["record_id"]] - if (is.null(record_id)) { - record_id <- sample[["visit_id"]] - } - if (!is.null(record_id)) { - if (is.null(self$record_to_index[[record_id]])) { - self$record_to_index[[record_id]] <- c() - } - self$record_to_index[[record_id]] <- c(self$record_to_index[[record_id]], i) + # Efficiently create the patient_to_index mapping + if (length(samples) > 0) { + patient_ids <- sapply(samples, function(s) as.character(s[["patient_id"]])) + self$patient_to_index <- split(seq_along(samples), patient_ids) + + record_ids <- sapply(samples, function(s) { + s[["record_id"]] %||% s[["visit_id"]] %||% s[["admission_id"]] %||% NA + }) + valid_indices <- !is.na(record_ids) + if (any(valid_indices)) { + self$record_to_index <- split(seq_along(samples)[valid_indices], + as.character(record_ids[valid_indices])) } } + self$validate() - self$build() + self$build_and_process(save_path = save_path) message("samples built") - }, #' @description Check that all samples contain required schema fields validate = function() { + if (length(self$samples) == 0) return() input_keys <- names(self$input_schema) output_keys <- names(self$output_schema) - - for (sample in self$samples) { - sample_keys <- names(sample) - stopifnot(all(input_keys %in% sample_keys)) - stopifnot(all(output_keys %in% sample_keys)) - } + sample_keys <- names(self$samples[[1]]) + stopifnot(all(input_keys %in% sample_keys)) + stopifnot(all(output_keys %in% sample_keys)) }, #' @description Build processors and transform all samples - build = function() { + build_and_process = function(save_path = NULL) { + # Build processors + message("Building input processors...") + p_input <- progressr::progressor(steps = length(names(self$input_schema))) for (k in names(self$input_schema)) { processor_type <- self$input_schema[[k]] processor <- get_processor(processor_type)$new() processor$fit(self$samples, k) self$input_processors[[k]] <- processor + p_input() } + + message("Building output processors...") + p_output <- progressr::progressor(steps = length(names(self$output_schema))) for (k in names(self$output_schema)) { processor_type <- self$output_schema[[k]] processor <- get_processor(processor_type)$new() - - processor$fit(self$samples, k) - self$output_processors[[k]] <- processor + p_output() + } + tensors_path <- NULL + if (!is.null(save_path)) { + if (!dir.exists(save_path)) { + dir.create(save_path, recursive = TRUE) + } + tensors_path <- file.path(save_path, "tensors") + if (!dir.exists(tensors_path)) { + dir.create(tensors_path) + } } + + # Process all samples upfront + message("Processing samples...") + p <- progressr::progressor(steps = length(self$samples)) + samples_for_saving <- rlang::duplicate(self$samples, shallow = FALSE) + for (i in seq_along(self$samples)) { sample <- self$samples[[i]] + sample_for_save <- samples_for_saving[[i]] + for (k in names(sample)) { + processed_val <- NULL if (!is.null(self$input_processors[[k]])) { - sample[[k]] <- self$input_processors[[k]]$process(sample[[k]]) + processed_val <- self$input_processors[[k]]$process(sample[[k]]) } else if (!is.null(self$output_processors[[k]])) { - sample[[k]] <- self$output_processors[[k]]$process(sample[[k]]) + processed_val <- self$output_processors[[k]]$process(sample[[k]]) + } + + if (!is.null(processed_val)) { + sample[[k]] <- processed_val # Keep tensor in memory + if (!is.null(save_path) && inherits(processed_val, "torch_tensor")) { + tensor_file <- file.path(tensors_path, paste0("sample_", i, "_", k, ".pt")) + torch::torch_save(processed_val, tensor_file) + sample_for_save[[k]] <- list(.is_tensor_placeholder = TRUE, path = tensor_file) + } else { + sample_for_save[[k]] <- processed_val + } } } self$samples[[i]] <- sample + samples_for_saving[[i]] <- sample_for_save + p() + } + + if (!is.null(save_path)) { + clone <- self$clone() + clone$samples <- samples_for_saving + saveRDS(clone, file.path(save_path, "sd_object.rds")) + message(sprintf("SampleDataset saved to: %s", save_path)) } }, #' @description Get a sample by index - #' @param index Integer index #' @return Named list representing the sample .getitem = function(index) { sample <- self$samples[[index]] + + for (k in names(sample)) { + item <- sample[[k]] + if (is.list(item) && isTRUE(item$.is_tensor_placeholder)) { + sample[[k]] <- torch::torch_load(item$path) + } + } + input_keys <- names(self$input_processors) output_keys <- names(self$output_processors) selected_keys <- union(input_keys, output_keys) @@ -142,9 +187,65 @@ SampleDataset <- torch::dataset( }, #' @description Printable description of dataset - #' @param ... Ignored print = function(...) { cat(sprintf("Sample dataset %s %s\n", self$dataset_name, self$task_name)) - } - + }, + public = list( + save_to_disk = function(path) { + if (!dir.exists(path)) { + dir.create(path, recursive = TRUE) + } + tensors_path <- file.path(path, "tensors") + if (!dir.exists(tensors_path)) { + dir.create(tensors_path) + } + + clone <- self$clone() + + message("Saving tensors...") + p <- progressr::progressor(steps = length(clone$samples)) + for (i in seq_along(clone$samples)) { + for (k in names(clone$samples[[i]])) { + if (inherits(clone$samples[[i]][[k]], "torch_tensor")) { + tensor_file <- file.path(tensors_path, paste0("sample_", i, "_", k, ".pt")) + torch::torch_save(clone$samples[[i]][[k]], tensor_file) + clone$samples[[i]][[k]] <- list(.is_tensor_placeholder = TRUE, path = tensor_file) + } + } + p() + } + + saveRDS(clone, file.path(path, "sd_object.rds")) + message(sprintf("SampleDataset saved to: %s", path)) + } + ) ) + +#' @title Load a SampleDataset object from a directory +#' @description This function reconstructs a SampleDataset object from a directory. +#' @param path The directory path from where to load the dataset. +#' @return The reconstructed SampleDataset object. +#' @export +load_sample_dataset <- function(path) { + rds_file <- file.path(path, "sd_object.rds") + if (!file.exists(rds_file)) { + stop("Saved SampleDataset object not found at the specified path: ", rds_file) + } + + sd_object <- readRDS(rds_file) + + message("Loading tensors...") + p <- progressr::progressor(steps = length(sd_object$samples)) + for (i in seq_along(sd_object$samples)) { + for (k in names(sd_object$samples[[i]])) { + item <- sd_object$samples[[i]][[k]] + if (is.list(item) && isTRUE(item$.is_tensor_placeholder)) { + sd_object$samples[[i]][[k]] <- torch::torch_load(item$path) + } + } + p() + } + + message("SampleDataset loaded successfully.") + return(sd_object) +} diff --git a/R/Dataset_data.R b/R/Dataset_data.R index 486eafb..dabbb36 100644 --- a/R/Dataset_data.R +++ b/R/Dataset_data.R @@ -65,10 +65,10 @@ Event <- R6::R6Class( #' #' @description #' The `Patient` class manages all clinical events for a single patient. -#' It supports efficient event-type partitioning, fast time-range slicing, and flexible multi-condition filtering using rpolars. +#' It supports efficient event-type partitioning, fast time-range slicing, and flexible multi-condition filtering. #' #' @details -#' - Data is held as a polars DataFrame. +#' - Data is held as a data.frame. #' - Events can be retrieved as either raw data frames or Event object lists. #' #' @export @@ -78,89 +78,83 @@ Patient <- R6::R6Class( #' @field patient_id Character. Unique identifier for the patient. patient_id = NULL, - #' @field data_source Polars DataFrame. All events for this patient, sorted by timestamp. + #' @field data_source data.frame. All events for this patient, sorted by timestamp. data_source = NULL, - #' @field event_type_partitions List. Mapping event type to corresponding polars DataFrames. + #' @field event_type_partitions List. Mapping event type to corresponding data.frames. event_type_partitions = NULL, #' @description #' Create a Patient object. #' @param patient_id Character. Unique patient identifier. - #' @param data_source Polars DataFrame. All events (must include event_type, timestamp columns). + #' @param data_source data.frame. All events (must include event_type, timestamp columns). #' @return A `Patient` object. initialize = function(patient_id, data_source) { self$patient_id <- patient_id - self$data_source <- data_source$sort("timestamp") - self$event_type_partitions <- self$data_source$partition_by("event_type", maintain_order = TRUE, include_key = TRUE, as_nested_list = TRUE) + self$data_source <- data_source %>% dplyr::arrange(timestamp) + self$event_type_partitions <- split(self$data_source, self$data_source$event_type) }, #' @description #' Filter events by time range (O(n), regular scan). - #' @param df Polars DataFrame. Source event data. + #' @param df data.frame. Source event data. #' @param start Character/POSIXct. (Optional) Start time. #' @param end Character/POSIXct. (Optional) End time. - #' @return Polars DataFrame. Events in specified range. + #' @return data.frame. Events in specified range. filter_by_time_range_regular = function(df, start = NULL, end = NULL) { - if (!is.null(start)) df <- df$filter(pl$col("timestamp") >= as.character(start)) - if (!is.null(end)) df <- df$filter(pl$col("timestamp") <= as.character(end)) + if (!is.null(start)) df <- df %>% dplyr::filter(timestamp >= as.POSIXct(start)) + if (!is.null(end)) df <- df %>% dplyr::filter(timestamp <= as.POSIXct(end)) df }, #' @description #' Efficient time range filter via binary search (O(log n)), requires sorted data. - #' @param df Polars DataFrame. Source event data. + #' @param df data.frame. Source event data. #' @param start Character/POSIXct. (Optional) Start time. #' @param end Character/POSIXct. (Optional) End time. - #' @return Polars DataFrame. Filtered events. + #' @return data.frame. Filtered events. filter_by_time_range_fast = function(df, start = NULL, end = NULL) { if (is.null(start) && is.null(end)) return(df) - df <- df$filter(pl$col("timestamp")$is_not_null()) - ts_col <- df$to_data_frame()[["timestamp"]] + df <- df %>% dplyr::filter(!is.na(timestamp)) + ts_col <- df[["timestamp"]] ts_col <- as.POSIXct(ts_col, tz = "UTC") - start_idx <- 0 + start_idx <- 1 end_idx <- length(ts_col) if (!is.null(start)) { - start <- as.POSIXct(start, tz = "UTC") -1 - start_idx <- as.integer(findInterval(start, ts_col, left.open = FALSE)) - }else { - start_idx <- 0 + start <- as.POSIXct(start, tz = "UTC") - 1 + start_idx <- as.integer(findInterval(start, ts_col, left.open = FALSE)) + 1 } if (!is.null(end)) { end <- as.POSIXct(end, tz = "UTC") + 1 end_idx <- as.integer(findInterval(end, ts_col, left.open = TRUE)) - }else { - end_idx <- length(ts_col) } - return(df$slice(start_idx, end_idx - start_idx)) + if (start_idx > end_idx) return(df[0,]) + return(df[start_idx:end_idx, ]) }, #' @description #' Regular event type filter (O(n)). - #' @param df Polars DataFrame. + #' @param df data.frame. #' @param event_type Character. Type of event. - #' @return Polars DataFrame. + #' @return data.frame. filter_by_event_type_regular = function(df, event_type) { if (!is.null(event_type)) { - df <- df$filter(pl$col("event_type") == event_type) + df <- df %>% dplyr::filter(event_type == !!event_type) } df }, #' @description #' Fast event type filter (O(1)) using partitioned lookup. - #' @param df Polars DataFrame. + #' @param df data.frame. #' @param event_type Character. Type of event. - #' @return Polars DataFrame. Only the given event type. + #' @return data.frame. Only the given event type. filter_by_event_type_fast = function(df, event_type) { if (!is.null(event_type)) { - keystr <- event_type - match_idx <- which(sapply(self$event_type_partitions, \(x) x$key$event_type == keystr)) - - if (length(match_idx) > 0) { - return(self$event_type_partitions[[match_idx]]$data) + if (event_type %in% names(self$event_type_partitions)) { + return(self$event_type_partitions[[event_type]]) } else { - return(df$slice(0,0)) + return(df[0,]) } } else { return(df) @@ -183,32 +177,25 @@ Patient <- R6::R6Class( df <- self$filter_by_time_range_fast(df, start, end) if (is.null(filters)) filters <- list() if (length(filters) > 0 && is.null(event_type)) stop("event_type must be provided if filters are used") - exprs <- list() for (filt in filters) { if (!(is.list(filt) && length(filt) == 3)) stop("Each filter must be a 3-element list: (attr, op, value)") attr <- filt[[1]]; op <- filt[[2]]; val <- filt[[3]] - col_expr <- pl$col(sprintf("%s/%s", event_type, attr)) - exprs <- append(exprs, switch( - op, - "==" = col_expr == val, - "!=" = col_expr != val, - "<" = col_expr < val, - "<=" = col_expr <= val, - ">" = col_expr > val, - ">=" = col_expr >= val, - stop(sprintf("Unsupported operator: %s", op)) - )) - } - if (length(exprs) > 0) { - filter_expr <- purrr::reduce(`&`, exprs) - df <- df$filter(filter_expr) + + # Build a filter condition for dplyr + # The column name is constructed, and then we build the expression + col_name <- sprintf("%s/%s", event_type, attr) + + # This is a bit of metaprogramming to build the filter expression + # It's safer than paste() to avoid SQL injection-like issues, though not a DB here. + filter_expr <- rlang::call2(op, rlang::sym(col_name), val) + df <- df %>% dplyr::filter(!!filter_expr) } - + if (return_df) { return(df) } else { - - datalist <- df$to_data_frame() + + datalist <- df if (nrow(datalist) == 0) { return(list()) } diff --git a/R/Dataset_splitter.R b/R/Dataset_splitter.R index b4fdabd..b18d295 100644 --- a/R/Dataset_splitter.R +++ b/R/Dataset_splitter.R @@ -8,19 +8,67 @@ #' @param dataset A `SampleDataset` object. #' @param ratios A numeric vector of length 3 indicating train/val/test split ratios. Must sum to 1. #' @param seed Optional integer for reproducibility. +#' @param stratify Logical, whether to perform stratified sampling. Default: FALSE. +#' @param stratify_by Character, the name of the field to stratify by (e.g., the label). Required if `stratify` is TRUE. #' @param get_index Logical, whether to return the indices instead of subsets. Default: FALSE. #' @return A list of 3 torch::dataset_subset objects or 3 tensors of indices if get_index = TRUE. #' @export -split_by_sample <- function(dataset, ratios, seed = NULL, get_index = FALSE) { +split_by_sample <- function(dataset, ratios, seed = NULL, stratify = FALSE, stratify_by = NULL, get_index = FALSE) { stopifnot(sum(ratios) == 1.0) if (!is.null(seed)) set.seed(seed) - index <- sample(seq_len(length(dataset))) - n <- length(index) + if (stratify) { + if (is.null(stratify_by)) { + stop("`stratify_by` must be provided when `stratify` is TRUE.") + } - train_index <- index[1:floor(n * ratios[[1]])] - val_index <- index[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] - test_index <- index[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + # Extract stratification values + strata <- purrr::map_vec(dataset$samples, function(s) { + item <- s[[stratify_by]] + if (is.list(item) && isTRUE(item$.is_tensor_placeholder)) { + item <- torch::torch_load(item$path) + } + as.numeric(item) + }) + + # Get indices for each stratum + strata_indices <- split(seq_along(strata), strata) + + train_index <- c() + val_index <- c() + test_index <- c() + + for (indices in strata_indices) { + n_stratum <- length(indices) + shuffled_indices <- sample(indices) + + train_end <- floor(n_stratum * ratios[[1]]) + val_end <- train_end + floor(n_stratum * ratios[[2]]) + + if (train_end > 0) { + train_index <- c(train_index, shuffled_indices[1:train_end]) + } + if (val_end > train_end) { + val_index <- c(val_index, shuffled_indices[(train_end + 1):val_end]) + } + if (n_stratum > val_end) { + test_index <- c(test_index, shuffled_indices[(val_end + 1):n_stratum]) + } + } + + # Shuffle the final indices to mix strata + train_index <- sample(train_index) + val_index <- sample(val_index) + test_index <- sample(test_index) + + } else { + index <- sample(seq_len(length(dataset))) + n <- length(index) + + train_index <- index[1:floor(n * ratios[[1]])] + val_index <- index[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] + test_index <- index[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + } if (get_index) { return(list( @@ -41,19 +89,75 @@ split_by_sample <- function(dataset, ratios, seed = NULL, get_index = FALSE) { #' @param dataset A `SampleDataset` object. #' @param ratios A numeric vector of length 3 indicating train/val/test split ratios. Must sum to 1. #' @param seed Optional integer for reproducibility. -#' @return A list of 3 torch::dataset_subset objects split by patient id. +#' @param stratify Logical, whether to perform stratified sampling. Default: FALSE. +#' @param stratify_by Character, the name of the field to stratify by (e.g., the label). Required if `stratify` is TRUE. +#' @param get_index Logical, whether to return the indices instead of subsets. Default: FALSE. +#' @return A list of 3 torch::dataset_subset objects or 3 tensors of indices if get_index = TRUE, split by patient id. #' @export -split_by_patient <- function(dataset, ratios, seed = NULL) { +split_by_patient <- function(dataset, ratios, seed = NULL, stratify = FALSE, stratify_by = NULL, get_index = FALSE) { stopifnot(sum(ratios) == 1.0) if (!is.null(seed)) set.seed(seed) - patient_ids <- names(dataset$patient_to_index) - n <- length(patient_ids) - shuffled <- sample(patient_ids) + train_ids <- c() + val_ids <- c() + test_ids <- c() - train_ids <- shuffled[1:floor(n * ratios[[1]])] - val_ids <- shuffled[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] - test_ids <- shuffled[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + if (stratify) { + if (is.null(stratify_by)) { + stop("`stratify_by` must be provided when `stratify` is TRUE.") + } + + patient_strata <- list() + for (i in seq_along(dataset$samples)) { + sample <- dataset$samples[[i]] + if (is.null(sample$patient_id)) next + patient_id <- as.character(as.numeric(sample$patient_id)) + + # Assuming higher value is "worse" outcome (e.g., 1 for mortality) + item <- sample[[stratify_by]] + if (is.list(item) && isTRUE(item$.is_tensor_placeholder)) { + item <- torch::torch_load(item$path) + } + stratum <- as.numeric(item) + + if (is.null(stratum) || any(is.na(stratum))) next + + if (!patient_id %in% names(patient_strata)) { + patient_strata[[patient_id]] <- stratum + } else { + patient_strata[[patient_id]] <- max(patient_strata[[patient_id]], stratum, na.rm = TRUE) + } + } + + strata_groups <- split(names(patient_strata), as.factor(unlist(patient_strata))) + + for (group in strata_groups) { + n_group <- length(group) + shuffled_pids <- sample(group) + train_end <- floor(n_group * ratios[[1]]) + val_end <- train_end + floor(n_group * ratios[[2]]) + + if (train_end > 0) { + train_ids <- c(train_ids, shuffled_pids[1:train_end]) + } + if (val_end > train_end) { + val_ids <- c(val_ids, shuffled_pids[(train_end + 1):val_end]) + } + if (n_group > val_end) { + test_ids <- c(test_ids, shuffled_pids[(val_end + 1):n_group]) + } + } + train_ids <- sample(train_ids) + val_ids <- sample(val_ids) + test_ids <- sample(test_ids) + } else { + patient_ids <- names(dataset$patient_to_index) + n <- length(patient_ids) + shuffled <- sample(patient_ids) + train_ids <- shuffled[1:floor(n * ratios[[1]])] + val_ids <- shuffled[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] + test_ids <- shuffled[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + } flatten_indices <- function(ids) { unlist(purrr::map(ids, function(pid) dataset$patient_to_index[[pid]])) @@ -63,6 +167,14 @@ split_by_patient <- function(dataset, ratios, seed = NULL) { val_index <- flatten_indices(val_ids) test_index <- flatten_indices(test_ids) + if (get_index) { + return(list( + torch_tensor(train_index), + torch_tensor(val_index), + torch_tensor(test_index) + )) + } + list( dataset_subset(dataset, indices = train_index), dataset_subset(dataset, indices = val_index), @@ -74,18 +186,82 @@ split_by_patient <- function(dataset, ratios, seed = NULL) { #' @param dataset A `SampleDataset` object. #' @param ratios A numeric vector of length 3 indicating train/val/test split ratios. Must sum to 1. #' @param seed Optional integer for reproducibility. +#' @param stratify Logical, whether to perform stratified sampling. Default: FALSE. +#' @param stratify_by Character, the name of the field to stratify by (e.g., the label). Required if `stratify` is TRUE. #' @return A list of 3 torch::dataset_subset objects. #' @export -split_by_visit <- function(dataset, ratios, seed = NULL) { +split_by_visit <- function(dataset, ratios, seed = NULL, stratify = FALSE, stratify_by = NULL) { stopifnot(sum(ratios) == 1.0) if (!is.null(seed)) set.seed(seed) - index <- sample(seq_len(length(dataset))) - n <- length(index) + if (stratify) { + if (is.null(stratify_by)) { + stop("`stratify_by` must be provided when `stratify` is TRUE.") + } + + visit_strata <- list() + for (i in seq_along(dataset$samples)) { + sample <- dataset$samples[[i]] + visit_id <- sample[["record_id"]] %||% sample[["visit_id"]] %||% sample[["admission_id"]] %||% NA + if (is.na(visit_id)) next + visit_id <- as.character(visit_id) + + item <- sample[[stratify_by]] + if (is.list(item) && isTRUE(item$.is_tensor_placeholder)) { + item <- torch::torch_load(item$path) + } + stratum <- as.numeric(item) - train_index <- index[1:floor(n * ratios[[1]])] - val_index <- index[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] - test_index <- index[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + if (is.null(stratum) || any(is.na(stratum))) next + + if (!visit_id %in% names(visit_strata)) { + visit_strata[[visit_id]] <- stratum + } else { + visit_strata[[visit_id]] <- max(visit_strata[[visit_id]], stratum, na.rm = TRUE) + } + } + + strata_groups <- split(names(visit_strata), as.factor(unlist(visit_strata))) + + train_ids <- c() + val_ids <- c() + test_ids <- c() + + for (group in strata_groups) { + n_group <- length(group) + shuffled_vids <- sample(group) + train_end <- floor(n_group * ratios[[1]]) + val_end <- train_end + floor(n_group * ratios[[2]]) + + if (train_end > 0) { + train_ids <- c(train_ids, shuffled_vids[1:train_end]) + } + if (val_end > train_end) { + val_ids <- c(val_ids, shuffled_vids[(train_end + 1):val_end]) + } + if (n_group > val_end) { + test_ids <- c(test_ids, shuffled_vids[(val_end + 1):n_group]) + } + } + train_ids <- sample(train_ids) + val_ids <- sample(val_ids) + test_ids <- sample(test_ids) + + flatten_indices <- function(ids) { + unlist(purrr::map(ids, function(vid) dataset$record_to_index[[vid]])) + } + + train_index <- flatten_indices(train_ids) + val_index <- flatten_indices(val_ids) + test_index <- flatten_indices(test_ids) + } else { + index <- sample(seq_len(length(dataset))) + n <- length(index) + + train_index <- index[1:floor(n * ratios[[1]])] + val_index <- index[(floor(n * ratios[[1]]) + 1):floor(n * (ratios[[1]] + ratios[[2]]))] + test_index <- index[(floor(n * (ratios[[1]] + ratios[[2]])) + 1):n] + } list( dataset_subset(dataset, indices = train_index), diff --git a/R/Dataset_utils.R b/R/Dataset_utils.R index 996d1a1..fc05912 100644 --- a/R/Dataset_utils.R +++ b/R/Dataset_utils.R @@ -21,6 +21,9 @@ collate_fn_dict_with_padding <- function(batch) { if (values[[1]]$dim() == 0) { collated[[key]] <- torch_stack(values) } else if (values[[1]]$dim() >= 1) { + # Store original lengths before padding + lengths <- torch_tensor(sapply(values, function(v) v$shape[1]), dtype = torch_long()) + collated[[paste0(key, "_len")]] <- lengths collated[[key]] <- nn_utils_rnn_pad_sequence(values, batch_first = TRUE, padding_value = 0) } else { stop(sprintf("Unsupported tensor shape: %s", paste0(values[[1]]$shape, collapse = ","))) diff --git a/R/Medcode_cross_map.R b/R/Medcode_cross_map.R index fe19599..448150e 100644 --- a/R/Medcode_cross_map.R +++ b/R/Medcode_cross_map.R @@ -18,7 +18,10 @@ supported_cross <- function() { c( "ICD9CM_to_CCSCM", "ICD9PROC_to_CCSPROC", - "ICD10CM_to_CCSCM", "ICD10PROC_to_CCSPROC", "NDC_to_ATC" + "ICD10CM_to_CCSCM", "ICD10PROC_to_CCSPROC", "NDC_to_ATC", + "ICD10CM_to_ICD9CM", "ICD9CM_to_ICD10CM", + "ICD10PCS_to_ICD9PCS", "ICD9PCS_to_ICD10PCS", + "ICD10CMPCS_to_ICD9CM" ) } @@ -53,8 +56,43 @@ supported_cross <- function() { #' \code{\link{load_medcode}}, \code{\link{supported_cross}} #' #' @export +#' @importFrom stringr str_pad map_code <- function(code, from = "ICD9CM", to = "CCSCM") { name <- paste0(from, "_to_", to) df <- load_medcode(name) - df[df[[1]] == code, 2, drop = TRUE] + + # Pad numeric-only ICD9 codes with leading zeros, stripping decimals + if (from == "ICD9CM" && grepl("^[0-9.]+$", code)) { + code <- gsub("\\.", "", code) + code <- stringr::str_pad(code, 5, pad = "0", side = "left") + } else if (from == "ICD9PCS" && grepl("^[0-9.]+$", code)) { + code <- gsub("\\.", "", code) + code <- stringr::str_pad(code, 4, pad = "0", side = "left") + } + + # Normalize column names for broader compatibility + from_col <- tolower(gsub("[^A-Za-z0-9]", "", from)) + to_col <- tolower(gsub("[^A-Za-z0-9]", "", to)) + df_cols <- tolower(gsub("[^A-Za-z0-9]", "", names(df))) + + from_idx <- match(from_col, df_cols) + to_idx <- match(to_col, df_cols) + + # Fallback to original column names if not found + if (is.na(from_idx)) from_idx <- match(from, names(df)) + if (is.na(to_idx)) to_idx <- match(to, names(df)) + + # Fallback to default column positions if names don't match + if (is.na(from_idx) || is.na(to_idx)) { + from_idx <- 1 + to_idx <- 2 + } + + # Perform the lookup + matches <- df[[from_idx]] == code + if (any(matches, na.rm = TRUE)) { + return(df[matches, to_idx, drop = TRUE]) + } + + character(0) } diff --git a/R/Medcode_download.R b/R/Medcode_download.R index 2f8219d..e167f45 100644 --- a/R/Medcode_download.R +++ b/R/Medcode_download.R @@ -1,3 +1,25 @@ +.medcode_gdrive_links <- c( + "ATC" = "1LOAf-AheiZ28vkTcAn6K2X89-jpxU-ok", + "CCSCM" = "1FpsaT1EPaeJ2vw9WxCi0fTt3kr1jjmyi", + "CCSPROC" = "1dd6MNzENb9utr_F-YPwoNFV6wheHwBi8", + "DDI" = "15DKDEcENncyeAVieHV_kP8iCZl8hMpRP", + "ICD9CM_to_CCSCM" = "1Yruhix5yEH15C898p0VL_40G9K0dn-cR", + "ICD9CM" = "1UMF66hl5vxZ9SXIAJLCSC8ugxeYwkFV9", + "ICD9PROC_to_CCSPROC" = "16oFsOpgmtlDmMHAr5pW6KUCks1-ir5k6", + "ICD9PROC" = "1Sez38YseXaifokM2frRvZ8chNI8NhvB0", + "ICD10CM_to_CCSCM" = "1utcHE81_mbjqDuEbPvZ9uto9f_yGqr9n", + "ICD10CM" = "1Oe9A6x58O2ZaXhtfqnK5vqjbXMVIAH0W", + "ICD10PROC_to_CCSPROC" = "1gG44ALc8DVGT6Yg9HRUlwgLv9iuJ-6Ql", + "ICD10PROC" = "1ThPq6D16QXnK21fV_kl5JqD5A-rDVHQQ", + "NDC_to_ATC" = "11IQSkVaGjTc6kZ0XFUd_FgQa3uRrvihP", + "NDC" = "11mCQ3AJTkZvkC0WWfxLhaxU5Mg-a3S9L", + "ICD10CM_to_ICD9CM" = "1Ioo_Aq-sXmiO8FKmsupE8EidiWpm8Q_6", + "ICD9CM_to_ICD10CM" = "1_EUiJ8AINq4ktZcbuVGoxXDj8Ddy8QrQ", + "ICD10PCS_to_ICD9PCS" = "1ZQOw2ww73uqJBGHNfJDRpYVzHAXeXWz7", + "ICD9PCS_to_ICD10PCS" = "1pYsgXndNTaRvWieEBw4FQu8AvM25-Qh2", + "ICD10CMPCS_to_ICD9CM" = "12GlFYTmdxOVGjxSLQKR-qMv9q4rOka8e" +) + #' Download and Cache Medical Code CSV #' #' @description @@ -53,7 +75,12 @@ download_medcode <- function(name) { fs::dir_create(cache_dir, recurse = TRUE) dest <- fs::path(cache_dir, paste0(name, ".csv")) if (!fs::file_exists(dest)) { - url <- paste0(getOption("RHealth.medcode_base"), name, ".csv") + if (name %in% names(.medcode_gdrive_links)) { + file_id <- .medcode_gdrive_links[[name]] + url <- paste0("https://drive.google.com/uc?export=download&id=", file_id) + } else { + stop(paste("Medical code", name, "is not supported.")) + } httr::GET(url, httr::write_disk(dest, overwrite = TRUE)) } dest diff --git a/R/Metrics_ECE_calibration.R b/R/Metrics_ECE_calibration.R index 5da2114..a964948 100644 --- a/R/Metrics_ECE_calibration.R +++ b/R/Metrics_ECE_calibration.R @@ -8,10 +8,10 @@ #' containing predicted probabilities for the *positive* class #' (only the first column is used if a matrix is supplied). #' @param label Numeric vector **or** two-column matrix of true labels \ -#' encoded as 0/1 (only the first column is used if a matrix is supplied). +#' encoded as `0/1` (only the first column is used if a matrix is supplied). #' @param bins Integer. Number of bins (default 20). #' @param adaptive Logical. If `FALSE` (default) equal-width bins \ -#' spanning \\([0,1]\\) are used; if `TRUE` each bin contains the \ +#' spanning `0, 1` are used; if `TRUE` each bin contains the \ #' same number of samples (equal-size bins). #' #' @return A single numeric value – the (adaptive) ECE. diff --git a/R/Metrics_binary.R b/R/Metrics_binary.R index 263dddd..de717a7 100644 --- a/R/Metrics_binary.R +++ b/R/Metrics_binary.R @@ -30,8 +30,6 @@ #' y_prob <- runif(100) #' binary_metrics_fn(y_true, y_prob, metrics = c("accuracy", "ECE")) #' -#' @importFrom PRROC pr.curve -#' @importFrom pROC auc #' @importFrom MLmetrics Accuracy F1_Score Precision Recall #' @importFrom metrica balacc jaccardindex #' @importFrom psych cohen.kappa @@ -40,6 +38,10 @@ binary_metrics_fn <- function(y_true, y_prob, metrics = NULL, threshold = 0.5) { + + y_true <- to_numeric_vector(y_true) + y_prob <- to_numeric_vector(y_prob) + stopifnot(length(y_true) == length(y_prob), all(y_true %in% c(0, 1)), is.numeric(y_prob), @@ -60,48 +62,35 @@ binary_metrics_fn <- function(y_true, out[m] <- switch( m, pr_auc = { - y_true_vec <- to_numeric_vector(y_true) - y_prob_vec <- to_numeric_vector(y_prob) - has_pos <- any(y_true_vec == 1) - has_neg <- any(y_true_vec == 0) + has_pos <- any(y_true == 1) + has_neg <- any(y_true == 0) if (!(has_pos && has_neg)) { warning("PR AUC undefined: only one class present."); 0 } else { - PRROC::pr.curve( - scores.class0 = y_prob_vec[y_true_vec == 1], - scores.class1 = y_prob_vec[y_true_vec == 0] - )$auc.integral + MLmetrics::PRAUC(y_prob, y_true) } }, roc_auc = { - y_true_vec <- to_numeric_vector(y_true) - y_prob_vec <- to_numeric_vector(y_prob) - has_pos <- any(y_true_vec == 1) - has_neg <- any(y_true_vec == 0) + has_pos <- any(y_true == 1) + has_neg <- any(y_true == 0) if (!(has_pos && has_neg)) { warning("ROC AUC undefined: only one class present."); 0 } else { - as.numeric(pROC::auc(y_true_vec, y_prob_vec)) + MLmetrics::AUC(y_prob, y_true) } }, - accuracy = { - y_true_vec <- to_numeric_vector(y_true) - y_pred_vec <- to_numeric_vector(ifelse(y_prob >= threshold, 1, 0)) - mean(y_true_vec == y_pred_vec) - Accuracy(y_pred_vec, y_true_vec) - }, - + accuracy = Accuracy(y_pred, y_true), balanced_accuracy = as.numeric(balacc(data = data.frame(pred = y_pred, obs = y_true), pred = "pred", obs = "obs", tidy = FALSE)), - f1 = F1_Score(y_pred, y_true), - precision = Precision(y_pred, y_true), - recall = Recall(y_pred, y_true), + f1 = F1_Score(y_pred, y_true, positive = 1), + precision = Precision(y_pred, y_true, positive = 1), + recall = Recall(y_pred, y_true, positive = 1), cohen_kappa = cohen.kappa(table(y_pred, y_true))$kappa, jaccard = as.numeric(jaccardindex(data = data.frame(pred = y_pred, obs = y_true), pred = "pred", diff --git a/R/Model_BaseModel.R b/R/Model_BaseModel.R index cc0101e..996b887 100644 --- a/R/Model_BaseModel.R +++ b/R/Model_BaseModel.R @@ -12,9 +12,8 @@ BaseModel <- torch::nn_module( classname = "BaseModel", + #' @param dataset A dataset object (must have input_schema, output_schema, output_processors). initialize = function(dataset) { - - #' @param dataset A dataset object (must have input_schema, output_schema, output_processors). self$dataset <- dataset self$feature_keys <- names(dataset$input_schema) @@ -37,8 +36,8 @@ BaseModel <- torch::nn_module( #' @description Selects appropriate loss function based on task type in output schema. #' @return A function such as nnf_binary_cross_entropy_with_logits or nnf_cross_entropy. stopifnot(length(self$label_keys) == 1) - key <- self$label_keys[[1]] - mode <- self$dataset$output_schema[[key]] + label_key <- self$label_keys[[1]] + mode <- self$dataset$output_schema[[label_key]] if (mode == "binary") { return(torch::nnf_binary_cross_entropy_with_logits) @@ -53,12 +52,12 @@ BaseModel <- torch::nn_module( } }, + #' @title Prepare Predicted Probabilities + #' @description Converts logits into predicted probabilities for evaluation. + #' Format depends on task mode (sigmoid or softmax, or raw). + #' This method takes `logits` as input, which is a torch tensor with raw model outputs. + #' @return Torch tensor of probabilities. prepare_y_prob = function(logits) { - #' @title Prepare Predicted Probabilities - #' @description Converts logits into predicted probabilities for evaluation. - #' Format depends on task mode (sigmoid or softmax, or raw). - #' @param logits Torch tensor with raw model outputs. - #' @return Torch tensor of probabilities. stopifnot(length(self$label_keys) == 1) key <- self$label_keys[[1]] mode <- self$dataset$output_schema[[key]] @@ -76,6 +75,5 @@ BaseModel <- torch::nn_module( } return(y_prob) - }, - + } ) diff --git a/R/Model_EmbeddingModel.R b/R/Model_EmbeddingModel.R index 620f1ab..d6610f3 100644 --- a/R/Model_EmbeddingModel.R +++ b/R/Model_EmbeddingModel.R @@ -40,12 +40,11 @@ EmbeddingModel <- torch::nn_module( classname = "EmbeddingModel", inherit = BaseModel, + #' @param dataset A SampleDataset object containing input_processors. + #' @param embedding_dim Integer embedding dimension. Default is 128. initialize = function(dataset, embedding_dim = 128) { #' @description #' Initialize an EmbeddingModel by constructing embedding layers based on input processors. - #' - #' @param dataset A SampleDataset object containing input_processors. - #' @param embedding_dim Integer embedding dimension. Default is 128. #' @return None (initializes fields inside the object). # Call parent (BaseModel) initializer, which is assumed to set up `self$device` and other internals @@ -70,7 +69,7 @@ EmbeddingModel <- torch::nn_module( } else if (inherits(processor, "TimeseriesProcessor")) { # TimeseriesProcessor: use feature size to build a Linear layer mapping to embedding_dim self$module_list[[field_name]] <- nn_linear( - in_features = processor$size, + in_features = processor$size(), out_features = embedding_dim ) } @@ -78,31 +77,38 @@ EmbeddingModel <- torch::nn_module( # in forward(), inputs for that field will be passed through unchanged. } - self$embedding_layers <- nn_module_dict(self$module_list) + self$embedding_layers <- nn_module_dict(self$module_list) + # Manually zero out the embedding for the padding index + for (field_name in names(dataset$input_processors)) { + processor <- dataset$input_processors[[field_name]] + if (inherits(processor, "SequenceProcessor")) { + with_no_grad({ + self$embedding_layers[[field_name]]$weight[1, ] <- 0 + }) + } + } }, + #' @description + #' Perform a forward pass by computing embeddings (or passing through) for each field. + #' This method takes `inputs`, a named list of `torch_tensor` objects, with names matching dataset$input_processors. + #' @return A named list of `torch_tensor` objects after embedding (or passthrough). forward = function(inputs) { - #' @description - #' Perform a forward pass by computing embeddings (or passing through) for each field. - #' - #' @param inputs A named list of `torch_tensor` objects, with names matching dataset$input_processors. - #' @return A named list of `torch_tensor` objects after embedding (or passthrough). - embedded <- list() for (field_name in names(inputs)) { tensor <- inputs[[field_name]] - # 修正非法索引 0(确保 padding_idx = 0 不被作为有效 ID 使用) + # Correct illegal index 0 (ensure padding_idx = 0 is not used as a valid ID) tensor[tensor == 0] <- 1 - # 将输入张量移到 embedding 层的同一设备上 + # Move the input tensor to the same device as the embedding layer embed_device <- self$embedding_layers[[field_name]]$weight$device tensor <- tensor$to(device = embed_device) - # 如果有 embedding 层,则嵌入;否则 passthrough + # If an embedding layer exists, embed; otherwise passthrough if (field_name %in% names(self$module_list)) { embedded[[field_name]] <- self$embedding_layers[[field_name]](tensor) } else { @@ -113,12 +119,10 @@ EmbeddingModel <- torch::nn_module( }, + #' @description + #' Return a concise string representation of the EmbeddingModel, listing its embedding layers. + #' @return A character string representation. .repr = function() { - #' @description - #' Return a concise string representation of the EmbeddingModel, listing its embedding layers. - #' - #' @return A character string representation. - paste0( "EmbeddingModel(embedding_layers = {", paste(names(self$embedding_layers), collapse = ", "), diff --git a/R/Model_RNNModel.R b/R/Model_RNNModel.R index dbbd7db..ef758e3 100644 --- a/R/Model_RNNModel.R +++ b/R/Model_RNNModel.R @@ -52,67 +52,58 @@ RNNLayer <- torch::nn_module( ) self$num_directions <- if (bidirectional) 2 else 1 - self$null_hidden <- nn_parameter(torch_randn(c(self$num_directions * self$num_layers, 1, hidden_size))) + if (bidirectional) { + self$down_projection <- nn_linear(hidden_size * 2, hidden_size) + } }, - forward = function(x, mask = NULL) { - x <- x$to(dtype = torch_float()) + forward = function(x, mask = NULL, lengths = NULL) { x <- self$dropout_layer(x) B <- x$size(1) T <- x$size(2) - lengths_cpu <- if (is.null(mask)) { - torch_full(size = B, fill_value = T, dtype = torch_long()) - } else { - mask <- mask$to(dtype = torch_long()) - torch_sum(mask, dim = -1)$cpu() - } - lengths_cpu <- torch_clamp(lengths_cpu, min = 1) - valid_idx <- lengths_cpu > 0 - num_valid <- as.integer(torch_sum(valid_idx)$item()) - if (num_valid < B) { - warning("Some sequences in the batch have zero length and are being skipped.") + if (is.null(lengths)) { + lengths <- if (is.null(mask)) { + torch_full(size = B, fill_value = T, dtype = torch_long()) + } else { + torch_sum(mask$to(dtype = torch_long()), dim = -1)$cpu() + } } - if (num_valid == 0) { - stop("All sequences in this batch have zero length! Cannot proceed.") - } - - x_v <- x[valid_idx$to(device = x$device), ] - len_v <- lengths_cpu[valid_idx] - + packed <- nn_utils_rnn_pack_padded_sequence( - x_v, len_v$to(dtype = torch_int()), batch_first = TRUE, enforce_sorted = FALSE + x, lengths$to(dtype = torch_int()), batch_first = TRUE, enforce_sorted = FALSE ) out_packed <- self$rnn(packed) - out <- nn_utils_rnn_pad_packed_sequence(out_packed[[1]], batch_first = TRUE, total_length = T)[[1]] - - z <- torch_zeros(c(B, T, self$hidden_size), device = x$device) - - valid_idx_tensor <- torch_where(valid_idx)[[1]] - valid_indices <- valid_idx_tensor$to(dtype = torch_long()) - - # for (i in seq_len(num_valid)) { - # z[as.integer(valid_indices[i]$item()), , ] <- out[i, , ] - # } - valid_indices <- torch_where(valid_idx)[[1]]$to(dtype = torch_long()) - - z <- torch_index_put_( - self = z, - indices = list(valid_indices), - values = out, - accumulate = FALSE - ) - - last_indices <- lengths_cpu$view(c(-1, 1, 1))$ - expand(c(B, 1, self$hidden_size))$ - to(dtype = torch_long(), device = out$device) - - last_outputs <- out$gather(dim = 2, index = last_indices)$squeeze(2) + outputs <- nn_utils_rnn_pad_packed_sequence(out_packed[[1]], batch_first = TRUE, total_length = T)[[1]] + - return(list(outputs = z, last_outputs = last_outputs)) + if (!self$bidirectional) { + H <- outputs$shape[3] + index <- lengths$to(device = outputs$device, dtype = torch_long())$view(c(B, 1, 1))$expand(c(B, 1, H)) + last_outputs <- outputs$gather(dim = 2, index = index)$squeeze(2) + # message(sprintf("RNNLayer: Sequence lengths are %s", paste(as.array(lengths$cpu()), collapse = ", "))) + # message(sprintf("RNNLayer: Shape of last_outputs is %s", paste(last_outputs$shape, collapse = " x "))) + return(list(outputs = outputs, last_outputs = last_outputs)) + } else { + outputs_reshaped <- outputs$view(c(B, T, 2, -1)) + H_half <- outputs_reshaped$shape[4] + + f_outputs <- outputs_reshaped[.., 1, ] + index <- lengths$to(device = outputs$device, dtype = torch_long())$view(c(B, 1, 1))$expand(c(B, 1, H_half)) + f_last_outputs <- f_outputs$gather(dim = 2, index = index)$squeeze(2) + + b_last_outputs <- outputs_reshaped[, 1, 2, ] + + last_outputs <- torch_cat(list(f_last_outputs, b_last_outputs), dim = -1) + + last_outputs <- self$down_projection(last_outputs) + outputs <- self$down_projection(outputs) + + return(list(outputs = outputs, last_outputs = last_outputs)) + } } ) @@ -185,20 +176,24 @@ RNN <- torch::nn_module( }, forward = function(inputs) { - patient_emb <- list() - y_true <- inputs[[self$label_key]] ay_true <- y_true$clone() - embedded <- self$embedding_model(inputs) + feature_inputs <- inputs[self$feature_keys] + embedded <- self$embedding_model(feature_inputs) - for (feature_key in self$feature_keys) { + patient_emb <- lapply(self$feature_keys, function(feature_key) { x <- embedded[[feature_key]] - mask <- (x$sum(dim = -1) != 0)$to(dtype = torch_long()) - lengths <- torch_sum(mask, dim = -1) - result <- self$rnn[[feature_key]](x, mask) - hidden <- result[[2]] - patient_emb[[feature_key]] <- result[[2]] - } + len_key <- paste0(feature_key, "_len") + lengths <- if (len_key %in% names(inputs)) inputs[[len_key]] else NULL + + # The mask is now only needed if lengths are not provided. + # For backwards compatibility or other use cases. + mask <- if (is.null(lengths)) (x$sum(dim = -1)$abs() > 1e-6)$to(dtype = torch_long()) else NULL + + result <- self$rnn[[feature_key]](x = x, mask = mask, lengths = lengths) + result[[2]] + }) + patient_vec <- torch_cat(patient_emb, dim = 2) logits <- self$fc(patient_vec) device <- logits$device @@ -213,13 +208,12 @@ RNN <- torch::nn_module( y_true <- y_true$to(device = device) loss <- self$get_loss_function()(logits, y_true) - dtype = torch_long() y_prob <- self$prepare_y_prob(logits) ay_true <- ay_true$to(device = device) results <- list( loss = loss, y_prob = y_prob, - y_true = ay_true, + y_true = y_true, logit = logits ) diff --git a/R/Processor_binary.R b/R/Processor_binary.R index afa4f65..23ca376 100644 --- a/R/Processor_binary.R +++ b/R/Processor_binary.R @@ -43,9 +43,9 @@ BinaryLabelProcessor <- R6::R6Class("BinaryLabelProcessor", message(sprintf("Label '%s' vocab: %s", field, paste(names(self$label_vocab), collapse = ", "))) }, - #' @description Process a label into a torch tensor [0] or [1]. + #' @description Process a label into a torch tensor `[0]` or `[1]`. #' @param value A single label value. - #' @return A float32 torch tensor of shape [1]. + #' @return A float32 torch tensor of shape `1`. process = function(value) { index <- self$label_vocab[[as.character(value)]] torch::torch_tensor(index, dtype = torch::torch_float()) @@ -57,7 +57,6 @@ BinaryLabelProcessor <- R6::R6Class("BinaryLabelProcessor", return(1) }, - #' @description Print a summary of the processor. #' @description Print a summary of the processor. #' @param ... Ignored. print = function(...) { diff --git a/R/Processor_multilabel.R b/R/Processor_multilabel.R index f19a376..e51e8fc 100644 --- a/R/Processor_multilabel.R +++ b/R/Processor_multilabel.R @@ -35,7 +35,7 @@ MultiLabelProcessor <- R6::R6Class("MultiLabelProcessor", #' @description Process a list of active labels into a one-hot float tensor. #' @param value A character or numeric vector of active labels. - #' @return A torch tensor of shape [num_classes] with 0s and 1s. + #' @return A torch tensor of shape `num_classes` with 0s and 1s. process = function(value) { if (!is.vector(value)) { stop("Expected a vector (label list) for multilabel task.", call. = FALSE) diff --git a/R/Processor_regression.R b/R/Processor_regression.R index 76dec6e..09fb5b8 100644 --- a/R/Processor_regression.R +++ b/R/Processor_regression.R @@ -9,13 +9,13 @@ RegressionLabelProcessor <- R6::R6Class("RegressionLabelProcessor", public = list( #' @description Process a numeric label into a single-element float tensor. #' @param value A numeric value. - #' @return A torch tensor of shape [1]. + #' @return A torch tensor of shape `[1]`. process = function(value) { torch::torch_tensor(as.numeric(value), dtype = torch::torch_float()) }, #' @description Return the size of the processed label (always 1). - #' @return Integer 1 + #' @return Integer `1` size = function() { 1 }, diff --git a/R/Processor_sequence.R b/R/Processor_sequence.R index 30d3857..f94ab69 100644 --- a/R/Processor_sequence.R +++ b/R/Processor_sequence.R @@ -28,22 +28,34 @@ SequenceProcessor <- R6::R6Class("SequenceProcessor", stop("Input to SequenceProcessor must be a vector (sequence of tokens).", call. = FALSE) } - indices <- integer(length(value)) - for (i in seq_along(value)) { - token <- value[[i]] - if (is.null(token)) { - indices[i] <- self$code_vocab[[""]] - } else { - key <- as.character(token) - if (!(key %in% names(self$code_vocab))) { - self$code_vocab[[key]] <- self$.next_index - self$.next_index <- self$.next_index + 1 - } - indices[i] <- self$code_vocab[[key]] - } + # Vectorized handling of NULLs and conversion to character + if (is.list(value)) { + is_null_mask <- vapply(value, is.null, FUN.VALUE = logical(1)) + tokens <- as.character(value) + tokens[is_null_mask] <- "" + } else { + tokens <- as.character(value) + tokens[is.na(tokens)] <- "" } - torch::torch_tensor(indices, dtype = torch::torch_long()) + # Find unique tokens that are not yet in the vocabulary + unique_tokens <- unique(tokens) + existing_indices <- match(unique_tokens, names(self$code_vocab)) + new_token_mask <- is.na(existing_indices) + new_tokens <- unique_tokens[new_token_mask] + + # Add new tokens to the vocabulary in one go + if (length(new_tokens) > 0) { + new_indices_start <- self$.next_index + new_indices <- seq.int(from = new_indices_start, length.out = length(new_tokens)) + names(new_indices) <- new_tokens + self$code_vocab <- c(self$code_vocab, new_indices) + self$.next_index <- new_indices_start + length(new_tokens) + } + + # Get indices for all tokens in the original sequence using a single lookup + indices <- self$code_vocab[tokens] + torch::torch_tensor(unname(indices), dtype = torch::torch_long()) }, #' @description Return size of vocabulary. diff --git a/R/Processor_timeseries.R b/R/Processor_timeseries.R index bcab498..02864a2 100644 --- a/R/Processor_timeseries.R +++ b/R/Processor_timeseries.R @@ -32,7 +32,7 @@ TimeseriesProcessor <- R6::R6Class("TimeseriesProcessor", #' Step 2: impute missing entries using selected strategy. #' #' @param value A list: list(timestamps = POSIXct vector, values = matrix). - #' @return A torch tensor of shape [T, F]. + #' @return A torch tensor of shape `[T, F]`. process = function(value) { timestamps <- value[[1]] values <- value[[2]] diff --git a/R/RHealth-package.R b/R/RHealth-package.R new file mode 100644 index 0000000..693c459 --- /dev/null +++ b/R/RHealth-package.R @@ -0,0 +1,8 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start +#' @importFrom stats aggregate +#' @importFrom stats setNames +## usethis namespace: end +NULL diff --git a/R/Task_BaseTask.R b/R/Task_BaseTask.R index bafe82d..4ade5a7 100644 --- a/R/Task_BaseTask.R +++ b/R/Task_BaseTask.R @@ -32,7 +32,7 @@ BaseTask <- R6::R6Class( #' @description Main processing function. Must be overridden in subclasses. #' @param patient A list or structured object representing a single patient or record. - #' @return A list of named lists (equivalent to Python List[Dict]) representing the task result. + #' @return A list of named lists representing the task result. call = function(patient) { stop("`call()` is an abstract method and must be implemented by a subclass.") } diff --git a/R/Task_InHospitalMortalityMIMIC3.R b/R/Task_InHospitalMortalityMIMIC3.R new file mode 100644 index 0000000..49d4af0 --- /dev/null +++ b/R/Task_InHospitalMortalityMIMIC3.R @@ -0,0 +1,180 @@ +#' @title InHospitalMortalityMIMIC3 Task +#' @description Task for predicting in-hospital mortality using MIMIC-III dataset. +#' This task leverages lab results from the first 48 hours of an admission to +#' predict the likelihood of in-hospital mortality. +#' @import R6 +#' @import dplyr +#' @import tidyr +#' @import lubridate +#' @export +InHospitalMortalityMIMIC3 <- R6::R6Class( + classname = "InHospitalMortalityMIMIC3", + inherit = BaseTask, + public = list( + #' @field task_name The name of the task. + task_name = "InHospitalMortalityMIMIC3", + #' @field input_schema The schema for input data. + input_schema = list(labs = "timeseries"), + #' @field output_schema The schema for output data. + output_schema = list(mortality = "binary"), + #' @field label The name of the label column. + label = "mortality", + #' @field LABITEMS A list of lab item IDs used in this task for MIMIC-III. + LABITEMS = c( + # Electrolytes & Metabolic + "50824", "52455", "50983", "52623", # Sodium + "50822", "52452", "50971", "52610", # Potassium + "50806", "52434", "50902", "52535", # Chloride + "50803", "50804", # Bicarbonate + "50809", "52027", "50931", "52569", # Glucose + "50808", "51624", # Calcium + "50960", # Magnesium + "50868", "52500", # Anion Gap + "52031", "50964", "51701", # Osmolality + "50970" # Phosphate + ), + + #' @description Initialize a new InHospitalMortalityMIMIC3 instance. + initialize = function() { + super$initialize( + task_name = self$task_name, + input_schema = self$input_schema, + output_schema = self$output_schema + ) + }, + + #' @description Pre-filter hook to retain only necessary columns for this task. + #' @param df A lazy query containing all events. + #' @return A filtered LazyFrame with only relevant columns. + pre_filter = function(df) { + required_cols <- c( + "patient_id", "event_type", "timestamp", + "patients/dob", + "admissions/dischtime", "admissions/hospital_expire_flag", "admissions/hadm_id", + "labevents/itemid", "labevents/charttime", "labevents/valuenum" + ) + existing_cols <- colnames(df) + keep_cols <- intersect(required_cols, existing_cols) + df %>% dplyr::select(dplyr::all_of(keep_cols)) + }, + + #' @description Main processing method to generate samples. + #' @param patient An object with method `get_events(event_type, ...)`. + #' @return A list of samples. + call = function(patient) { + input_window_hours <- 48 + samples <- list() + + demographics <- patient$get_events(event_type = "patients") + if (length(demographics) == 0) { + return(samples) + } + dob <- tryCatch(as.POSIXct(demographics[[1]]$get("dob")), error = function(e) NULL) + if (is.null(dob)) { + return(samples) + } + + admissions <- patient$get_events(event_type = "admissions") + for (admission in admissions) { + admission_timestamp <- admission$timestamp + age <- as.numeric(difftime(admission_timestamp, dob, units = "days") / 365.25) + + if (is.na(age) || age < 18) { + next + } + + admission_dischtime <- tryCatch( + as.POSIXct(admission$get("dischtime")), + error = function(e) NULL + ) + if (is.null(admission_dischtime)) next + + duration_hour <- as.numeric(difftime( + admission_dischtime, + admission_timestamp, + units = "hours") + ) + if (duration_hour <= input_window_hours) { + next + } + predict_time <- admission_timestamp + lubridate::hours(input_window_hours) + + labevents <- patient$get_events( + event_type = "labevents", + start = admission_timestamp, + end = predict_time + ) + + if (length(labevents) == 0) next + + labevents_df_list <- purrr::map(labevents, ~{ + itemid <- .x$get("itemid") + if (!is.null(itemid) && itemid %in% self$LABITEMS) { + charttime_str <- tryCatch(.x$get("charttime"), error = function(e) NULL) + if (is.null(charttime_str)) { + charttime_str <- .x$timestamp + } + if (!is.null(charttime_str)) { + charttime <- tryCatch(as.POSIXct(charttime_str), error = function(e) NULL) + if (!is.null(charttime) && charttime <= predict_time) { + valuenum <- .x$get("valuenum") + if (!is.null(valuenum) && !is.na(valuenum)) { + return( + tibble::tibble( + timestamp = .x$timestamp, + itemid = as.character(itemid), + valuenum = as.numeric(valuenum) + ) + ) + } + } + } + } + return(NULL) + }) + + labevents_df <- dplyr::bind_rows(labevents_df_list) + + if (nrow(labevents_df) == 0) next + + labevents_df <- labevents_df %>% + dplyr::filter(!is.na(.data$valuenum)) %>% + dplyr::group_by(.data$timestamp, .data$itemid) %>% + dplyr::summarise(valuenum = dplyr::first(.data$valuenum), .groups = "drop") %>% + tidyr::pivot_wider( + names_from = .data$itemid, + values_from = .data$valuenum + ) %>% + dplyr::arrange(.data$timestamp) + + existing_cols <- setdiff(colnames(labevents_df), "timestamp") + missing_cols <- setdiff(self$LABITEMS, existing_cols) + + if (length(missing_cols) > 0) { + for (col in missing_cols) { + labevents_df[[col]] <- NA_real_ + } + } + + labevents_df <- labevents_df %>% + dplyr::select(timestamp, dplyr::all_of(self$LABITEMS)) + + timestamps <- labevents_df$timestamp + lab_values <- as.matrix(labevents_df[, -1]) + + mortality_label <- as.integer(admission$get("hospital_expire_flag")) + if (is.na(mortality_label) || !mortality_label %in% c(0, 1, "0", "1")) { + mortality_label <- 0 + } + + samples[[length(samples) + 1]] <- list( + patient_id = patient$patient_id, + admission_id = admission$get("hadm_id"), + labs = list(timestamps, lab_values), + mortality = mortality_label + ) + } + return(samples) + } + ) +) diff --git a/R/Task_InHospitalMortalityMIMIC4.R b/R/Task_InHospitalMortalityMIMIC4.R index 7ef6718..0bae226 100644 --- a/R/Task_InHospitalMortalityMIMIC4.R +++ b/R/Task_InHospitalMortalityMIMIC4.R @@ -1,7 +1,7 @@ - #' @title InHospitalMortalityMIMIC4 Task #' @description Task for predicting in-hospital mortality using MIMIC-IV dataset. -#' Uses lab results from the first 48 hours after admission as input features. +#' This task leverages lab results from the first 48 hours of an admission to +#' predict the likelihood of in-hospital mortality. #' @import R6 #' @import dplyr #' @import tidyr @@ -11,149 +11,165 @@ InHospitalMortalityMIMIC4 <- R6::R6Class( classname = "InHospitalMortalityMIMIC4", inherit = BaseTask, public = list( - #' @field input_window_hours Numeric, number of hours to look back for lab data. - input_window_hours = NULL, - #' @field LAB_CATEGORIES Named list mapping lab category to subcategory itemids. - LAB_CATEGORIES = list( - "Electrolytes & Metabolic" = list( - Sodium = c("50824", "52455", "50983", "52623"), - Potassium = c("50822", "52452", "50971", "52610"), - Chloride = c("50806", "52434", "50902", "52535"), - Bicarbonate = c("50803", "50804"), - Glucose = c("50809", "52027", "50931", "52569"), - Calcium = c("50808", "51624"), - Magnesium = c("50960"), - `Anion Gap` = c("50868", "52500"), - Osmolality = c("52031", "50964", "51701"), - Phosphate = c("50970") - ) + #' @field task_name The name of the task. + task_name = "InHospitalMortalityMIMIC4", + #' @field input_schema The schema for input data. + input_schema = list(labs = "timeseries"), + #' @field output_schema The schema for output data. + output_schema = list(mortality = "binary"), + #' @field label The name of the label column. + label = "mortality", + #' @field LABITEMS A list of lab item IDs used in this task. + LABITEMS = c( + # Electrolytes & Metabolic + "50824", "52455", "50983", "52623", # Sodium + "50822", "52452", "50971", "52610", # Potassium + "50806", "52434", "50902", "52535", # Chloride + "50803", "50804", # Bicarbonate + "50809", "52027", "50931", "52569", # Glucose + "50808", "51624", # Calcium + "50960", # Magnesium + "50868", "52500", # Anion Gap + "52031", "50964", "51701", # Osmolality + "50970" # Phosphate ), - #' @field LABITEMS Character vector of all lab itemids (flattened). - LABITEMS = NULL, #' @description Initialize a new InHospitalMortalityMIMIC4 instance. - #' @param input_window_hours Numeric, number of hours to look back (default: 48). - initialize = function(input_window_hours = 48) { + initialize = function() { super$initialize( - task_name = "InHospitalMortalityMIMIC4", - input_schema = list(labs = "timeseries"), - output_schema = list(mortality = "binary") + task_name = self$task_name, + input_schema = self$input_schema, + output_schema = self$output_schema ) - self$input_window_hours <- input_window_hours - # Flatten nested LAB_CATEGORIES into vector of itemids - self$LABITEMS <- unlist(self$LAB_CATEGORIES, use.names = FALSE) }, - #' @description Pre-filter hook to retain only necessary columns for this task. - #' @param df A polars LazyFrame containing all events. + #' @param df A lazy query containing all events. #' @return A filtered LazyFrame with only relevant columns. pre_filter = function(df) { - # Define required columns required_cols <- c( - "patient_id", "event_type", "timestamp", # always required - "anchor_age", # from patients - "dischtime", "hospital_expire_flag", "hadm_id", # from admissions - "labevents/itemid", "labevents/storetime", "labevents/valuenum" # from labevents + "patient_id", "event_type", "timestamp", + "patients/anchor_age", + "admissions/dischtime", "admissions/hospital_expire_flag", "admissions/hadm_id", + "labevents/itemid", "labevents/storetime", "labevents/valuenum" ) - - # Drop other columns, keep only required ones if present - existing_cols <- names(df$schema) - + existing_cols <- colnames(df) keep_cols <- intersect(required_cols, existing_cols) - exprs <- lapply(keep_cols, pl$col) - lf <- do.call(df$select, exprs) - return(lf) + df %>% dplyr::select(dplyr::all_of(keep_cols)) }, - - #' @description Main processing method to generate samples. #' @param patient An object with method `get_events(event_type, ...)`. - #' @return A list of samples. Each sample is a named list containing: - #' - patient_id: character - #' - admission_id: character or integer - #' - labs: a list of [timestamps, lab_values_matrix] - #' - mortality: binary indicator (0/1) + #' @return A list of samples. call = function(patient) { + input_window_hours <- 48 samples <- list() - # Get demographics (should be single event) + demographics <- patient$get_events(event_type = "patients") - if (length(demographics) != 1) return(samples) - demo <- demographics[[1]] - anchor_age <- as.integer(demo$anchor_age) - # Exclude minors - if (is.na(anchor_age) || anchor_age < 18) return(samples) + if (length(demographics) == 0) { + return(samples) + } + anchor_age <- as.integer(demographics[[1]]$get("anchor_age")) + if (length(anchor_age) != 1 || is.na(anchor_age) || anchor_age < 18) { + return(samples) + } - # Iterate over admissions admissions <- patient$get_events(event_type = "admissions") - for (ad in admissions) { - admit_time <- lubridate::ymd_hms(ad$timestamp) - discharge_time <- lubridate::ymd_hms(ad$dischtime) - duration_hours <- as.numeric(difftime(discharge_time, admit_time, units = "hours")) - # Only consider stays longer than input_window_hours - if (duration_hours <= self$input_window_hours) next - predict_time <- admit_time + lubridate::hours(self$input_window_hours) - - # Extract lab events in the window - labevents_df <- patient$get_events( + for (admission in admissions) { + admission_timestamp <- admission$timestamp + admission_dischtime <- tryCatch( + as.POSIXct(admission$get("dischtime")), + error = function(e) NULL + ) + if (is.null(admission_dischtime)) next + + duration_hour <- as.numeric(difftime( + admission_dischtime, + admission_timestamp, + units = "hours") + ) + if (duration_hour <= input_window_hours) { + next + } + predict_time <- admission_timestamp + lubridate::hours(input_window_hours) + + labevents <- patient$get_events( event_type = "labevents", - start = admit_time, - end = predict_time, - return_df = TRUE + start = admission_timestamp, + end = predict_time ) - if (nrow(labevents_df) == 0) next - # Rename columns, convert types, and filter relevant itemids - labevents_df <- labevents_df %>% - dplyr::rename( - timestamp = timestamp, - itemid = `labevents/itemid`, - storetime = `labevents/storetime`, - valuenum = `labevents/valuenum` - ) %>% - dplyr::mutate( - timestamp = lubridate::ymd_hms(timestamp), - storetime = lubridate::ymd_hms(storetime), - valuenum = as.numeric(valuenum) - ) %>% - dplyr::filter( - itemid %in% self$LABITEMS, - storetime <= predict_time - ) + if (length(labevents) == 0) next + + labevents_df_list <- purrr::map(labevents, ~{ + itemid <- .x$get("itemid") + if (!is.null(itemid) && itemid %in% self$LABITEMS) { + storetime_str <- .x$get("storetime") + if (!is.null(storetime_str)) { + storetime <- tryCatch(as.POSIXct(storetime_str), error = function(e) NULL) + if (!is.null(storetime) && storetime <= predict_time) { + valuenum <- .x$get("valuenum") + if (!is.null(valuenum) && !is.na(valuenum)) { + return( + tibble::tibble( + timestamp = .x$timestamp, + itemid = as.character(itemid), + valuenum = as.numeric(valuenum) + ) + ) + } + } + } + } + return(NULL) + }) + + labevents_df <- dplyr::bind_rows(labevents_df_list) + if (nrow(labevents_df) == 0) next - # Pivot to wide format with first aggregation for duplicates - labevents_wide <- labevents_df %>% - dplyr::group_by(timestamp, itemid) %>% - dplyr::summarise(valuenum = dplyr::first(valuenum), .groups = "drop") %>% + labevents_df <- labevents_df %>% + dplyr::filter(!is.na(.data$valuenum)) %>% + dplyr::group_by(.data$timestamp, .data$itemid) %>% + dplyr::summarise(valuenum = dplyr::first(.data$valuenum), .groups = "drop") %>% tidyr::pivot_wider( - names_from = itemid, - values_from = valuenum + names_from = .data$itemid, + values_from = .data$valuenum ) %>% - dplyr::arrange(timestamp) + dplyr::arrange(.data$timestamp) - # Add missing columns with NA and reorder - missing <- setdiff(self$LABITEMS, colnames(labevents_wide)) - if (length(missing) > 0) { - labevents_wide[missing] <- NA + existing_cols <- setdiff(colnames(labevents_df), "timestamp") + missing_cols <- setdiff(self$LABITEMS, existing_cols) + + if (length(missing_cols) > 0) { + for (col in missing_cols) { + labevents_df[[col]] <- NA_real_ + } } - labevents_wide <- labevents_wide %>% + + labevents_df <- labevents_df %>% dplyr::select(timestamp, dplyr::all_of(self$LABITEMS)) - # Extract timestamps and numeric matrix - timestamps <- labevents_wide$timestamp - lab_values <- as.matrix(labevents_wide %>% dplyr::select(-timestamp)) + timestamps <- labevents_df$timestamp + lab_values <- as.matrix(labevents_df[, -1]) + + mortality_label_raw <- admission$get("hospital_expire_flag") - # Mortality flag - mortality_flag <- as.integer(ad$hospital_expire_flag) + mortality_label <- mortality_label_raw + if (is.factor(mortality_label)) { + mortality_label <- as.character(mortality_label) + } + mortality_label <- as.integer(mortality_label) + + if (is.na(mortality_label) || !mortality_label %in% c(0, 1)) { + mortality_label <- 0 + } - # Append to samples samples[[length(samples) + 1]] <- list( patient_id = patient$patient_id, - admission_id = ad$hadm_id, + admission_id = admission$get("hadm_id"), labs = list(timestamps, lab_values), - mortality = mortality_flag + mortality = mortality_label ) } return(samples) diff --git a/R/Task_NextMortalityMIMIC3.R b/R/Task_NextMortalityMIMIC3.R new file mode 100644 index 0000000..0c2fe0f --- /dev/null +++ b/R/Task_NextMortalityMIMIC3.R @@ -0,0 +1,118 @@ +#' @title NextMortalityMIMIC3 Task +#' @description Task for predicting in-hospital mortality using MIMIC-III dataset. +#' This task aims to predict whether the patient will decease in the next +#' hospital visit based on clinical information from the current visit. +#' @import R6 +#' @import dplyr +#' @import tidyr +#' @import lubridate +#' @export +NextMortalityMIMIC3 <- R6::R6Class( + classname = "NextMortalityMIMIC3", + inherit = BaseTask, + public = list( + #' @field label the name of the label column. + label = NULL, + + #' @description Initialize a new NextMortalityMIMIC3 instance. + initialize = function() { + super$initialize( + task_name = "NextMortalityMIMIC3", + input_schema = list( + conditions = "sequence", + procedures = "sequence", + drugs = "sequence" + ), + output_schema = list(mortality = "binary") + ) + self$label <- "mortality" + }, + + #' @description Pre-filter hook to retain only necessary columns for this task. + #' @param df A lazy query containing all events. + #' @return A filtered LazyFrame with only relevant columns. + pre_filter = function(df) { + required_cols <- c( + "patient_id", "event_type", "timestamp", + "admissions/hadm_id", "admissions/hospital_expire_flag", + "diagnoses_icd/icd9_code", "diagnoses_icd/hadm_id", + "procedures_icd/icd9_code", "procedures_icd/hadm_id", + "prescriptions/drug", "prescriptions/hadm_id" + ) + existing_cols <- colnames(df) + keep_cols <- intersect(required_cols, existing_cols) + df %>% dplyr::select(dplyr::all_of(keep_cols)) + }, + + #' @description Main processing method to generate samples. + #' @param patient An object with method `get_events(event_type, ...)`. + #' @return A list of samples. + call = function(patient) { + samples <- list() + admissions <- patient$get_events(event_type = "admissions") + + if (length(admissions) <= 1) { + return(samples) + } + + for (i in 1:(length(admissions) - 1)) { + visit <- admissions[[i]] + next_visit <- admissions[[i + 1]] + + mortality_label <- as.integer(next_visit$get("hospital_expire_flag")) + if (is.na(mortality_label) || !mortality_label %in% c(0, 1)) { + message(paste0( + "patient_id: ", patient$patient_id, + ", hadm_id: ", visit$get("hadm_id"), + ", hospital_expire_flag: '", mortality_label, "'", + ", class: ", class(mortality_label) + )) + mortality_label <- 0 + } + + hadm_id <- visit$get("hadm_id") + + diagnoses <- patient$get_events( + event_type = "diagnoses_icd", + filters = list(list("hadm_id", "==", hadm_id)) + ) + procedures <- patient$get_events( + event_type = "procedures_icd", + filters = list(list("hadm_id", "==", hadm_id)) + ) + prescriptions <- patient$get_events( + event_type = "prescriptions", + filters = list(list("hadm_id", "==", hadm_id)) + ) + + conditions <- purrr::map_chr(diagnoses, ~ .x$get("icd9_code")) + procedures_list <- purrr::map_chr(procedures, ~ .x$get("icd9_code")) + drugs <- purrr::map_chr(prescriptions, ~ .x$get("drug")) + + # Helper to clean sequences + clean_sequence <- function(seq) { + seq <- seq[!is.na(seq) & nzchar(trimws(seq))] + return(seq) + } + + conditions <- clean_sequence(conditions) + procedures_list <- clean_sequence(procedures_list) + drugs <- clean_sequence(drugs) + + if (length(conditions) == 0 || length(procedures_list) == 0 || length(drugs) == 0) { + next + } + + samples[[length(samples) + 1]] <- list( + hadm_id = hadm_id, + patient_id = patient$patient_id, + conditions = conditions, + procedures = procedures_list, + drugs = drugs, + mortality = mortality_label + ) + } + return(samples) + } + ) +) diff --git a/R/Task_NextMortalityMIMIC4.R b/R/Task_NextMortalityMIMIC4.R new file mode 100644 index 0000000..b10d8a8 --- /dev/null +++ b/R/Task_NextMortalityMIMIC4.R @@ -0,0 +1,134 @@ +#' @title NextMortalityMIMIC4 Task +#' @description Task for predicting in-hospital mortality using MIMIC-IV dataset. +#' Uses lab results from the first 48 hours after admission as input features. +#' @import R6 +#' @import dplyr +#' @import tidyr +#' @import lubridate +#' @export +NextMortalityMIMIC4 <- R6::R6Class( + classname = "NextMortalityMIMIC4", + inherit = BaseTask, + public = list( + #' @field label the name of the label column. + label = NULL, + + #' @description Initialize a new NextMortalityMIMIC4 instance. + initialize = function() { + super$initialize( + task_name = "NextMortalityMIMIC4", + input_schema = list( + conditions = "sequence", + procedures = "sequence", + drugs = "sequence" + ), + output_schema = list(mortality = "binary") + ) + self$label <- "mortality" + }, + + #' @description Pre-filter hook to retain only necessary columns for this task. + #' @param df A lazy query containing all events. + #' @return A filtered LazyFrame with only relevant columns. + pre_filter = function(df) { + required_cols <- c( + "patient_id", "event_type", "timestamp", + "patients/anchor_age", + "admissions/dischtime", "admissions/hospital_expire_flag", "admissions/hadm_id", + "diagnoses_icd/icd_code", + "procedures_icd/icd_code", + "prescriptions/drug" + ) + + existing_cols <- colnames(df) + + keep_cols <- intersect(required_cols, existing_cols) + df <- df %>% dplyr::select(dplyr::all_of(keep_cols)) + return(df) + }, + + #' @description Main processing method to generate samples. + #' @param patient An object with method `get_events(event_type, ...)`. + #' @return A list of samples. + call = function(patient) { + samples <- list() + # Get demographics (should be single event) + demographics <- patient$get_events(event_type = "patients") + if (length(demographics) == 0) return(samples) + demo <- demographics[[1]] + anchor_age <- as.integer(demo$get("anchor_age")) + + # Exclude minors or patients with invalid age + if (length(anchor_age) != 1 || is.na(anchor_age) || anchor_age < 18) { + return(samples) + } + + # Iterate over admissions + admissions <- patient$get_events(event_type = "admissions") + if (length(admissions) <= 1) return(samples) + + for (i in 1:(length(admissions) - 1)) { + admission <- admissions[[i]] + next_admission <- admissions[[i + 1]] + + admit_time <- as.POSIXct(admission$timestamp) + discharge_time <- tryCatch( + as.POSIXct(admission$get("dischtime")), + error = function(e) NULL + ) + if (is.null(discharge_time)) next + + # Get mortality label from the next admission + mortality_label <- as.integer(next_admission$get("hospital_expire_flag")) + if (is.na(mortality_label) || !mortality_label %in% c(0, 1)) { + mortality_label <- 0 + } + + # Get events during the current admission + diagnoses <- patient$get_events( + event_type = "diagnoses_icd", + start = admit_time, + end = discharge_time + ) + procedures <- patient$get_events( + event_type = "procedures_icd", + start = admit_time, + end = discharge_time + ) + prescriptions <- patient$get_events( + event_type = "prescriptions", + start = admit_time, + end = discharge_time + ) + + conditions <- purrr::map_chr(diagnoses, ~ .x$get("icd_code")) + procedures_list <- purrr::map_chr(procedures, ~ .x$get("icd_code")) + drugs <- purrr::map_chr(prescriptions, ~ .x$get("drug")) + + # Helper to clean sequences + clean_sequence <- function(seq) { + seq <- seq[!is.na(seq) & nzchar(trimws(seq))] + return(seq) + } + + conditions <- clean_sequence(conditions) + procedures_list <- clean_sequence(procedures_list) + drugs <- clean_sequence(drugs) + + if (length(conditions) == 0 || length(procedures_list) == 0 || length(drugs) == 0) { + next + } + + samples[[length(samples) + 1]] <- list( + patient_id = patient$patient_id, + admission_id = admission$get("hadm_id"), + conditions = conditions, + procedures = procedures_list, + drugs = drugs, + mortality = mortality_label + ) + } + return(samples) + } + ) +) diff --git a/R/Task_Readmission30DaysMIMIC4.R b/R/Task_Readmission30DaysMIMIC4.R index a5c688a..0d545ea 100644 --- a/R/Task_Readmission30DaysMIMIC4.R +++ b/R/Task_Readmission30DaysMIMIC4.R @@ -42,18 +42,19 @@ #' } #' #' @import R6 -#' @import polars #' @export Readmission30DaysMIMIC4 <- R6::R6Class( classname = "Readmission30DaysMIMIC4", inherit = BaseTask, public = list( + #' @field label the name of the label column. + label = NULL, #' @description #' Initialize the Readmission30DaysMIMIC4 task. #' Sets the task name and defines the expected input/output schema. initialize = function() { self$task_name <- "Readmission30DaysMIMIC4" - + self$label <- "readmission" self$input_schema <- list( conditions = "sequence", procedures = "sequence", @@ -101,6 +102,9 @@ Readmission30DaysMIMIC4 <- R6::R6Class( # Extract admissions admissions <- patient$get_events(event_type = "admissions") + + # Sort admissions by timestamp to ensure chronological order + admissions <- admissions[order(sapply(admissions, function(x) x$timestamp))] for (i in seq_along(admissions)) { admission <- admissions[[i]] @@ -139,7 +143,17 @@ Readmission30DaysMIMIC4 <- R6::R6Class( readmission <- 0L } - # Retrieve events during the admission period as Polars DataFrames + # Debug print for the first 5 patients + if (patient$patient_id <= 5) { + message(sprintf( + "Patient ID: %s, Admission: %d, Time Diff (h): %.2f, Readmission: %d", + patient$patient_id, i, + if (!is.null(next_admission)) time_diff_hour else -1, + readmission + )) + } + + # Retrieve events during the admission period as DataFrames diagnoses_icd <- patient$get_events( event_type = "diagnoses_icd", start = admission_time, @@ -159,18 +173,18 @@ Readmission30DaysMIMIC4 <- R6::R6Class( return_df = TRUE ) - # Convert to lists of codes using Polars operations - conditions <- diagnoses_icd$select( - pl$concat_str(c("diagnoses_icd/icd_version", "diagnoses_icd/icd_code"), separator = "_") - )$to_series()$to_list() + # Convert to lists of codes using dplyr operations + conditions <- diagnoses_icd %>% + dplyr::mutate(code = paste(`diagnoses_icd/icd_version`, `diagnoses_icd/icd_code`, sep = "_")) %>% + dplyr::pull(code) - procedures <- procedures_icd$select( - pl$concat_str(c("procedures_icd/icd_version", "procedures_icd/icd_code"), separator = "_") - )$to_series()$to_list() + procedures <- procedures_icd %>% + dplyr::mutate(code = paste(`procedures_icd/icd_version`, `procedures_icd/icd_code`, sep = "_")) %>% + dplyr::pull(code) - drugs <- prescriptions$select( - pl$concat_str(c("prescriptions/drug"), separator = "_") - )$to_series()$to_list() + drugs <- prescriptions %>% + dplyr::mutate(code = paste(`prescriptions/drug`, sep = "_")) %>% + dplyr::pull(code) # Exclude visits without complete feature data if (length(conditions) * length(procedures) * length(drugs) == 0) { diff --git a/R/Trainer.R b/R/Trainer.R index 800a1e2..6192331 100644 --- a/R/Trainer.R +++ b/R/Trainer.R @@ -49,6 +49,10 @@ create_directory <- function(directory) { } } +multiclass_metrics_fn <- NULL +multilabel_metrics_fn <- NULL +regression_metrics_fn <- NULL + #' @title Get Metrics Function #' @description Returns appropriate metric function according to task mode. #' @@ -74,7 +78,7 @@ get_metrics_fn <- function(mode) { #' * **Dynamic `steps_per_epoch`**: can iterate indefinitely over a dataloader to reach a target number of steps, just like Python. #' * **Parameter‑group–wise weight decay**: bias and *LayerNorm* parameters are excluded from L2 regularisation. #' * **Gradient clipping**. -#' * **Optional progress bar** using `cli::cli_progress_bar()` (falls back to simple logging). +#' * **Optional progress bar** using `progressr::progressor()` (falls back to simple logging). #' * **Correctly named `additional_outputs`** collection. #' #' @export @@ -123,6 +127,8 @@ Trainer <- R6::R6Class( set_logger(file.path(self$exp_path, "train.log")) } flog.info("Initialised model on %s", self$device) + flog.info(paste(capture.output(self$model), collapse = "\n")) + flog.info("Metrics: %s", paste(self$metrics, collapse = ", ")) # Checkpoint -------------------------------------------------------------- if (!is.null(checkpoint_path)) { @@ -162,13 +168,24 @@ Trainer <- R6::R6Class( load_best_model_at_last = TRUE, use_progress_bar = TRUE) { + flog.info("Training:") + flog.info("Batch size: %d", train_dataloader$batch_size) + flog.info("Optimizer: %s", deparse(substitute(optimizer_class))) + flog.info("Optimizer params: %s", paste(names(optimizer_params), optimizer_params, sep = "=", collapse = ", ")) + flog.info("Weight decay: %f", weight_decay) + flog.info("Max grad norm: %s", if (is.null(max_grad_norm)) "NULL" else max_grad_norm) + flog.info("Val dataloader: %s", if (is.null(val_dataloader)) "NULL" else "provided") + flog.info("Monitor: %s", if (is.null(monitor)) "NULL" else monitor) + flog.info("Monitor criterion: %s", monitor_criterion) + flog.info("Epochs: %d", epochs) + # ---------- parameter grouping (bias / LayerNorm excluded) --------------- all_named <- self$model$named_parameters() no_decay_keys <- c("bias", "LayerNorm.weight", "LayerNorm.bias") params_wd <- list() params_nowd <- list() for (nm in names(all_named)) { - if (any(startsWith(nm, no_decay_keys))) { + if (any(vapply(no_decay_keys, function(key) grepl(key, nm, fixed = TRUE), logical(1)))) { params_nowd[[length(params_nowd)+1]] <- all_named[[nm]] } else { params_wd[[length(params_wd)+1]] <- all_named[[nm]] @@ -191,9 +208,8 @@ Trainer <- R6::R6Class( # Create iterator that can restart ------------------------------------ iter <- torch::dataloader_make_iter(train_dataloader) # Progress bar --------------------------------------------------------- - if (use_progress_bar && requireNamespace("cli", quietly = TRUE)) { - pb <- cli::cli_progress_bar(total = steps_per_epoch, - format = "Epoch {epoch}/{epochs} :current/:total :elapsed") + if (use_progress_bar) { + p <- progressr::progressor(steps = steps_per_epoch) } for (step in seq_len(steps_per_epoch)) { @@ -216,23 +232,36 @@ Trainer <- R6::R6Class( epoch_losses <- c(epoch_losses, loss$item()) global_step <- global_step + 1L - if (use_progress_bar && exists("pb")) cli::cli_progress_update() + if (use_progress_bar) p(message = sprintf("Epoch %d", epoch)) } - if (use_progress_bar && exists("pb")) cli::cli_progress_done() - flog.info("Epoch %d/%d | train loss %.4f", epoch, epochs, mean(epoch_losses)) + train_header <- sprintf("--- Train epoch-%d, step-%d ---", epoch, global_step) + flog.info(train_header) + message(train_header) + train_loss <- paste(sprintf("loss: %.4f", mean(epoch_losses)), collapse = "\n") + flog.info(train_loss) + message(train_loss) + flog.info("loss: %.4f", mean(epoch_losses)) # Save last ckpt -------------------------------------------------------- if (!is.null(self$exp_path)) self$save_ckpt(file.path(self$exp_path, "last.ckpt")) # Validation ------------------------------------------------------------ if (!is.null(val_dataloader)) { - scores <- self$evaluate(val_dataloader) - flog.info("Val scores: %s", paste(sprintf("%s=%.4f", names(scores), scores), collapse = ", ")) + scores <- self$evaluate(val_dataloader, use_progress_bar) + eval_header <- sprintf("--- Eval epoch-%d, step-%d ---", epoch, global_step) + flog.info(eval_header) + message(eval_header) + scores_log <- paste(sprintf("%s: %.4f", names(scores), scores), collapse = "\n") + flog.info(scores_log) + message(scores_log) if (!is.null(monitor)) { current <- scores[[monitor]] if (is_best(best_score, current, monitor_criterion)) { best_score <- current - flog.info("New best %s: %.4f", monitor, current) + flog.info("New best %s score (%.4f) at epoch-%d, step-%d", + monitor, current, epoch, global_step) + message(sprintf("New best %s score (%.4f) at epoch-%d, step-%d", + monitor, current, epoch, global_step)) if (!is.null(self$exp_path)) self$save_ckpt(file.path(self$exp_path, "best.ckpt")) } } @@ -241,12 +270,20 @@ Trainer <- R6::R6Class( # Reload best ------------------------------------------------------------- best_path <- file.path(self$exp_path, "best.ckpt") - if (load_best_model_at_last && file.exists(best_path)) self$load_ckpt(best_path) + if (load_best_model_at_last && file.exists(best_path)) { + flog.info("Loaded best model") + message("Loaded best model") + self$load_ckpt(best_path) + } # Test ------------------------------------------------------------------- if (!is.null(test_dataloader)) { - scores <- self$evaluate(test_dataloader) - flog.info("Test scores: %s", paste(sprintf("%s=%.4f", names(scores), scores), collapse = ", ")) + scores <- self$evaluate(test_dataloader, use_progress_bar) + flog.info("--- Test ---") + message("--- Test ---") + scores_log <- paste(sprintf("%s: %.4f", names(scores), scores), collapse = "\n") + flog.info(scores_log) + message(scores_log) } }, @@ -255,43 +292,55 @@ Trainer <- R6::R6Class( #' @param dataloader A dataloader. #' @param additional_outputs Vector of additional outputs to capture. #' @param return_patient_ids Whether to return patient IDs. - inference = function(dataloader, additional_outputs = NULL, return_patient_ids = FALSE) { - losses <- c(); y_true <- list(); y_prob <- list() + #' @param use_progress_bar Whether to show a progress bar. + inference = function(dataloader, + additional_outputs = NULL, + return_patient_ids = FALSE, + use_progress_bar = FALSE) { + losses <- c() + y_true_batches <- list() + y_prob_batches <- list() + if (!is.null(additional_outputs)) { add_outputs <- setNames(lapply(additional_outputs, function(x) list()), additional_outputs) } pids <- c() - + self$model$eval() - + + if (use_progress_bar) { + p <- progressr::progressor(steps = length(dataloader)) + } + torch::with_no_grad({ - + coro::loop(for (batch in dataloader) { - + out <- self$model(batch) - + losses <- c(losses, out$loss$item()) - - y_true[[length(y_true) + 1]] <- as_array(out$y_true$cpu()) - y_prob[[length(y_prob) + 1]] <- as_array(out$y_prob$cpu()) - + + y_true_batches[[length(y_true_batches) + 1]] <- out$y_true$cpu() + y_prob_batches[[length(y_prob_batches) + 1]] <- out$y_prob$cpu() + if (!is.null(additional_outputs)) { for (nm in additional_outputs) { add_outputs[[nm]][[length(add_outputs[[nm]]) + 1]] <- as_array(out[[nm]]$cpu()) } } - + if (return_patient_ids && "patient_id" %in% names(batch)) { pids <- c(pids, batch$patient_id) } + if (use_progress_bar) p() }) }) - - + + res <- list( - y_true = do.call(rbind, y_true), - y_prob = do.call(rbind, y_prob), + y_true = torch::torch_cat(y_true_batches, dim = 1), + y_prob = torch::torch_cat(y_prob_batches, dim = 1), loss = mean(losses) ) if (!is.null(additional_outputs)) res$additional <- lapply(add_outputs, function(x) do.call(rbind, x)) @@ -302,8 +351,9 @@ Trainer <- R6::R6Class( #' @description #' Evaluate the model using a dataloader. #' @param dataloader A dataloader to evaluate on. - evaluate = function(dataloader) { - inf <- self$inference(dataloader) + #' @param use_progress_bar Whether to show a progress bar. + evaluate = function(dataloader, use_progress_bar = FALSE) { + inf <- self$inference(dataloader, use_progress_bar = use_progress_bar) if (!is.null(self$model$mode)) { fn <- get_metrics_fn(self$model$mode) scores <- fn(inf$y_true, inf$y_prob, metrics = self$metrics) diff --git a/R/zzz.R b/R/zzz.R index 74e9b06..5026207 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -8,10 +8,7 @@ #' @keywords internal #' @noRd .onLoad <- function(libname, pkgname) { - if (is.null(getOption("RHealth.medcode_base"))) { - options( - RHealth.medcode_base = - "https://storage.googleapis.com/pyhealth/resource/" - ) + if (interactive()) { + progressr::handlers("progress") } } diff --git a/README.Rmd b/README.Rmd index d30e6f7..7261a43 100644 --- a/README.Rmd +++ b/README.Rmd @@ -105,6 +105,8 @@ get_descendants(code = "428", system = "ICD9CM") ```{r, eval=FALSE} # Map from ICD-9 to CCS map_code(code = "428.0", from = "ICD9CM", to = "CCSCM") +# Map from ICD-9 to ICD-10 +map_code(code = "589", from = "ICD9CM", to = "ICD10CM") ``` ### 🛢️ 2. Dataset Module @@ -113,7 +115,6 @@ The **Dataset** module is the foundation of RHealth. It transforms raw, multi-ta **Key Features:** - * **Efficient Ingestion**: Streams large CSVs using Polars `LazyFrame` to avoid loading massive files into RAM. * **Data Harmonisation**: Merges heterogeneous tables into a single, canonical event table. * **Built-in Caching**: Uses DuckDB for CSV → Parquet caching, enabling up to 10x faster reloads. * **Dev Mode**: Allows for lightning-fast iteration by using a small subset of patients. @@ -189,7 +190,7 @@ MyReadmissionTask <- R6::R6Class( Once a task is defined, use it with your dataset to create a `SampleDataset` compatible with `{torch}`. ```{r, eval=FALSE} -task <- Readmission30DaysMIMIC4$new() # A built-in task +task <- InHospitalMortalityMIMIC4$new() # A built-in task samples <- ds$set_task(task) ``` @@ -249,7 +250,7 @@ The **Trainer** module provides a high-level, configurable training loop that ha ```{r, eval=FALSE} # 1. Create data loaders -splits <- split_by_patient(samples, c(0.8, 0.1, 0.1)) +splits <- split_by_patient(samples, c(0.8, 0.1, 0.1), stratify = TRUE, stratify_by = 'mortality') train_dl <- get_dataloader(splits[[1]], batch_size = 32, shuffle = TRUE) val_dl <- get_dataloader(splits[[2]], batch_size = 32) test_dl <- get_dataloader(splits[[3]], batch_size = 32) diff --git a/README.md b/README.md index 91758c7..15c23f5 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,8 @@ get_descendants(code = "428", system = "ICD9CM") ``` r # Map from ICD-9 to CCS map_code(code = "428.0", from = "ICD9CM", to = "CCSCM") +# Map from ICD-9 to ICD-10 +map_code(code = "589", from = "ICD9CM", to = "ICD10CM") ``` ### 🛢️ 2. Dataset Module @@ -117,8 +119,6 @@ tensors that any downstream model can consume. **Key Features:** -- **Efficient Ingestion**: Streams large CSVs using Polars `LazyFrame` - to avoid loading massive files into RAM. - **Data Harmonisation**: Merges heterogeneous tables into a single, canonical event table. - **Built-in Caching**: Uses DuckDB for CSV → Parquet caching, enabling @@ -201,7 +201,7 @@ Once a task is defined, use it with your dataset to create a `SampleDataset` compatible with `{torch}`. ``` r -task <- Readmission30DaysMIMIC4$new() # A built-in task +task <- InHospitalMortalityMIMIC4$new() # A built-in task samples <- ds$set_task(task) ``` @@ -265,7 +265,7 @@ that handles logging, checkpointing, evaluation, and progress bars. ``` r # 1. Create data loaders -splits <- split_by_patient(samples, c(0.8, 0.1, 0.1)) +splits <- split_by_patient(samples, c(0.8, 0.1, 0.1), stratify = TRUE, stratify_by = 'mortality') train_dl <- get_dataloader(splits[[1]], batch_size = 32, shuffle = TRUE) val_dl <- get_dataloader(splits[[2]], batch_size = 32) test_dl <- get_dataloader(splits[[3]], batch_size = 32) diff --git a/_pkgdown.yml b/_pkgdown.yml index f6bc4ec..39b76f9 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -45,6 +45,7 @@ reference: - split_by_patient - split_by_visit - split_by_sample + - load_sample_dataset - title: "Processors" @@ -70,8 +71,11 @@ reference: Benchmark prediction tasks based on MIMIC-IV. contents: - BaseTask - - Readmission30DaysMIMIC4 + - InHospitalMortalityMIMIC3 - InHospitalMortalityMIMIC4 + - NextMortalityMIMIC3 + - NextMortalityMIMIC4 + - Readmission30DaysMIMIC4 - title: "Models" desc: > diff --git a/inst/extdata/configs/mimic3.yaml b/inst/extdata/configs/mimic3.yaml index 9fe5eb3..9447198 100644 --- a/inst/extdata/configs/mimic3.yaml +++ b/inst/extdata/configs/mimic3.yaml @@ -41,7 +41,7 @@ tables: - "first_careunit" - "dbsource" - "last_careunit" - - "outtime" + - "outtime" diagnoses_icd: file_path: "DIAGNOSES_ICD.csv.gz" @@ -54,6 +54,7 @@ tables: - "dischtime" timestamp: "dischtime" attributes: + - "hadm_id" - "icd9_code" - "seq_num" @@ -61,7 +62,14 @@ tables: file_path: "PRESCRIPTIONS.csv.gz" patient_id: "subject_id" timestamp: "startdate" + join: + - file_path: "ADMISSIONS.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" attributes: + - "hadm_id" - "drug" - "drug_type" - "drug_name_poe" @@ -88,6 +96,7 @@ tables: - "dischtime" timestamp: "dischtime" attributes: + - "hadm_id" - "icd9_code" - "seq_num" @@ -116,9 +125,16 @@ tables: noteevents: file_path: "NOTEEVENTS.csv.gz" patient_id: "subject_id" + join: + - file_path: "ADMISSIONS.csv.gz" + "on": "hadm_id" + how: "inner" + columns: + - "dischtime" timestamp: - "charttime" attributes: + - "hadm_id" - "text" - "category" - "description" diff --git a/man/BaseDataset.Rd b/man/BaseDataset.Rd index d71bbe5..f7d2315 100644 --- a/man/BaseDataset.Rd +++ b/man/BaseDataset.Rd @@ -10,20 +10,20 @@ BaseDataset — R6 infrastructure for clinical event datasets } \details{ The \strong{BaseDataset} class mirrors rhealth's \code{BaseDataset}, providing a -fully‑featured, YAML driven loader that converts multi‑table electronic +fully-featured, YAML driven loader that converts multi-table electronic health records into a single \emph{event} table. It supports: \itemize{ -\item URL or local‐file ingestion (with automatic \code{.csv} / \code{.csv.gz} +\item URL or local-file ingestion (with automatic \code{.csv} / \code{.csv.gz} fallback). -\item Per‑table joins as declared in the config. -\item Flexible timestamp parsing (single or multi‑column). +\item Per-table joins as declared in the config. +\item Flexible timestamp parsing (single or multi-column). \item A \code{dev} mode that caps the number of patients for rapid prototyping. -\item Multi‑threaded sample generation with progress bars. +\item Multi-threaded sample generation with progress bars. } -Down‑stream, it cooperates with \code{BaseTask} (task definition), -\code{Patient} (per‑subject wrapper), and \code{SampleDataset} (collection of +Down-stream, it cooperates with \code{BaseTask} (task definition), +\code{Patient} (per-subject wrapper), and \code{SampleDataset} (collection of input/output pairs). } \section{Dependencies}{ @@ -39,13 +39,15 @@ reporting require \code{future}, \code{future.apply}, and \code{progressr}. \item{\code{tables}}{Character vector of table names to ingest.} -\item{\code{dataset_name}}{Human‑readable dataset label.} +\item{\code{dataset_name}}{Human-readable dataset label.} \item{\code{config}}{Parsed YAML configuration list.} -\item{\code{dev}}{Logical flag — when TRUE limits to 1 000 patients.} +\item{\code{dev}}{Logical flag — when TRUE limits to 1000 patients.} -\item{\code{global_event_df}}{A polars LazyFrame with all events combined.} +\item{\code{con}}{a duckdb connection} + +\item{\code{global_event_df}}{A duckdb lazy query with all events combined.} \item{\code{.collected_global_event_df}}{Polars dataframe storing all global events.} @@ -104,14 +106,14 @@ Instantiate a \code{BaseDataset}. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-BaseDataset-collected_global_event_df}{}}} \subsection{Method \code{collected_global_event_df()}}{ -Materialise (collect) the lazy event dataframe. In dev‑mode only the +Materialise (collect) the lazy event dataframe. In dev-mode only the first 1000 patients are kept. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{BaseDataset$collected_global_event_df()}\if{html}{\out{
}} } \subsection{Returns}{ -A polars DataFrame containing all selected events. +A dataframe containing all selected events. } } \if{html}{\out{
}} @@ -132,7 +134,7 @@ event schema. \if{html}{\out{}} } \subsection{Returns}{ -A polars LazyFrame in event format. +A dplyr lazy query in event format. } } \if{html}{\out{
}} @@ -145,7 +147,7 @@ Load every configured table, returning a single \emph{lazy} frame. } \subsection{Returns}{ -A polars LazyFrame. +A duckdb lazy query. } } \if{html}{\out{
}} @@ -193,7 +195,7 @@ Iterate over all patients (optionally a filtered dataframe). \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{Optional polars DataFrame (already collected).} +\item{\code{df}}{Optional dataframe (already collected).} } \if{html}{\out{
}} } @@ -205,13 +207,13 @@ List of \code{Patient} objects. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-BaseDataset-stats}{}}} \subsection{Method \code{stats()}}{ -Print dataset‑level statistics. +Print dataset-level statistics. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{BaseDataset$stats()}\if{html}{\out{
}} } \subsection{Returns}{ -Invisible NULL (called for side‑effects). +Invisible NULL (called for side-effects). } } \if{html}{\out{
}} @@ -233,7 +235,12 @@ NULL \subsection{Method \code{set_task()}}{ Apply a \code{BaseTask} to build a \code{SampleDataset}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{BaseDataset$set_task(task = NULL, num_workers = 1)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{BaseDataset$set_task( + task = NULL, + num_workers = 1, + chunk_size = 1000, + cache_dir = NULL +)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -243,6 +250,12 @@ Apply a \code{BaseTask} to build a \code{SampleDataset}. used.} \item{\code{num_workers}}{Integer ≥1. Number of parallel workers.} + +\item{\code{chunk_size}}{Integer. Number of patients to process in each chunk.} + +\item{\code{cache_dir}}{Optional path to a directory for caching samples. If set, +processed samples will be saved to an \code{.rds} file and reloaded on +subsequent runs, skipping the generation step.} } \if{html}{\out{}} } diff --git a/man/BaseModel.Rd b/man/BaseModel.Rd index 3061993..516d651 100644 --- a/man/BaseModel.Rd +++ b/man/BaseModel.Rd @@ -8,8 +8,6 @@ BaseModel(dataset) } \arguments{ \item{dataset}{A dataset object (must have input_schema, output_schema, output_processors).} - -\item{logits}{Torch tensor with raw model outputs.} } \value{ Integer scalar representing the output dimension. @@ -29,6 +27,7 @@ Selects appropriate loss function based on task type in output schema. Converts logits into predicted probabilities for evaluation. Format depends on task mode (sigmoid or softmax, or raw). +This method takes \code{logits} as input, which is a torch tensor with raw model outputs. } \section{Fields}{ diff --git a/man/BaseTask.Rd b/man/BaseTask.Rd index aba6813..f8ea82e 100644 --- a/man/BaseTask.Rd +++ b/man/BaseTask.Rd @@ -84,7 +84,7 @@ Main processing function. Must be overridden in subclasses. \if{html}{\out{}} } \subsection{Returns}{ -A list of named lists (equivalent to Python List\link{Dict}) representing the task result. +A list of named lists representing the task result. } } \if{html}{\out{
}} diff --git a/man/BinaryLabelProcessor.Rd b/man/BinaryLabelProcessor.Rd index 3c1c868..b0caa78 100644 --- a/man/BinaryLabelProcessor.Rd +++ b/man/BinaryLabelProcessor.Rd @@ -8,7 +8,7 @@ Processor for binary classification labels. Supports numeric (0/1), logical (TRUE/FALSE), or categorical binary labels. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{BinaryLabelProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{BinaryLabelProcessor} } \section{Public fields}{ \if{html}{\out{
}} @@ -69,7 +69,7 @@ Fit the processor by analyzing all unique labels in the dataset. \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-BinaryLabelProcessor-process}{}}} \subsection{Method \code{process()}}{ -Process a label into a torch tensor \link{0} or \link{1}. +Process a label into a torch tensor \verb{[0]} or \verb{[1]}. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{BinaryLabelProcessor$process(value)}\if{html}{\out{
}} } @@ -82,7 +82,7 @@ Process a label into a torch tensor \link{0} or \link{1}. \if{html}{\out{
}} } \subsection{Returns}{ -A float32 torch tensor of shape \link{1}. +A float32 torch tensor of shape \code{1}. } } \if{html}{\out{
}} @@ -102,9 +102,6 @@ Integer 1 \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-BinaryLabelProcessor-print}{}}} \subsection{Method \code{print()}}{ -Print a summary of the processor. - - Print a summary of the processor. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{BinaryLabelProcessor$print(...)}\if{html}{\out{
}} diff --git a/man/DatasetProcessor.Rd b/man/DatasetProcessor.Rd index c78c251..5ab2db9 100644 --- a/man/DatasetProcessor.Rd +++ b/man/DatasetProcessor.Rd @@ -7,7 +7,7 @@ Optional class for processing entire datasets in bulk (e.g., batch statistics). } \section{Super class}{ -\code{RHealth::Processor} -> \code{DatasetProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{DatasetProcessor} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/EmbeddingModel.Rd b/man/EmbeddingModel.Rd index 3dc9486..c69382e 100644 --- a/man/EmbeddingModel.Rd +++ b/man/EmbeddingModel.Rd @@ -10,8 +10,6 @@ EmbeddingModel(dataset, embedding_dim = 128) \item{dataset}{A SampleDataset object containing input_processors.} \item{embedding_dim}{Integer embedding dimension. Default is 128.} - -\item{inputs}{A named list of \code{torch_tensor} objects, with names matching dataset$input_processors.} } \value{ An \code{EmbeddingModel} object that inherits from \code{BaseModel}. @@ -28,6 +26,7 @@ EmbeddingModel is responsible for creating embedding layers for different types Initialize an EmbeddingModel by constructing embedding layers based on input processors. Perform a forward pass by computing embeddings (or passing through) for each field. +This method takes \code{inputs}, a named list of \code{torch_tensor} objects, with names matching dataset$input_processors. Return a concise string representation of the EmbeddingModel, listing its embedding layers. } diff --git a/man/FeatureProcessor.Rd b/man/FeatureProcessor.Rd index cfee46c..21c34e4 100644 --- a/man/FeatureProcessor.Rd +++ b/man/FeatureProcessor.Rd @@ -17,7 +17,7 @@ model-ready tensors. } \section{Super class}{ -\code{RHealth::Processor} -> \code{FeatureProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{FeatureProcessor} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/InHospitalMortalityMIMIC3.Rd b/man/InHospitalMortalityMIMIC3.Rd new file mode 100644 index 0000000..013964c --- /dev/null +++ b/man/InHospitalMortalityMIMIC3.Rd @@ -0,0 +1,105 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Task_InHospitalMortalityMIMIC3.R +\name{InHospitalMortalityMIMIC3} +\alias{InHospitalMortalityMIMIC3} +\title{InHospitalMortalityMIMIC3 Task} +\description{ +Task for predicting in-hospital mortality using MIMIC-III dataset. +This task leverages lab results from the first 48 hours of an admission to +predict the likelihood of in-hospital mortality. +} +\section{Super class}{ +\code{\link[RHealth:BaseTask]{RHealth::BaseTask}} -> \code{InHospitalMortalityMIMIC3} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{task_name}}{The name of the task.} + +\item{\code{input_schema}}{The schema for input data.} + +\item{\code{output_schema}}{The schema for output data.} + +\item{\code{label}}{The name of the label column.} + +\item{\code{LABITEMS}}{A list of lab item IDs used in this task for MIMIC-III.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-InHospitalMortalityMIMIC3-new}{\code{InHospitalMortalityMIMIC3$new()}} +\item \href{#method-InHospitalMortalityMIMIC3-pre_filter}{\code{InHospitalMortalityMIMIC3$pre_filter()}} +\item \href{#method-InHospitalMortalityMIMIC3-call}{\code{InHospitalMortalityMIMIC3$call()}} +\item \href{#method-InHospitalMortalityMIMIC3-clone}{\code{InHospitalMortalityMIMIC3$clone()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-InHospitalMortalityMIMIC3-new}{}}} +\subsection{Method \code{new()}}{ +Initialize a new InHospitalMortalityMIMIC3 instance. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC3$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-InHospitalMortalityMIMIC3-pre_filter}{}}} +\subsection{Method \code{pre_filter()}}{ +Pre-filter hook to retain only necessary columns for this task. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC3$pre_filter(df)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{df}}{A lazy query containing all events.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A filtered LazyFrame with only relevant columns. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-InHospitalMortalityMIMIC3-call}{}}} +\subsection{Method \code{call()}}{ +Main processing method to generate samples. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC3$call(patient)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{patient}}{An object with method \code{get_events(event_type, ...)}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A list of samples. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-InHospitalMortalityMIMIC3-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC3$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/InHospitalMortalityMIMIC4.Rd b/man/InHospitalMortalityMIMIC4.Rd index 04ea8fe..565a410 100644 --- a/man/InHospitalMortalityMIMIC4.Rd +++ b/man/InHospitalMortalityMIMIC4.Rd @@ -5,19 +5,24 @@ \title{InHospitalMortalityMIMIC4 Task} \description{ Task for predicting in-hospital mortality using MIMIC-IV dataset. -Uses lab results from the first 48 hours after admission as input features. +This task leverages lab results from the first 48 hours of an admission to +predict the likelihood of in-hospital mortality. } \section{Super class}{ -\code{RHealth::BaseTask} -> \code{InHospitalMortalityMIMIC4} +\code{\link[RHealth:BaseTask]{RHealth::BaseTask}} -> \code{InHospitalMortalityMIMIC4} } \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{input_window_hours}}{Numeric, number of hours to look back for lab data.} +\item{\code{task_name}}{The name of the task.} -\item{\code{LAB_CATEGORIES}}{Named list mapping lab category to subcategory itemids.} +\item{\code{input_schema}}{The schema for input data.} -\item{\code{LABITEMS}}{Character vector of all lab itemids (flattened).} +\item{\code{output_schema}}{The schema for output data.} + +\item{\code{label}}{The name of the label column.} + +\item{\code{LABITEMS}}{A list of lab item IDs used in this task.} } \if{html}{\out{
}} } @@ -36,16 +41,9 @@ Uses lab results from the first 48 hours after admission as input features. \subsection{Method \code{new()}}{ Initialize a new InHospitalMortalityMIMIC4 instance. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC4$new(input_window_hours = 48)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{InHospitalMortalityMIMIC4$new()}\if{html}{\out{
}} } -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{input_window_hours}}{Numeric, number of hours to look back (default: 48).} -} -\if{html}{\out{
}} -} } \if{html}{\out{
}} \if{html}{\out{}} @@ -59,7 +57,7 @@ Pre-filter hook to retain only necessary columns for this task. \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{A polars LazyFrame containing all events.} +\item{\code{df}}{A lazy query containing all events.} } \if{html}{\out{
}} } @@ -84,13 +82,7 @@ Main processing method to generate samples. \if{html}{\out{}} } \subsection{Returns}{ -A list of samples. Each sample is a named list containing: -\itemize{ -\item patient_id: character -\item admission_id: character or integer -\item labs: a list of \link{timestamps, lab_values_matrix} -\item mortality: binary indicator (0/1) -} +A list of samples. } } \if{html}{\out{
}} diff --git a/man/MIMIC3Dataset.Rd b/man/MIMIC3Dataset.Rd index ec971b6..fc8676b 100644 --- a/man/MIMIC3Dataset.Rd +++ b/man/MIMIC3Dataset.Rd @@ -1,12 +1,8 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/Dataset_MIMIC3Dataset.R -\docType{class} \name{MIMIC3Dataset} \alias{MIMIC3Dataset} \title{MIMIC3Dataset: Dataset class for MIMIC-III} -\usage{ -\method{MIMIC3Dataset}{initialize}(...) -} \description{ MIMIC3Dataset: Dataset class for MIMIC-III @@ -18,13 +14,12 @@ It ensures key tables like patients, admissions, and icustays are loaded, and allows appending additional tables. Also provides per-table preprocessing. } \section{Super class}{ -\code{RHealth::BaseDataset} -> \code{MIMIC3Dataset} +\code{\link[RHealth:BaseDataset]{RHealth::BaseDataset}} -> \code{MIMIC3Dataset} } \section{Methods}{ \subsection{Public methods}{ \itemize{ \item \href{#method-MIMIC3Dataset-new}{\code{MIMIC3Dataset$new()}} -\item \href{#method-MIMIC3Dataset-preprocess_noteevents}{\code{MIMIC3Dataset$preprocess_noteevents()}} \item \href{#method-MIMIC3Dataset-clone}{\code{MIMIC3Dataset$clone()}} } } @@ -72,33 +67,12 @@ initialize MIMIC3Dataset \item{\code{dev}}{Logical flag for dev mode.} -\item{\code{...}}{Additional arguments passed to \code{BaseDataset}.} +\item{\code{...}}{Additional arguments passed to \code{BaseDataset$initialize}.} } \if{html}{\out{}} } } \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-MIMIC3Dataset-preprocess_noteevents}{}}} -\subsection{Method \code{preprocess_noteevents()}}{ -Table-specific preprocessing for noteevents. -If \code{charttime} is missing, fills it with \code{chartdate} + " 00:00:00". -\subsection{Usage}{ -\if{html}{\out{
}}\preformatted{MIMIC3Dataset$preprocess_noteevents(df)}\if{html}{\out{
}} -} - -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{df}}{A polars LazyFrame.} -} -\if{html}{\out{
}} -} -\subsection{Returns}{ -A modified LazyFrame. -} -} -\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-MIMIC3Dataset-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/MIMIC4EHRDataset.Rd b/man/MIMIC4EHRDataset.Rd index 9b68f71..5345de9 100644 --- a/man/MIMIC4EHRDataset.Rd +++ b/man/MIMIC4EHRDataset.Rd @@ -1,12 +1,8 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/Dataset_MIMIC4EHRDataset.R -\docType{class} \name{MIMIC4EHRDataset} \alias{MIMIC4EHRDataset} \title{MIMIC4EHRDataset: Dataset class for MIMIC-IV EHR} -\usage{ -\method{MIMIC4EHRDataset}{initialize}(...) -} \description{ MIMIC4EHRDataset: Dataset class for MIMIC-IV EHR @@ -18,7 +14,7 @@ It ensures key tables like patients, admissions, and icustays are included, and allows appending additional tables. It also logs memory usage if needed. } \section{Super class}{ -\code{RHealth::BaseDataset} -> \code{MIMIC4EHRDataset} +\code{\link[RHealth:BaseDataset]{RHealth::BaseDataset}} -> \code{MIMIC4EHRDataset} } \section{Methods}{ \subsection{Public methods}{ @@ -71,7 +67,7 @@ Initialize MIMIC4EHRDataset \item{\code{dev}}{Logical flag for dev mode.} -\item{\code{...}}{Additional arguments passed to \code{BaseDataset}.} +\item{\code{...}}{Additional arguments passed to \code{BaseDataset$initialize}.} } \if{html}{\out{}} } diff --git a/man/MIMIC4NoteDataset.Rd b/man/MIMIC4NoteDataset.Rd index 55b772e..d5f3271 100644 --- a/man/MIMIC4NoteDataset.Rd +++ b/man/MIMIC4NoteDataset.Rd @@ -1,12 +1,8 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/Dataset_MIMIC4NoteDataset.R -\docType{class} \name{MIMIC4NoteDataset} \alias{MIMIC4NoteDataset} \title{MIMIC4NoteDataset: Dataset class for MIMIC-IV Clinical Notes} -\usage{ -\method{MIMIC4NoteDataset}{initialize}(...) -} \description{ MIMIC4NoteDataset: Dataset class for MIMIC-IV Clinical Notes @@ -17,7 +13,7 @@ This class inherits from BaseDataset and is specialized for handling MIMIC-IV Cl It includes tables such as discharge, discharge_detail, and radiology. } \section{Super class}{ -\code{RHealth::BaseDataset} -> \code{MIMIC4NoteDataset} +\code{\link[RHealth:BaseDataset]{RHealth::BaseDataset}} -> \code{MIMIC4NoteDataset} } \section{Methods}{ \subsection{Public methods}{ @@ -70,7 +66,7 @@ Initialize MIMIC4NoteDataset \item{\code{dev}}{Logical flag for dev mode.} -\item{\code{...}}{Additional arguments passed to \code{BaseDataset}.} +\item{\code{...}}{Additional arguments passed to \code{BaseDataset$initialize}.} } \if{html}{\out{}} } diff --git a/man/MultiClassLabelProcessor.Rd b/man/MultiClassLabelProcessor.Rd index 8c37a60..8466f88 100644 --- a/man/MultiClassLabelProcessor.Rd +++ b/man/MultiClassLabelProcessor.Rd @@ -8,7 +8,7 @@ Processor for multi-class classification tasks. Converts string or integer label into integer indices, with one output per label. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{MultiClassLabelProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{MultiClassLabelProcessor} } \section{Public fields}{ \if{html}{\out{
}} diff --git a/man/MultiLabelProcessor.Rd b/man/MultiLabelProcessor.Rd index f5975bb..73f270e 100644 --- a/man/MultiLabelProcessor.Rd +++ b/man/MultiLabelProcessor.Rd @@ -8,7 +8,7 @@ Processor for multi-label classification. Converts a list of active labels into one-hot tensor with multiple 1s. Inherits from FeatureProcessor. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{MultiLabelProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{MultiLabelProcessor} } \section{Public fields}{ \if{html}{\out{
}} @@ -82,7 +82,7 @@ Process a list of active labels into a one-hot float tensor. \if{html}{\out{
}} } \subsection{Returns}{ -A torch tensor of shape \link{num_classes} with 0s and 1s. +A torch tensor of shape \code{num_classes} with 0s and 1s. } } \if{html}{\out{
}} diff --git a/man/NextMortalityMIMIC3.Rd b/man/NextMortalityMIMIC3.Rd new file mode 100644 index 0000000..0d4924f --- /dev/null +++ b/man/NextMortalityMIMIC3.Rd @@ -0,0 +1,97 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Task_NextMortalityMIMIC3.R +\name{NextMortalityMIMIC3} +\alias{NextMortalityMIMIC3} +\title{NextMortalityMIMIC3 Task} +\description{ +Task for predicting in-hospital mortality using MIMIC-III dataset. +This task aims to predict whether the patient will decease in the next +hospital visit based on clinical information from the current visit. +} +\section{Super class}{ +\code{\link[RHealth:BaseTask]{RHealth::BaseTask}} -> \code{NextMortalityMIMIC3} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{label}}{the name of the label column.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-NextMortalityMIMIC3-new}{\code{NextMortalityMIMIC3$new()}} +\item \href{#method-NextMortalityMIMIC3-pre_filter}{\code{NextMortalityMIMIC3$pre_filter()}} +\item \href{#method-NextMortalityMIMIC3-call}{\code{NextMortalityMIMIC3$call()}} +\item \href{#method-NextMortalityMIMIC3-clone}{\code{NextMortalityMIMIC3$clone()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC3-new}{}}} +\subsection{Method \code{new()}}{ +Initialize a new NextMortalityMIMIC3 instance. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC3$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC3-pre_filter}{}}} +\subsection{Method \code{pre_filter()}}{ +Pre-filter hook to retain only necessary columns for this task. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC3$pre_filter(df)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{df}}{A lazy query containing all events.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A filtered LazyFrame with only relevant columns. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC3-call}{}}} +\subsection{Method \code{call()}}{ +Main processing method to generate samples. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC3$call(patient)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{patient}}{An object with method \code{get_events(event_type, ...)}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A list of samples. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC3-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC3$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/NextMortalityMIMIC4.Rd b/man/NextMortalityMIMIC4.Rd new file mode 100644 index 0000000..58526ad --- /dev/null +++ b/man/NextMortalityMIMIC4.Rd @@ -0,0 +1,96 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Task_NextMortalityMIMIC4.R +\name{NextMortalityMIMIC4} +\alias{NextMortalityMIMIC4} +\title{NextMortalityMIMIC4 Task} +\description{ +Task for predicting in-hospital mortality using MIMIC-IV dataset. +Uses lab results from the first 48 hours after admission as input features. +} +\section{Super class}{ +\code{\link[RHealth:BaseTask]{RHealth::BaseTask}} -> \code{NextMortalityMIMIC4} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{label}}{the name of the label column.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-NextMortalityMIMIC4-new}{\code{NextMortalityMIMIC4$new()}} +\item \href{#method-NextMortalityMIMIC4-pre_filter}{\code{NextMortalityMIMIC4$pre_filter()}} +\item \href{#method-NextMortalityMIMIC4-call}{\code{NextMortalityMIMIC4$call()}} +\item \href{#method-NextMortalityMIMIC4-clone}{\code{NextMortalityMIMIC4$clone()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC4-new}{}}} +\subsection{Method \code{new()}}{ +Initialize a new NextMortalityMIMIC4 instance. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC4$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC4-pre_filter}{}}} +\subsection{Method \code{pre_filter()}}{ +Pre-filter hook to retain only necessary columns for this task. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC4$pre_filter(df)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{df}}{A lazy query containing all events.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A filtered LazyFrame with only relevant columns. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC4-call}{}}} +\subsection{Method \code{call()}}{ +Main processing method to generate samples. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC4$call(patient)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{patient}}{An object with method \code{get_events(event_type, ...)}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A list of samples. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-NextMortalityMIMIC4-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{NextMortalityMIMIC4$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/Patient.Rd b/man/Patient.Rd index c165265..eb6513e 100644 --- a/man/Patient.Rd +++ b/man/Patient.Rd @@ -5,11 +5,11 @@ \title{Patient: R6 Class for a Sequence of Events} \description{ The \code{Patient} class manages all clinical events for a single patient. -It supports efficient event-type partitioning, fast time-range slicing, and flexible multi-condition filtering using rpolars. +It supports efficient event-type partitioning, fast time-range slicing, and flexible multi-condition filtering. } \details{ \itemize{ -\item Data is held as a polars DataFrame. +\item Data is held as a data.frame. \item Events can be retrieved as either raw data frames or Event object lists. } } @@ -18,9 +18,9 @@ It supports efficient event-type partitioning, fast time-range slicing, and flex \describe{ \item{\code{patient_id}}{Character. Unique identifier for the patient.} -\item{\code{data_source}}{Polars DataFrame. All events for this patient, sorted by timestamp.} +\item{\code{data_source}}{data.frame. All events for this patient, sorted by timestamp.} -\item{\code{event_type_partitions}}{List. Mapping event type to corresponding polars DataFrames.} +\item{\code{event_type_partitions}}{List. Mapping event type to corresponding data.frames.} } \if{html}{\out{
}} } @@ -50,7 +50,7 @@ Create a Patient object. \describe{ \item{\code{patient_id}}{Character. Unique patient identifier.} -\item{\code{data_source}}{Polars DataFrame. All events (must include event_type, timestamp columns).} +\item{\code{data_source}}{data.frame. All events (must include event_type, timestamp columns).} } \if{html}{\out{}} } @@ -70,7 +70,7 @@ Filter events by time range (O(n), regular scan). \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{Polars DataFrame. Source event data.} +\item{\code{df}}{data.frame. Source event data.} \item{\code{start}}{Character/POSIXct. (Optional) Start time.} @@ -79,7 +79,7 @@ Filter events by time range (O(n), regular scan). \if{html}{\out{
}} } \subsection{Returns}{ -Polars DataFrame. Events in specified range. +data.frame. Events in specified range. } } \if{html}{\out{
}} @@ -94,7 +94,7 @@ Efficient time range filter via binary search (O(log n)), requires sorted data. \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{Polars DataFrame. Source event data.} +\item{\code{df}}{data.frame. Source event data.} \item{\code{start}}{Character/POSIXct. (Optional) Start time.} @@ -103,7 +103,7 @@ Efficient time range filter via binary search (O(log n)), requires sorted data. \if{html}{\out{
}} } \subsection{Returns}{ -Polars DataFrame. Filtered events. +data.frame. Filtered events. } } \if{html}{\out{
}} @@ -118,14 +118,14 @@ Regular event type filter (O(n)). \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{Polars DataFrame.} +\item{\code{df}}{data.frame.} \item{\code{event_type}}{Character. Type of event.} } \if{html}{\out{
}} } \subsection{Returns}{ -Polars DataFrame. +data.frame. } } \if{html}{\out{
}} @@ -140,14 +140,14 @@ Fast event type filter (O(1)) using partitioned lookup. \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{df}}{Polars DataFrame.} +\item{\code{df}}{data.frame.} \item{\code{event_type}}{Character. Type of event.} } \if{html}{\out{
}} } \subsection{Returns}{ -Polars DataFrame. Only the given event type. +data.frame. Only the given event type. } } \if{html}{\out{
}} diff --git a/man/RHealth-package.Rd b/man/RHealth-package.Rd new file mode 100644 index 0000000..20ac65a --- /dev/null +++ b/man/RHealth-package.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RHealth-package.R +\docType{package} +\name{RHealth-package} +\alias{RHealth} +\alias{RHealth-package} +\title{RHealth: A Deep Learning Toolkit for Healthcare Predictive Modeling} +\description{ +RHealth is an open-source R package specifically designed to bring comprehensive deep learning toolkits to the R community for healthcare predictive modeling. +} +\seealso{ +Useful links: +\itemize{ + \item \url{https://v1xerunt.github.io/dev/RHealth} + \item \url{https://github.com/v1xerunt/RHealth} + \item Report bugs at \url{https://github.com/v1xerunt/RHealth/issues} +} + +} +\author{ +\strong{Maintainer}: Junyi Gao \email{junyii.gao@gmail.com} + +Authors: +\itemize{ + \item Ji Song \email{eaad0907@163.com} + \item Zhixia Ren \email{jexica0921@gmail.com} + \item Zhenbang Wu \email{zw12@illinois.edu} + \item John Wu \email{johnwu3@illinois.edu} + \item Chaoqi Yang \email{ycqsjtu@gmail.com} + \item Jimeng Sun \email{jimeng@illinois.edu} + \item Liantao Ma \email{malt@pku.edu.cn} + \item Ewen Harrison \email{ewen.harrison@ed.ac.uk} +} + +} +\keyword{internal} diff --git a/man/RawProcessor.Rd b/man/RawProcessor.Rd index 2f2c2a8..004a95b 100644 --- a/man/RawProcessor.Rd +++ b/man/RawProcessor.Rd @@ -8,7 +8,7 @@ Processor that returns raw values without any transformation. Inherits from FeatureProcessor. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{RawProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{RawProcessor} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/Readmission30DaysMIMIC4.Rd b/man/Readmission30DaysMIMIC4.Rd index 6005b3f..126dc56 100644 --- a/man/Readmission30DaysMIMIC4.Rd +++ b/man/Readmission30DaysMIMIC4.Rd @@ -51,7 +51,14 @@ samples <- task$call(patient) } \section{Super class}{ -\code{RHealth::BaseTask} -> \code{Readmission30DaysMIMIC4} +\code{\link[RHealth:BaseTask]{RHealth::BaseTask}} -> \code{Readmission30DaysMIMIC4} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{label}}{the name of the label column.} +} +\if{html}{\out{
}} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/RegressionLabelProcessor.Rd b/man/RegressionLabelProcessor.Rd index 0dd4130..efad9ed 100644 --- a/man/RegressionLabelProcessor.Rd +++ b/man/RegressionLabelProcessor.Rd @@ -7,7 +7,7 @@ Processor for scalar regression labels. Converts values to a 1D float tensor. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{RegressionLabelProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{RegressionLabelProcessor} } \section{Methods}{ \subsection{Public methods}{ @@ -44,7 +44,7 @@ Process a numeric label into a single-element float tensor. \if{html}{\out{}} } \subsection{Returns}{ -A torch tensor of shape \link{1}. +A torch tensor of shape \verb{[1]}. } } \if{html}{\out{
}} @@ -57,7 +57,7 @@ Return the size of the processed label (always 1). } \subsection{Returns}{ -Integer 1 +Integer \code{1} } } \if{html}{\out{
}} diff --git a/man/SampleDataset.Rd b/man/SampleDataset.Rd index c22f58c..7d9e3a0 100644 --- a/man/SampleDataset.Rd +++ b/man/SampleDataset.Rd @@ -9,7 +9,8 @@ SampleDataset( input_schema, output_schema, dataset_name = "", - task_name = "" + task_name = "", + save_path = NULL ) } \arguments{ @@ -23,9 +24,7 @@ SampleDataset( \item{task_name}{Optional task name} -\item{index}{Integer index} - -\item{...}{Ignored} +\item{save_path}{Optional path to save the processed dataset.} } \value{ Named list representing the sample diff --git a/man/SampleProcessor.Rd b/man/SampleProcessor.Rd index ba66c0e..ea1e366 100644 --- a/man/SampleProcessor.Rd +++ b/man/SampleProcessor.Rd @@ -8,7 +8,7 @@ Optional processor for transformations applied at the whole-sample level (e.g., normalizing an image+label pair). } \section{Super class}{ -\code{RHealth::Processor} -> \code{SampleProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{SampleProcessor} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/SequenceProcessor.Rd b/man/SequenceProcessor.Rd index 4c19585..7c94e22 100644 --- a/man/SequenceProcessor.Rd +++ b/man/SequenceProcessor.Rd @@ -8,7 +8,7 @@ Feature processor for encoding categorical sequences (e.g., medical codes) into numerical indices. Supports dynamic vocabulary construction. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{SequenceProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{SequenceProcessor} } \section{Public fields}{ \if{html}{\out{
}} diff --git a/man/TextProcessor.Rd b/man/TextProcessor.Rd index eeca7fe..4554ed5 100644 --- a/man/TextProcessor.Rd +++ b/man/TextProcessor.Rd @@ -7,7 +7,7 @@ Processor for textual input. Inherits from FeatureProcessor and defines a minimal no-op process method. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{TextProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{TextProcessor} } \section{Methods}{ \subsection{Public methods}{ diff --git a/man/TimeseriesProcessor.Rd b/man/TimeseriesProcessor.Rd index ba3ed71..70e4d7c 100644 --- a/man/TimeseriesProcessor.Rd +++ b/man/TimeseriesProcessor.Rd @@ -8,7 +8,7 @@ Processor for irregular time series data with missing values. Supports uniform resampling and two imputation strategies: forward-fill and zero-fill. } \section{Super classes}{ -\code{RHealth::Processor} -> \code{RHealth::FeatureProcessor} -> \code{TimeseriesProcessor} +\code{\link[RHealth:Processor]{RHealth::Processor}} -> \code{\link[RHealth:FeatureProcessor]{RHealth::FeatureProcessor}} -> \code{TimeseriesProcessor} } \section{Public fields}{ \if{html}{\out{
}} @@ -81,7 +81,7 @@ Step 2: impute missing entries using selected strategy. \if{html}{\out{
}} } \subsection{Returns}{ -A torch tensor of shape \link{T, F}. +A torch tensor of shape \verb{[T, F]}. } } \if{html}{\out{
}} diff --git a/man/Trainer.Rd b/man/Trainer.Rd index 066d427..69de85f 100644 --- a/man/Trainer.Rd +++ b/man/Trainer.Rd @@ -9,7 +9,7 @@ An enhanced R6 trainer mirroring PyHealth's Python version. It supports: \item \strong{Dynamic \code{steps_per_epoch}}: can iterate indefinitely over a dataloader to reach a target number of steps, just like Python. \item \strong{Parameter‑group–wise weight decay}: bias and \emph{LayerNorm} parameters are excluded from L2 regularisation. \item \strong{Gradient clipping}. -\item \strong{Optional progress bar} using \code{cli::cli_progress_bar()} (falls back to simple logging). +\item \strong{Optional progress bar} using \code{progressr::progressor()} (falls back to simple logging). \item \strong{Correctly named \code{additional_outputs}} collection. } } @@ -142,7 +142,8 @@ Perform inference on a dataloader. \if{html}{\out{
}}\preformatted{Trainer$inference( dataloader, additional_outputs = NULL, - return_patient_ids = FALSE + return_patient_ids = FALSE, + use_progress_bar = FALSE )}\if{html}{\out{
}} } @@ -154,6 +155,8 @@ Perform inference on a dataloader. \item{\code{additional_outputs}}{Vector of additional outputs to capture.} \item{\code{return_patient_ids}}{Whether to return patient IDs.} + +\item{\code{use_progress_bar}}{Whether to show a progress bar.} } \if{html}{\out{
}} } @@ -164,13 +167,15 @@ Perform inference on a dataloader. \subsection{Method \code{evaluate()}}{ Evaluate the model using a dataloader. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Trainer$evaluate(dataloader)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Trainer$evaluate(dataloader, use_progress_bar = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{dataloader}}{A dataloader to evaluate on.} + +\item{\code{use_progress_bar}}{Whether to show a progress bar.} } \if{html}{\out{
}} } diff --git a/man/dot-create_task_runner.Rd b/man/dot-create_task_runner.Rd new file mode 100644 index 0000000..b65f5a6 --- /dev/null +++ b/man/dot-create_task_runner.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Dataset_BaseDataset.R +\name{.create_task_runner} +\alias{.create_task_runner} +\title{Helper function to create a clean closure for parallel processing} +\usage{ +.create_task_runner(task) +} +\arguments{ +\item{task}{The task object to be used inside the closure.} +} +\value{ +A function that takes a patient dataframe (\code{pdf}) and applies the task. +} +\description{ +This function acts as a factory. It creates and returns another function +(a closure) that is suitable for use with \code{future.apply}. The returned +function's environment is intentionally minimal, containing only the \code{task} +object. This prevents large objects from the parent environment (like the +full event dataframe) from being accidentally exported to parallel workers. +} +\keyword{internal} diff --git a/man/dot-csv2parquet_path.Rd b/man/dot-csv2parquet_path.Rd index 8707c5a..e06362c 100644 --- a/man/dot-csv2parquet_path.Rd +++ b/man/dot-csv2parquet_path.Rd @@ -6,6 +6,13 @@ \usage{ .csv2parquet_path(csv_path) } +\arguments{ +\item{csv_path}{Path to the csv file.} +} +\value{ +A character string. +} \description{ Given a *.csv(.gz) path, return *.parquet path in a /subset folder } +\keyword{internal} diff --git a/man/dot-find_path_with_fallback.Rd b/man/dot-find_path_with_fallback.Rd new file mode 100644 index 0000000..9c2d7e5 --- /dev/null +++ b/man/dot-find_path_with_fallback.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Dataset_BaseDataset.R +\name{.find_path_with_fallback} +\alias{.find_path_with_fallback} +\title{Find an existing data file path with fallback for .gz extension.} +\usage{ +.find_path_with_fallback(path) +} +\arguments{ +\item{path}{A character path to a .csv, .csv.gz, .tsv, or .tsv.gz file.} +} +\value{ +A list with \code{path} to an existing file and \code{separator} (',' or a tab). +} +\description{ +This function checks for the existence of a path and its alternative with/without +\code{.gz}. It also determines the separator based on the file extension (.csv or .tsv). +} +\keyword{internal} diff --git a/man/dot-scan_csv_gz_or_csv.Rd b/man/dot-scan_csv_gz_or_csv.Rd deleted file mode 100644 index aa9acb1..0000000 --- a/man/dot-scan_csv_gz_or_csv.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/Dataset_BaseDataset.R -\name{.scan_csv_gz_or_csv} -\alias{.scan_csv_gz_or_csv} -\title{Lazily loads a \code{.csv} or \code{.csv.gz} file and returns a Polars LazyFrame. -Automatically tries the alternate extension if the primary path fails.} -\usage{ -.scan_csv_gz_or_csv(path) -} -\arguments{ -\item{path}{File path or URL ending in \code{.csv} or \code{.csv.gz}.} -} -\value{ -A polars LazyFrame. -} -\description{ -Lazily loads a \code{.csv} or \code{.csv.gz} file and returns a Polars LazyFrame. -Automatically tries the alternate extension if the primary path fails. -} -\keyword{internal} diff --git a/man/ece_confidence_binary.Rd b/man/ece_confidence_binary.Rd index 2b26994..a4b29b7 100644 --- a/man/ece_confidence_binary.Rd +++ b/man/ece_confidence_binary.Rd @@ -12,12 +12,12 @@ containing predicted probabilities for the \emph{positive} class (only the first column is used if a matrix is supplied).} \item{label}{Numeric vector \strong{or} two-column matrix of true labels \ -encoded as 0/1 (only the first column is used if a matrix is supplied).} +encoded as \code{0/1} (only the first column is used if a matrix is supplied).} \item{bins}{Integer. Number of bins (default 20).} \item{adaptive}{Logical. If \code{FALSE} (default) equal-width bins \ -spanning \\(\link{0,1}\\) are used; if \code{TRUE} each bin contains the \ +spanning \verb{0, 1} are used; if \code{TRUE} each bin contains the \ same number of samples (equal-size bins).} } \value{ diff --git a/man/load_sample_dataset.Rd b/man/load_sample_dataset.Rd new file mode 100644 index 0000000..5188add --- /dev/null +++ b/man/load_sample_dataset.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Dataset_SampleDataset.R +\name{load_sample_dataset} +\alias{load_sample_dataset} +\title{Load a SampleDataset object from a directory} +\usage{ +load_sample_dataset(path) +} +\arguments{ +\item{path}{The directory path from where to load the dataset.} +} +\value{ +The reconstructed SampleDataset object. +} +\description{ +This function reconstructs a SampleDataset object from a directory. +} diff --git a/man/split_by_patient.Rd b/man/split_by_patient.Rd index 4efc053..9b2e627 100644 --- a/man/split_by_patient.Rd +++ b/man/split_by_patient.Rd @@ -4,7 +4,14 @@ \alias{split_by_patient} \title{split_by_patient} \usage{ -split_by_patient(dataset, ratios, seed = NULL) +split_by_patient( + dataset, + ratios, + seed = NULL, + stratify = FALSE, + stratify_by = NULL, + get_index = FALSE +) } \arguments{ \item{dataset}{A \code{SampleDataset} object.} @@ -12,9 +19,15 @@ split_by_patient(dataset, ratios, seed = NULL) \item{ratios}{A numeric vector of length 3 indicating train/val/test split ratios. Must sum to 1.} \item{seed}{Optional integer for reproducibility.} + +\item{stratify}{Logical, whether to perform stratified sampling. Default: FALSE.} + +\item{stratify_by}{Character, the name of the field to stratify by (e.g., the label). Required if \code{stratify} is TRUE.} + +\item{get_index}{Logical, whether to return the indices instead of subsets. Default: FALSE.} } \value{ -A list of 3 torch::dataset_subset objects split by patient id. +A list of 3 torch::dataset_subset objects or 3 tensors of indices if get_index = TRUE, split by patient id. } \description{ split_by_patient diff --git a/man/split_by_sample.Rd b/man/split_by_sample.Rd index 89d516b..03d10dd 100644 --- a/man/split_by_sample.Rd +++ b/man/split_by_sample.Rd @@ -4,7 +4,14 @@ \alias{split_by_sample} \title{Dataset Split Functions} \usage{ -split_by_sample(dataset, ratios, seed = NULL, get_index = FALSE) +split_by_sample( + dataset, + ratios, + seed = NULL, + stratify = FALSE, + stratify_by = NULL, + get_index = FALSE +) } \arguments{ \item{dataset}{A \code{SampleDataset} object.} @@ -13,6 +20,10 @@ split_by_sample(dataset, ratios, seed = NULL, get_index = FALSE) \item{seed}{Optional integer for reproducibility.} +\item{stratify}{Logical, whether to perform stratified sampling. Default: FALSE.} + +\item{stratify_by}{Character, the name of the field to stratify by (e.g., the label). Required if \code{stratify} is TRUE.} + \item{get_index}{Logical, whether to return the indices instead of subsets. Default: FALSE.} } \value{ diff --git a/man/split_by_visit.Rd b/man/split_by_visit.Rd index e167bb7..97042b3 100644 --- a/man/split_by_visit.Rd +++ b/man/split_by_visit.Rd @@ -4,7 +4,13 @@ \alias{split_by_visit} \title{split_by_visit} \usage{ -split_by_visit(dataset, ratios, seed = NULL) +split_by_visit( + dataset, + ratios, + seed = NULL, + stratify = FALSE, + stratify_by = NULL +) } \arguments{ \item{dataset}{A \code{SampleDataset} object.} @@ -12,6 +18,10 @@ split_by_visit(dataset, ratios, seed = NULL) \item{ratios}{A numeric vector of length 3 indicating train/val/test split ratios. Must sum to 1.} \item{seed}{Optional integer for reproducibility.} + +\item{stratify}{Logical, whether to perform stratified sampling. Default: FALSE.} + +\item{stratify_by}{Character, the name of the field to stratify by (e.g., the label). Required if \code{stratify} is TRUE.} } \value{ A list of 3 torch::dataset_subset objects. diff --git a/tests/testthat/test-model-learning.R b/tests/testthat/test-model-learning.R new file mode 100644 index 0000000..2a850bc --- /dev/null +++ b/tests/testthat/test-model-learning.R @@ -0,0 +1,107 @@ +library(testthat) +library(torch) +library(R6) + +# This test script is designed to verify that the RNN model can learn from a simple, +# synthetically generated dataset. If this test passes, it suggests that the model +# implementation is correct and any training issues might be related to the actual +# data being used. + +test_that("RNN model can learn from pseudo-data", { + + # Mock Processor Class + # A minimal R6 class to mimic the necessary properties of a feature processor, + # specifically the vocabulary size needed by the EmbeddingModel. + MockProcessor <- R6::R6Class("MockProcessor", + inherit = SequenceProcessor, + public = list( + vocab_size = NULL, + initialize = function(vocab_size) { + self$vocab_size <- vocab_size + # The EmbeddingModel uses the length of code_vocab to set num_embeddings. + self$code_vocab <- seq_len(vocab_size) + } + )) + + # Mock Dataset + # A list-based dataset, which is a more standard way to create custom datasets + # in torch for R. This avoids potential R6 class inheritance issues. + create_mock_dataset <- function(num_samples = 128) { + lapply(1:num_samples, function(i) { + # Create a sequence with a learnable pattern. + feature_seq <- torch_randint(1, 20, size = c(10), dtype = torch_long()) + # The label is 1 if the first token is > 10, otherwise 0. + label_val <- ifelse(as.integer(feature_seq[1]$item()) > 10, 1, 0) + + list( + feature_a = feature_seq, + # A random binary label (0 or 1), as a float for the loss function. + label = torch_tensor(label_val, dtype = torch_float()) + ) + }) + } + + mock_data <- create_mock_dataset() + mock_dataset_generator <- torch::dataset( + name = "mock_dataset", + initialize = function() { + self$input_processors <- list(feature_a = MockProcessor$new(vocab_size = 20)) + self$input_schema <- list(feature_a = "sequence") + self$output_processors <- list(label = BinaryLabelProcessor$new()) + self$output_schema <- list(label = "binary") + self$label_keys <- c("label") + self$feature_keys <- "feature_a" + }, + .getitem = function(i) mock_data[[i]], + .length = function() length(mock_data) + ) + + # Add the necessary metadata attributes that the model expects. + # mock_dataset_instance$input_processors <- list(feature_a = MockProcessor$new(vocab_size = 20)) + # mock_dataset_instance$input_schema <- list(feature_a = "sequence") + # mock_dataset_instance$output_processors <- list(label = BinaryLabelProcessor$new()) + # mock_dataset_instance$output_schema <- list(label = "binary") + # mock_dataset_instance$label_keys <- c("label") + # mock_dataset_instance$feature_keys <- "feature_a" + + + # 1. Instantiate the mock dataset. + mock_dataset_instance <- mock_dataset_generator() + + # 2. Instantiate the RNN model with the mock dataset's schema. + # We use small dimensions for this test and disable dropout for stability. + model <- RNN(dataset = mock_dataset_instance, embedding_dim = 128, hidden_dim = 128, dropout = 0) + + # 3. Create a dataloader for the mock dataset. + dataloader <- dataloader(mock_dataset_instance, batch_size = 32) + + # 4. Initialize the Trainer. + trainer <- Trainer$new( + model = model, + device = "cpu", + metrics = c("roc_auc", "accuracy") + ) + + # 5. Train the model for a few epochs. + # We use the same dataloader for training and validation for simplicity. + # With enough training, the model should be able to overfit to this simple data. + trainer$train( + train_dataloader = dataloader, + val_dataloader = dataloader, + epochs = 20, + monitor = "roc_auc", + use_progress_bar = FALSE + ) + + # 6. Evaluate the trained model. + result <- trainer$evaluate(dataloader, use_progress_bar = FALSE) + print("Final evaluation results on pseudo-data:") + print(result) + + # 7. Assert that the model has learned. + # A model that has learned should have an ROC AUC significantly > 0.5 + # and an accuracy significantly > 0.5. We'll check for > 0.9 as the pattern is simple. + expect_gt(result$roc_auc, 0.9, label = "ROC AUC should be greater than 0.9 after training.") + expect_gt(result$accuracy, 0.9, label = "Accuracy should be greater than 0.9 after training.") + +}) diff --git a/vignettes/MedCode.rmd b/vignettes/MedCode.rmd index 0e6d00a..39959f2 100644 --- a/vignettes/MedCode.rmd +++ b/vignettes/MedCode.rmd @@ -1,6 +1,10 @@ --- title: "DeepRHealth::medcode Module Demo" subtitle: "R/medicine 2025 Prototype Showcase" +vignette: > + %\VignetteIndexEntry{DeepRHealth::medcode Module Demo} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} author: "Zhixia Ren" date: 05/25/2025 format: diff --git a/vignettes/Overview.Rmd b/vignettes/Overview.Rmd index 8d543f7..b2ce85b 100644 --- a/vignettes/Overview.Rmd +++ b/vignettes/Overview.Rmd @@ -1,6 +1,10 @@ --- title: "Descriptions" output: html_document +vignette: > + %\VignetteIndexEntry{Overview: RHealth} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ---