A lightweight, configurable CLI for loading the CIC-MalMem-2022 dataset, preprocessing features, training classic ML models for malware detection (binary) and family classification (multi‑class), visualizing results, and image-based malware classification (malimg).
- Loads dataset from Hugging Face Datasets and caches locally
- Preprocesses features, derives
category_familyandcategory_encoded - Trains and evaluates Random Forest, MLP, KNN, XGBoost
- Saves trained models under a versioned directory structure
- Generates Sankey and radar plots and confusion matrices for trained models
# 1) Create and activate a virtual environment
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
# 2) Install dependencies
pip install -r requirements.txt
# 3) Preview the data (downloads and caches the dataset on first run)
python main.py --view head -r 10
# 4) Train models
# Binary detection (benign vs. malware)
python main.py --train detection
# Multi-class classification of malware family
python main.py --train classification
# Train both
python main.py --train both
# 5) Visualize
# Sankey of Category distribution
python main.py --plot sankey
# Radar of selected numeric features per Category
python main.py --plot radar -x handles.nsemaphore handles.nmutant -C Benign
# Confusion matrix for a saved model
python main.py --plot confusion --confusion-task classification --model-name XGBoostIf you encounter download issues (e.g., no internet), see “Dataset & Caching”.
.
├── main.py # CLI entrypoint
├── config.yaml # Project configuration (logging, data, models)
├── requirements.txt
├── malimg.py # CLI entrypoint for malware image classification
├── malimg.yaml # Configuration for malimg.py
├── malimg/ # Malware imaging related files
│ └── dataset/
│ └── malimg.npz # Dataset for malimg
├── src/
│ ├── __init__.py
│ ├── config.py # YAML loader with basic validation
│ ├── data.py # HF dataset loader + preprocessing
│ ├── models.py # Training/eval + model persistence
│ └── plot.py # Sankey, radar, and confusion matrix plots
└── models/ # Created at runtime, stores saved models
The malimg component provides functionality for malware image classification using Convolutional Neural Networks (CNNs). It allows for loading and preprocessing image-based malware datasets, training a CNN model, and evaluating its performance.
To run the malware imaging classification:
python malimg.py --helpThis will display the available command-line arguments for malimg.py, including options for data paths, training epochs, batch size, and model saving.
The behavior of malimg.py is controlled by malimg.yaml, which specifies parameters such as:
- training: Epochs, batch size, test split size, and random seed.
- paths: Locations for the dataset (
malimg/dataset/malimg.npz) and saved models (malimg/models). - mappings: A dictionary mapping numerical labels to malware family names.
- Data Loading and Preprocessing: Loads image data from
malimg/dataset/malimg.npz, resizes images to 32x32 pixels, flattens them, and appliesLabelEncoderto the labels. - Model Building and Training: Constructs a CNN model for One-vs-All classification and trains it using the preprocessed image data.
- Model Evaluation: Evaluates the trained model on a test set and reports metrics such as accuracy, precision, recall, and F1-score.
- Model Persistence: Saves the trained CNN model to
malimg/models/cnn_ova.kerasfor future use.
All behavior is controlled via config.yaml.
logging:
save: True
dir: "logs/app.log" # file path or directory; file created if suffix is present
data:
dataset: "bvk/CIC-MalMem-2022"
cache_dir: "data/" # HF Datasets cache for this project
selected_features: # optional: normalize and keep only these
- handles.nhandles
- dlllist.ndlls
selected_features_radar: # optional: which features to show on radar plot
- svcscan.nservices
- malfind.ninjections
models:
target_col: "category_encoded" # used for classification
dirs:
root: "models"
binary: "binary"
classification: "classification"
params: # per-model defaults (overrides are merged)
Random Forest:
n_estimators: 100
random_state: 42
n_jobs: -1
MLP:
hidden_layer_sizes: [100, 50]
max_iter: 1000
random_state: 42
KNN:
n_neighbors: 5
XGBoost:
n_estimators: 100
random_state: 42
verbosity: 0
n_jobs: -1Notes:
logging.dircan be a directory or a file path. If a filename is provided (has a suffix), logs go to that file; otherwise,app.logwill be created inside the directory.data.selected_featurescontrols which numeric columns are standardized and kept after preprocessing. If omitted, all numeric features (excluding labels) are standardized and kept.selected_features_radarfilters which numeric features appear on the radar plot.- Model hyperparameters in
models.paramsare merged with sensible defaults in code.
-
Data loading (
src/data.py)- Uses
datasets.load_dataset("bvk/CIC-MalMem-2022")and converts thetrainsplit to a pandas DataFrame. - Derives
category_familyfromCategoryby taking the substring before the first dash (e.g., “Ransomware-Foo” → “Ransomware”). - Encodes
category_familyintocategory_encodedviaLabelEncoder. - If
Classis a string column ("Malware"/"Benign"), it is converted to 1/0. - Standardizes numeric features (z-score) either on all numeric columns or only on
data.selected_featuresif provided.
- Uses
-
Training (
src/models.py)- Binary detection uses
Class(0/1) as the target. - Multi-class classification uses
category_encodedas the target (ensure preprocessing ran). - Splits data with stratified train/test (
--test-size,--seed). - Trains: Random Forest, MLP, KNN, XGBoost; reports Accuracy, Precision, Recall, F1.
- Saves each model to
models/<task>/<name>.joblib(skips retraining if the saved artifact exists and features match). - Optional hyperparameter tuning for classification (
--tune-classification) viaGridSearchCVwith stratified 3‑fold.
- Binary detection uses
-
Plotting (
src/plot.py)sankey: shows distribution ofCategoryvalues.radar: per-category mean values across selected numeric features (can--excludefeatures and--exclude-category).confusion: loads a saved model and plots a confusion matrix on the test split. For detection, if a saved model is missing, it will train a fresh model on the fly; for classification, you must train first.
General options:
python main.py --helpViewing data:
# Show first 5 rows (default)
python main.py --view head
# Show a random sample of 10
python main.py --view sample --rows 10Training:
# Detection (binary)
python main.py --train detection
# Classification (multi-class)
python main.py --train classification
# Train both
python main.py --train both
# Choose models directory (else taken from config)
python main.py --train detection --models-dir ./models
# Control split and seed
python main.py --train classification --test-size 0.25 --seed 123
# Enable hyperparameter tuning (classification only)
python main.py --train classification --tune-classificationPlotting:
# Sankey of Category distribution
python main.py --plot sankey
# Radar chart (exclude specific features and/or categories)
python main.py --plot radar --exclude handles.nsemaphore handles.nmutant --exclude-category Benign
# Confusion matrix (requires a saved model; train first for classification)
python main.py --plot confusion --confusion-task classification --model-name XGBoost
# For detection you can also visualize without prior save (auto-trains if missing)
python main.py --plot confusion --confusion-task detection --model-name Random Forest- The first run will download
bvk/CIC-MalMem-2022from Hugging Face and cache it. - Cache location is controlled by
data.cache_dirinconfig.yaml. You can also rely on standard HF cache env vars likeHF_HOMEorHF_DATASETS_CACHEif you prefer. - If running in a restricted/offline environment, pre-download the dataset on a machine with internet and copy the cache directory to this project’s
data/folder.
- Console logs are always enabled.
- If
logging.save: True, file logs are written to the path configured bylogging.dir. - Trained models are saved under
models/by default:models/binary/<model>.joblibmodels/classification/<model>.joblib
- Python 3.10+
- See
requirements.txtfor exact versions. Notable dependencies:datasets,pandas,scikit-learn,xgboost,matplotlib,plotly,seaborn.
- Classification requires
category_encoded; ensure preprocessing ran (it does inmain.py). - If you change the feature set, previously saved models might be skipped or retrained depending on feature compatibility.
- If
plot confusionfails for classification due to a missing model, run a classification training first. - XGBoost objective is set automatically to multi-class when necessary.