Probabilistic computation with emerging covariance: towards efficient uncertainty quantification,
Hengyuan Ma, Yang Qi, Li Zhang, Wenlian Lu, Jianfeng Feng
This repository provides implementation for training and evaluating moment neural networks (MNNs) trained through the Supervised Mean and Unsupervised Covariance (SMUC) approach based on the Pytorch library.
Building robust, interpretable, and secure artificial intelligence system requires some degree of quantifying and representing uncertainty via a probabilistic perspective, as it allows to mimic human cognitive abilities. However, probabilistic computation presents significant challenges due to its inherent complexity. In this paper, we develop an efficient and interpretable probabilistic computation framework by truncating the probabilistic representation up to its first two moments, i.e., mean and covariance. We instantiate the framework by training a deterministic surrogate of a stochastic network that learns the complex probabilistic representation via combinations of simple activations, encapsulating the non-linearities coupling of the mean and covariance. We show that when the mean is supervised for optimizing the task objective, the unsupervised covariance spontaneously emerging from the non-linear coupling with the mean faithfully captures the uncertainty associated with model predictions. Our research highlights the inherent computability and simplicity of probabilistic computation, enabling its wider application in large-scale settings.
In models.py module, we provide full-connected MNN (Heaviside2, ReluNet2), mixed MNN (CNN_ReluNet2, CNN_Heaviside2),
and the corresponding verision using batch-wise trick (Heaviside, ReluNet, CNN_ReluNet, CNN_Heaviside).
activations.py: Heaviside and ReLU moment activation.models.py: network architectures of MNN such as Heaviside, ReluNet, CNN_ReluNet, CNN_Heaviside.stoc_models.py: full-connected feed-forward stochastic neural network.utils.py: useful utilities for analyzing results of MNN.dataset.py: preprocessing of datasets.test_classification.py: examples for training and evaluating MNN for classification tasks.test_regression.py: examples for training and evaluating MNN for regression tasks.test_ood.py: demo for out-of-detection detection using MNN.test_attack.py: demo for gradient-based adversarial attack defending and awareness using MNN.test_stoc_models.py: demo for simulating stochastic network using parameters of MNN.data: directory for storing datasets such as UCI regression dataset and notMNIST dataset.checkpoints: directory for storing some pretrained checkpoints.
- python 3.7
- pytorch: 1.11.0
- torchvision: 0.12.0
- scipy: 1.7.3
- numpy: 1.21.6
- matplotlib: 3.5.1
Code for training a Moment Neural Network (MNN) using the Supervised Mean and Unsupervised Covariance (SMUC) approach for image classification tasks on MNIST, FashionMNIST, CIFAR-10, and CIFAR-100. To provide a demonstration, we have constructed different architectures for each dataset, and feel free to modify them according to your specific needs. For the MNIST dataset, we have included an example of a fully-connected network. For the FashionMNIST dataset, we have provided an example of a modified LeNet5 network. For the CIFAR-10 dataset, we have included an example of a modified VGG13 network. For the CIFAR-100 dataset, we have included an example of a modified VGG13 network with large channels number.
$ python test_classification.py --dataset cifar10
--heaviside
--batch_cov
--amp
--train_batch 256
--test_batch 128
--opt adam
--lr 0.0001
--epochs 120
--weight_decay 0.0005
--save_model
datasetis the name of dataset (mnist, fashionmnist, cifar10, or cifar100).ampis the option for whether to use data amplificationheavisideis the option for whether to use Heaviside moment activations, if not true, ReLY moment activations are applied.batch_covis the option for whether to use batch-wise covariance trick during training (see Supplementary Information of the paper for details).train_batchis the batch size of training.test_batchis the batch size of evaulation.optis the type of optimizer (adam or sgd).lris the learning rate.epochsis the number of training epoches.weight_decayis the factor of weight decay.save_modelis the option for whether to save the model checkpoint.
Code for training a Moment Neural Network (MNN) using the Supervised Mean and Unsupervised Covariance (SMUC) approach
for regression tasks on the UCI regression datasets. Before running the following experiment,
please download the data/notMNIST_small.zip and extract its contents.
$ python test_regression.py --datadir data\uci_datasets
--dataset concrete
--heaviside
--data_norm
--norm_type min_max
--batch_cov
--train_batch 64
--opt sgd
--lr 0.0001
--epochs 1000
--save_model
--sigma1 0.5
--sigma2 0.5
--test_times 20
-
datadiris the directory of dataset. -
datasetis the name of dataset (housing, concrete, energy, kin8, protein, power, wine, yacht). -
heavisideis the option for whether to use Heaviside moment activations, if not true, ReLU moment activations are applied. -
data_normis the option for whether to normalize the dataset before training. -
norm_typeis the typy of data normalization (min-max, mean_std) -
batch_covis the option for whether to use batch-wise covariance trick during training (see Supplementary Information of the paper for details). -
train_batchis the batch size of training. -
optis the type of optimizer (adam or sgd). -
lris the learning rate. -
epochsis the number of training epoches. -
sigma1is the setting of the hyperparameter$\sigma_1$ . -
sigma2is the setting of the hyperparameter$\sigma_2$ . -
test_timesis the times of the experiments.
Code for out-of-distribution (OOD) detection using MNN. Before running the following experiment,
please download the pretrained checkpoint checkpoints/mnist_ood_demo.pt and the NotMNIST dataset
data/notMNIST_small.zip and extract it in the data directory.
python test_ood.py
Code for adversarial attack defense and awareness using Heaviside MNN.
the FGSM We apply fast gradient sign method (FGSM)
to generate adversarial samples.
If set the checkpoints/mnist_ood_demo.pt for the
MNIST case and checkpoints/cifar10_attack_demo.pt for the CIFAR-10 case.
python test_attack.py --dataset mnist
--defending
-
datasetis the name of dataset (mnist, cifar10). -
defendingis the option for whether to set the$\sigma_1,\sigma_2$ as zero to defend the adversarial attack.
Code for simulating the feed-forward neural network constructed by the parameters trained on MNN. Before running the following experiment,
please download the pretrained checkpoint checkpoints/mnist_heav_mlp.pt for the
Heaviside case and checkpoints/mnist_relu_mlp.pt for the ReLU case.
python test_stoc_model.py --net_type heav
--load_model_path checkpoints\mnist_heav_mlp.pt
--times 5500
--dt 0.1
net_typeis the type of the moment activation used network of dataset (heav, relu).load_model_pathis the path storing the pretrained checkpoint of MNN.timesis the number of the simulation time steps.dtis the size of the time step.