diff --git a/notebook/experiments.ipynb b/notebook/experiments.ipynb new file mode 100644 index 0000000..b7b6c11 --- /dev/null +++ b/notebook/experiments.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restricted Botlzman Machines (RBM)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#FIXME: Review the generation process (theoretically) and fix the implementation " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import List, Dict, Tuple, Literal, Optional, Union, Iterable\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import scipy.io\n", + "from tqdm import tqdm\n", + "from numpy._typing import ArrayLike\n", + "\n", + "ArrayLike = Union[List, Tuple, np.ndarray]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_FOLDER = \"../data/\"\n", + "ALPHA_DIGIT_PATH = os.path.join(DATA_FOLDER, \"binaryalphadigs.mat\")\n", + "MNIST_PATH = os.path.join(DATA_FOLDER, \"mnist_all.mat\")\n", + "\n", + "if not os.path.exists(ALPHA_DIGIT_PATH):\n", + " raise FileNotFoundError(f\"The file {ALPHA_DIGIT_PATH} does not exist.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 Implementing a RBM and testing on Binary AlphaDigits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _load_data(file_path: str) -> Dict[str, np.ndarray]:\n", + " \"\"\"\n", + " Load Binary AlphaDigits data from a .mat file.\n", + "\n", + " Parameters:\n", + " - file_path (str): Path to the .mat file containing the data.\n", + "\n", + " Returns:\n", + " - data (dict): Loaded data dictionary.\n", + " \"\"\"\n", + " if file_path is None:\n", + " raise ValueError(\"File path must be provided.\")\n", + "\n", + " return scipy.io.loadmat(file_path)\n", + "\n", + "\n", + "data = _load_data(ALPHA_DIGIT_PATH)\n", + "class_labels = data[\"classlabels\"].flatten() \n", + "class_count = data[\"classcounts\"].flatten()\n", + "df = pd.DataFrame(\n", + " {\n", + " \"Class Labels\": class_labels,\n", + " \"Class Count\": class_count\n", + " }\n", + ")\n", + "df[\"Class Labels\"] = df[\"Class Labels\"].apply(lambda x: x[0])\n", + "df[\"Class Count\"] = df[\"Class Count\"].apply(lambda x: x[0][0])\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _load_data(file_path: str, which: Literal[\"alphadigit\", \"mnist\"]=\"alphadigit\") -> Dict[str, np.ndarray]:\n", + " \"\"\"\n", + " Load Binary AlphaDigits data from a .mat file.\n", + "\n", + " Parameters:\n", + " - file_path (str): Path to the .mat file containing the data.\n", + " - which (Literal[\"alphadigit\", \"mnist\"], optional): Specifies \n", + " which data to load. The default value is \"alphadigit\".\n", + "\n", + " Returns:\n", + " - data (dict): A dictionary containing the loaded data.\n", + "\n", + " Raises:\n", + " - ValueError: If the file_path parameter is None.\n", + " - ValueError: If the which parameter is not \"alphadigit\".\n", + "\n", + " Example Usage:\n", + " ```python\n", + " data = _load_data(\"data.mat\", \"alphadigit\")\n", + " ```\n", + " \"\"\"\n", + " if file_path is None:\n", + " raise ValueError(\"File path must be provided.\")\n", + " \n", + " if which == \"alphadigit\":\n", + " return scipy.io.loadmat(file_path)[\"dat\"]\n", + " \n", + " raise ValueError(\"MNIST NOT YET AVAILABLE.\")\n", + "\n", + "alphadigit_data = _load_data(ALPHA_DIGIT_PATH) \n", + "print(alphadigit_data.shape)\n", + "print(alphadigit_data[0][0].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _map_characters_to_indices(characters: Union[str, int, List[Union[str, int]]]) -> List[int]:\n", + " \"\"\"\n", + " Map alphanumeric character to its corresponding index.\n", + "\n", + " Parameters:\n", + " - character (str, int, list of str or int): Alphanumeric character or its index.\n", + "\n", + " Returns:\n", + " - char_index (int): Corresponding index for the character.\n", + " \"\"\"\n", + " if isinstance(characters, list):\n", + " return [_map_characters_to_indices(char) for char in characters]\n", + " if isinstance(characters, int) and 0 <= characters <= 35:\n", + " return [characters]\n", + " if (isinstance(characters, str) and characters.isdigit()\n", + " and 0 <= int(characters) <= 9):\n", + " return [int(characters)]\n", + " if (isinstance(characters, str) and characters.isalpha()\n", + " and 'A' <= characters.upper() <= 'Z'):\n", + " return [ord(characters.upper()) - ord('A') + 10]\n", + " \n", + " raise ValueError(\n", + " \"Invalid character input. It should be an alphanumeric\" \n", + " \"character '[0-9|A-Z]' or its index representing '[0-35]'.\"\n", + " )\n", + "\n", + "for char in [0, 10, \"A\", [1, \"C\"], 36]:\n", + " try:\n", + " map = _map_characters_to_indices(char)\n", + " print(f\"{char} > map to > {map}\")\n", + " except:\n", + " print(f\"{char} > no mapping available, out of range\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def read_alpha_digit(characters: Optional[Union[str, int, List[Union[str, int]]]] = None,\n", + " file_path: Optional[str] = ALPHA_DIGIT_PATH,\n", + " data: Optional[Dict[str, np.ndarray]] = None,\n", + " use_data: bool = False,\n", + " ) -> np.ndarray:\n", + " \"\"\"\n", + " Reads binary AlphaDigits data from a .mat file or uses already loaded data. \n", + " It extracts the data for a specified alphanumeric character or its index, and \n", + " flattens the images into one-dimensional vectors.\n", + "\n", + " Parameters:\n", + " - characters (Union[str, int, List[Union[str, int]]], optional): Alphanumeric character \n", + " or its index whose data needs to be extracted. It can be a single character or \n", + " a list of characters. Default is None.\n", + " - file_path (str, optional): Path to the .mat file containing the data. \n", + " Default is None.\n", + " - data (dict, optional): Already loaded data dictionary. \n", + " Default is None.\n", + " - use_data (bool): Flag to indicate whether to use already loaded data.\n", + " Default is False.\n", + "\n", + " Returns:\n", + " - flattened_images (numpy.ndarray): Flattened images for the specified character(s).\n", + " \"\"\"\n", + " if not use_data:\n", + " data = _load_data(file_path, which=\"alphadigit\")\n", + "\n", + " char_indices = _map_characters_to_indices(characters)\n", + "\n", + " # Select the rows corresponding to the characters indices.\n", + " char_data: np.ndarray = data[char_indices]\n", + " \n", + " # Flatten each image into a one-dimensional vector.\n", + " flattened_images = np.array([image.flatten() for image in char_data.flatten()])\n", + " return flattened_images\n", + "\n", + "def plot_characters(chars, data):\n", + " num_chars = len(chars)\n", + " num_images_per_char = data.shape[0] // num_chars\n", + " fig, ax = plt.subplots(1, num_chars, figsize=(num_chars * 2, 2))\n", + "\n", + " for i, char in enumerate(chars):\n", + " # Find the index of the first image corresponding to the current char\n", + " start_index = i * num_images_per_char\n", + " image = data[start_index].reshape(20, 16)\n", + " ax[i].imshow(image, cmap='gray')\n", + " ax[i].set_title(f'Char: {char}')\n", + " ax[i].axis('off')\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "# Example\n", + "chars = [0, \"K\", 7, \"Z\"]\n", + "data = read_alpha_digit(chars, data=alphadigit_data, use_data=True)\n", + "plot_characters(chars, data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"data shape:\", data.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RBM:\n", + " def __init__(self, n_visible: int, n_hidden: int=100, random_state=None) -> None:\n", + " \"\"\"\n", + " Initialize the Restricted Boltzmann Machine.\n", + "\n", + " Parameters:\n", + " - n_visible (int): Number of visible units.\n", + " - n_hidden (int): Number of hidden units. Default 100.\n", + " - random_state: Random seed for reproducibility.\n", + " \"\"\"\n", + " self.n_visible = n_visible\n", + " self.n_hidden = n_hidden\n", + " \n", + " self.a = np.zeros((1, n_visible)) # visible_bias\n", + " self.b = np.zeros((1, n_hidden)) # hidden_bias\n", + " self.rng = np.random.default_rng(random_state)\n", + " self.W = 1e-4 * self.rng.standard_normal(size=(n_visible, n_hidden)) # weights\n", + "\n", + " def __repr__(self) -> str:\n", + " return f\"RBM(n_visible={self.n_visible}, n_hidden={self.n_hidden})\"\n", + "\n", + " def _sigmoid(self, x: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Sigmoid activation function.\n", + "\n", + " Parameters:\n", + " - x (numpy.ndarray): Input array.\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Result of applying the sigmoid function to the input.\n", + " \"\"\"\n", + " return 1 / (1 + np.exp(-x))\n", + " \n", + " def _reconstruction_error(self, input: np.ndarray, image: np.ndarray) -> float:\n", + " \"\"\"\n", + " Compute reconstruction error.\n", + "\n", + " Parameters:\n", + " - input (numpy.ndarray): Original input data.\n", + " - image (numpy.ndarray): Reconstructed image.\n", + "\n", + " Returns:\n", + " - float: Reconstruction error.\n", + " \"\"\"\n", + " return np.round(np.power(image - input, 2).mean(), 5)\n", + "\n", + " def input_output(self, data: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Compute hidden units given visible units.\n", + "\n", + " Parameters:\n", + " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Hidden unit activations, shape (n_samples, n_hidden).\n", + " \"\"\"\n", + " return self._sigmoid(data @ self.W + self.b)\n", + "\n", + " def output_input(self, data_h: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Compute visible units given hidden units.\n", + "\n", + " Parameters:\n", + " - data_h (numpy.ndarray): Hidden unit activations, shape (n_samples, n_hidden).\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Reconstructed visible units, shape (n_samples, n_visible).\n", + " \"\"\"\n", + " return self._sigmoid(data_h @ self.W.T + self.a)\n", + " \n", + " def calcul_softmax(self, data: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Calculate softmax probabilities for the output units.\n", + "\n", + " Parameters:\n", + " - input_data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Softmax probabilities, shape (n_samples, n_hidden).\n", + " \"\"\"\n", + " # Compute activations for the hidden layer\n", + " hidden_activations = self.input_output(data)\n", + " \n", + " # Compute softmax probabilities for the output layer\n", + " exp_hidden_activations = np.exp(hidden_activations)\n", + " softmax_probs = exp_hidden_activations / np.sum(exp_hidden_activations, axis=1, keepdims=True)\n", + " \n", + " return softmax_probs\n", + "\n", + " def update(\n", + " self, \n", + " batch: np.ndarray,\n", + " learning_rate: float=0.1,\n", + " batch_size: Optional[int]=None,\n", + " return_output: bool=False\n", + " ):\n", + " \"\"\"_summary_\n", + "\n", + " Args:\n", + " batch (np.ndarray): _description_\n", + " learning_rate (float, optional): _description_. Defaults to 0.1.\n", + " batch_size (Optional[int], optional): _description_. Defaults to None.\n", + " return_output (bool, optional): _description_. Defaults to False.\n", + " \"\"\"\n", + " if not batch_size:\n", + " batch_size = batch.shape[0]\n", + " pos_h_probs = self.input_output(batch)\n", + " pos_v_probs = self.output_input(pos_h_probs)\n", + " neg_h_probs = self.input_output(pos_v_probs)\n", + " \n", + " # Update weights and biases\n", + " self.W += learning_rate * (batch.T @ pos_h_probs - pos_v_probs.T @ neg_h_probs) / batch_size\n", + " self.b += learning_rate * (pos_h_probs - neg_h_probs).mean(axis=0)\n", + " self.a += learning_rate * (batch - pos_v_probs).mean(axis=0)\n", + "\n", + " if return_output:\n", + " return self, pos_v_probs\n", + " \n", + " return self \n", + "\n", + " def train(self, \n", + " data: np.ndarray,\n", + " learning_rate: float=0.1,\n", + " n_epochs: int=10,\n", + " batch_size: int=10,\n", + " print_each=10\n", + " ) -> 'RBM':\n", + " \"\"\"\n", + " Train the RBM using Contrastive Divergence.\n", + "\n", + " Parameters:\n", + " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + " - learning_rate (float): Learning rate for gradient descent. Default is 0.1.\n", + " - n_epochs (int): Number of training epochs. Default is 10.\n", + " - batch_size (int): Size of mini-batches. Default is 10.\n", + "\n", + " Returns:\n", + " - RBM: Trained RBM instance.\n", + " \"\"\"\n", + " n_samples = data.shape[0]\n", + " for epoch in range(n_epochs):\n", + " self.rng.shuffle(data)\n", + " for i in tqdm(range(0, n_samples, batch_size), desc=f\"Epoch {epoch}\"):\n", + " batch = data[i:i+batch_size]\n", + " _, pos_v_probs = self.update(\n", + " batch=batch,\n", + " learning_rate=learning_rate,\n", + " batch_size=batch_size,\n", + " return_output=True\n", + " )\n", + " \n", + " if epoch % print_each == 0:\n", + " tqdm.write(\n", + " f\"Reconstruction error: {self._reconstruction_error(batch, pos_v_probs)}.\")\n", + "\n", + " return self\n", + "\n", + " def generate_image(self, n_samples: int=1, n_gibbs_steps: int=1) -> np.ndarray:\n", + " \"\"\"\n", + " Generate samples from the RBM using Gibbs sampling.\n", + "\n", + " Parameters:\n", + " - n_samples (int): Number of samples to generate. Default is 10.\n", + " - n_gibbs_steps (int): Number of Gibbs sampling steps. Default is 1.\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Generated samples, shape (n_samples, n_visible).\n", + " \"\"\"\n", + " samples = np.zeros((n_samples, self.n_visible))\n", + " \n", + " # Matrix of initlization value of Gibbs samples for each sample. \n", + " V = self.rng.binomial(1, self.rng.random(), size=n_samples*self.n_visible).reshape((n_samples, self.n_visible))\n", + " for i in range(n_samples):\n", + " for _ in range(n_gibbs_steps):\n", + " h_probs = self._sigmoid(V[i] @ self.W + self.b) # vector\n", + " h = self.rng.binomial(1, h_probs)\n", + " v_probs = self._sigmoid(h @ self.W.T + self.a)\n", + " v = self.rng.binomial(1, v_probs)\n", + " samples[i] = v\n", + " return samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the alpha_digit data\n", + "data = read_alpha_digit(file_path=ALPHA_DIGIT_PATH, characters=['Z'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parameters\n", + "n_visible = data.shape[1] # Number of visible units (size of each image)\n", + "n_hidden = 200 # Number of hidden units (hyperparameter)\n", + "\n", + "# Initialize RBM\n", + "rbm = RBM(n_visible=n_visible, n_hidden=n_hidden, random_state=42)\n", + "print(rbm)\n", + "\n", + "# Train RBM\n", + "rbm.train(data, learning_rate=0.1, n_epochs=500, batch_size=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_allclose(rbm.calcul_softmax(data).sum(1), 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate samples\n", + "generated_samples = rbm.generate_image(n_samples=10, n_gibbs_steps=1)\n", + "\n", + "# Plot original and generated samples\n", + "plt.figure(figsize=(12, 6))\n", + "for i in range(10):\n", + " plt.subplot(2, 10, i + 1)\n", + " plt.imshow(data[i].reshape(20, 16), cmap='gray')\n", + " plt.title('Original')\n", + " plt.axis('off')\n", + " \n", + " plt.subplot(2, 10, i + 11)\n", + " plt.imshow(generated_samples[i].reshape(20, 16), cmap='gray')\n", + " plt.title('Generated')\n", + " plt.axis('off')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(rbm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Implementing a Deep Belief Network (DBN) and test on Binary AlphaDigits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DBN:\n", + " def __init__(self, n_visible: int, hidden_layer_sizes: list[int], random_state=None):\n", + " \"\"\"\n", + " Initialize the Deep Belief Network.\n", + "\n", + " Parameters:\n", + " - n_visible (int): Number of visible units.\n", + " - hidden_layer_sizes (list[int]): List of sizes for each hidden layer.\n", + " - random_state: Random seed for reproducibility.\n", + " \"\"\"\n", + " self.n_visible = n_visible\n", + " self.hidden_layer_sizes = hidden_layer_sizes\n", + " self.rbms: List[RBM] = []\n", + " self.rng = np.random.default_rng(random_state)\n", + "\n", + " # Initialize the first RBM\n", + " first_rbm = RBM(\n", + " n_visible=n_visible,\n", + " n_hidden=hidden_layer_sizes[0],\n", + " random_state=random_state,\n", + " )\n", + " self.rbms.append(first_rbm)\n", + "\n", + " # Initialize RBMs for subsequent hidden layers\n", + " for i, size in enumerate(hidden_layer_sizes[1:], start=1):\n", + " rbm = RBM(\n", + " n_visible=hidden_layer_sizes[i - 1],\n", + " n_hidden=size,\n", + " random_state=random_state,\n", + " )\n", + " self.rbms.append(rbm)\n", + "\n", + "\n", + " def __getitem__(self, key):\n", + " return self.rbms[key]\n", + " \n", + "\n", + " def __repr__(self):\n", + " \"\"\"\n", + " Return a string representation of the DBN object.\n", + " \"\"\"\n", + " rbm_reprs = [f\"{'':4}{repr(rbm)}\" for rbm in self.rbms]\n", + " join_rbm_reprs = ',\\n'.join(rbm_reprs)\n", + " return f\"DBN([\\n{join_rbm_reprs}\\n])\"\n", + "\n", + "\n", + " def train(self,\n", + " data: np.ndarray,\n", + " learning_rate: float=0.1,\n", + " n_epochs: int=10,\n", + " batch_size: int=10,\n", + " print_each: int=10,\n", + " ) -> \"DBN\":\n", + " \"\"\"\n", + " Train the DBN using Greedy layer-wise procedure.\n", + "\n", + " Parameters:\n", + " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + " - learning_rate (float): Learning rate for gradient descent. Default is 0.1.\n", + " - n_epochs (int): Number of training epochs. Default is 10.\n", + " - batch_size (int): Size of mini-batches. Default is 10.\n", + " - print_each: Print reconstruction error each `print_each` epochs.\n", + " - verbose\n", + "\n", + " Returns:\n", + " - DBN: Trained DBN instance.\n", + " \"\"\"\n", + " input_data = data\n", + " for rbm in self.rbms:\n", + " rbm.train(\n", + " input_data,\n", + " learning_rate=learning_rate,\n", + " n_epochs=n_epochs,\n", + " batch_size=batch_size,\n", + " print_each=print_each,\n", + " )\n", + " # Update input data for the next RBM\n", + " input_data = rbm.input_output(input_data)\n", + "\n", + " return self\n", + "\n", + " def generate_image(self, n_samples: int=1, n_gibbs_steps: int=1) -> np.ndarray:\n", + " \"\"\"\n", + " Generate samples from the DBN using Gibbs sampling.\n", + "\n", + " Parameters:\n", + " - n_samples (int): Number of samples to generate. Default is 1.\n", + " - n_gibbs_steps (int): Number of Gibbs sampling steps. Default is 100.\n", + "\n", + " Returns:\n", + " - numpy.ndarray: Generated samples, shape (n_samples, n_visible).\n", + " \"\"\"\n", + " # samples = np.zeros((n_samples, self.n_visible))\n", + "\n", + " # Generate samples using the first RBM in the DBN\n", + " samples = self.rbms[-1].generate_image(n_samples, n_gibbs_steps)\n", + " for rbm in reversed(self.rbms[:-1]):\n", + " # Sample from the conditional probability of layer l-1 given layer l: p(h_{s-1}|h_{s}).\n", + " h_probs = rbm.output_input(samples)\n", + " h = self.rng.binomial(1, p=h_probs) \n", + " samples = h\n", + " return samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from principal_dbn_alpha import DBN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_visible=data.shape[1]\n", + "hidden_layer_sizes = [100, 50, 25]\n", + "\n", + "dbn = DBN(n_visible=n_visible, hidden_layer_sizes=hidden_layer_sizes, random_state=42)\n", + "dbn.train(data, learning_rate=0.1, n_epochs=10, batch_size=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if the RBM are accessibles via a slicing \n", + "print(dbn[1:3])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Generate images\n", + "generated_images = dbn.generate_image(n_samples=5, n_gibbs_steps=1)\n", + "\n", + "# Display generated images\n", + "fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(12, 4))\n", + "for i in range(5):\n", + " axes[i].imshow(generated_images[i].reshape(20, 16), cmap='gray')\n", + " axes[i].set_title(f\"Image {i+1}\")\n", + " axes[i].axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DNN(DBN):\n", + " def __init__(\n", + " self,\n", + " input_dim: int,\n", + " output_dim: int,\n", + " hidden_layer_sizes: List[int],\n", + " random_state=None\n", + " ):\n", + " \"\"\"\n", + " Initialize the Deep Neural Network (DNN).\n", + "\n", + " Parameters:\n", + " - input_dim (int): Dimension of the input.\n", + " - output_dim (int): Dimension of the output.\n", + " - hidden_layer_sizes (List[int]): List of sizes for each hidden layer.\n", + " - random_state: Random seed for reproducibility.\n", + " \"\"\"\n", + " super().__init__(\n", + " n_visible=input_dim,\n", + " hidden_layer_sizes=hidden_layer_sizes,\n", + " random_state=random_state\n", + " )\n", + " #--> self.rbms contains only the pre-trainable RBMs \n", + " self.clf = RBM(self.rbms[-1].n_hidden, output_dim)\n", + " self.network = self.rbms + [self.clf] # DNN = [DBN + Classifier] ~ [RBM_0,...,RBM_N, RBM_Clf]\n", + "\n", + " def __getitem__(self, key):\n", + " return self.network[key]\n", + " \n", + " def __repr__(self):\n", + " join_repr = \"\\n\".join([f\"{'':4}{repr(rbm)},\" for rbm in self.network])\n", + " return f\"DNN([\\n{join_repr} \\n])\"\n", + " \n", + " \n", + " def pretrain(self, n_epochs: int, learning_rate: float, batch_size: int, data: np.ndarray) -> \"DNN\":\n", + " \"\"\"\n", + " Pretrain the hidden layers of the DNN using the DBN training method.\n", + "\n", + " Parameters:\n", + " - n_epochs (int): Number of training epochs.\n", + " - learning_rate (float): Learning rate for gradient descent.\n", + " - batch_size (int): Size of mini-batches.\n", + " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + "\n", + " Returns:\n", + " - DNN: Pretrained DNN instance.\n", + " \"\"\"\n", + " # NOTE: Use the inherited `train` method to perform pre-training since `self.rbms`\n", + " # only contains the pre-trainable RBMs.\n", + " return self.train(data, n_epochs=n_epochs, learning_rate=learning_rate, batch_size=batch_size)\n", + " \n", + " def input_output(self, input_data: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:\n", + " \"\"\"\n", + " Get the outputs on each layer of the DNN and the softmax probabilities on the output layer.\n", + "\n", + " Parameters:\n", + " - input_data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + "\n", + " Returns:\n", + " - Tuple[List[np.ndarray], np.ndarray]: Outputs on each layer & softmax probabilities.\n", + " \"\"\"\n", + " layer_outputs = []\n", + " \n", + " # Input layer output\n", + " layer_outputs.append(input_data)\n", + " \n", + " # Hidden layers output\n", + " for rbm in self.rbms:\n", + " layer_outputs.append(rbm.input_output(layer_outputs[-1]))\n", + " \n", + " # Softmax probabilities on the output layer\n", + " output_probs = self.network[-1].calcul_softmax(layer_outputs[-1])\n", + " \n", + " return layer_outputs, output_probs\n", + " \n", + "\n", + " def _cross_entropy(batch_labels: np.ndarray, output_probs: np.ndarray, eps: float = 1e-15) -> float:\n", + " \"\"\"\n", + " Calculate the cross entropy between the batch labels and output probabilities.\n", + "\n", + " Parameters:\n", + " - batch_labels (numpy.ndarray): True labels for the batch, shape (batch_size, n_classes).\n", + " - output_probs (numpy.ndarray): Predicted probabilities for the batch, shape (batch_size, n_classes).\n", + " - eps (float): Small value to avoid numerical instability in logarithm calculation. Default is 1e-15.\n", + "\n", + " Returns:\n", + " - float: Cross entropy value.\n", + " \"\"\"\n", + " return -np.mean(np.sum(batch_labels * np.log(output_probs + eps), axis=1))\n", + "\n", + "\n", + " def backpropagation(\n", + " self,\n", + " input_data: np.ndarray,\n", + " labels: np.ndarray,\n", + " n_epochs: int,\n", + " learning_rate: float,\n", + " batch_size: int,\n", + " eps: float = 1e-15\n", + " ) -> \"DNN\":\n", + " \"\"\"\n", + " Estimate the weights/biases of the network using backpropagation algorithm.\n", + "\n", + " Parameters:\n", + " - input_data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", + " - labels (numpy.ndarray): Labels for the input data, shape (n_samples, n_classes).\n", + " - n_epochs (int): Number of training epochs.\n", + " - learning_rate (float): Learning rate for gradient descent.\n", + " - batch_size (int): Size of mini-batches.\n", + " - eps (float): Small value to avoid numerical instability in logarithm calculation. Default is 1e-15.\n", + "\n", + " Returns:\n", + " - DNN: Updated DNN instance.\n", + " \"\"\"\n", + " n_samples = input_data.shape[0]\n", + " \n", + " for epoch in tqdm(range(n_epochs), desc=\"Training\", unit=\"epoch\"):\n", + " for batch_start in range(0, n_samples, batch_size):\n", + " batch_end = min(batch_start + batch_size, n_samples)\n", + " batch_input = input_data[batch_start:batch_end]\n", + " batch_labels = labels[batch_start:batch_end]\n", + "\n", + " # Forward pass\n", + " layer_outputs, output_probs = self.input_output(batch_input)\n", + "\n", + " # Backward pass (update weights and biases)\n", + " self.network[-1].update(batch_labels, layer_outputs[-1], learning_rate)\n", + " for i in range(len(self.network) - 2, -1, -1):\n", + " self.network[i].update(layer_outputs[i], layer_outputs[i + 1], self.network[i + 1].weights, learning_rate)\n", + "\n", + " # Calculate cross entropy after each epoch\n", + " loss = self._cross_entropy(batch_labels, output_probs, eps)\n", + " tqdm.write(f\"Epoch {epoch + 1}/{n_epochs}, Cross Entropy: {loss}\")\n", + "\n", + " return self\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_visible=data.shape[1]\n", + "hidden_layer_sizes = [100, 50, 25]\n", + "output_dim = 20\n", + "\n", + "dnn = DNN(input_dim=n_visible, hidden_layer_sizes=hidden_layer_sizes, output_dim=output_dim, random_state=42)\n", + "# keep last RBM's weights for further test.\n", + "weights = dnn[-1].W \n", + "\n", + "dnn.train(data, learning_rate=0.1, n_epochs=10, batch_size=10)\n", + "\n", + "# Check that the last RBM has not been trained.\n", + "np.testing.assert_equal (dnn[-1].a, 0) # visible bias\n", + "np.testing.assert_equal (dnn[-1].b, 0) # hidden bias\n", + "np.testing.assert_equal (dnn[-1].W, weights) # weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "layer_outputs, softmax_output = dnn.input_output(data)\n", + "n_layers_net = len(layer_outputs) + 1\n", + "print(\"No. network output =\", n_layers_net)\n", + "\n", + "print(f\"Input data (0): {layer_outputs[0].shape}\")\n", + "for idx, layer_output in enumerate(layer_outputs[1:]):\n", + " print(f\"Hidden layer input ({idx+1}): {layer_output.shape}\")\n", + "\n", + "print(f\"Softmax output ({n_layers_net - 1}):\", softmax_output.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebook/experiments_ALPHA_DIGITS.ipynb b/notebook/experiments_ALPHA_DIGITS.ipynb new file mode 100644 index 0000000..ef7b765 --- /dev/null +++ b/notebook/experiments_ALPHA_DIGITS.ipynb @@ -0,0 +1,468 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analysis on ALPHA DIGITS " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install nbformat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%run experiments.ipynb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## I- Effect of Layers and units" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we create a special function : generate_symmetric_configurations to create liste of hidden layer we want to use for the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_symmetric_configurations(min_layers, max_layers, min_neurons, max_neurons, step_neurons):\n", + " \"\"\"\n", + " Generate symmetrical configurations for the DBN's hidden layers.\n", + "\n", + " Args:\n", + " min_layers (int): Minimum number of hidden layers.\n", + " max_layers (int): Maximum number of hidden layers.\n", + " min_neurons (int): Minimum number of neurons per layer.\n", + " max_neurons (int): Maximum number of neurons per layer.\n", + " step_neurons (int): No increase in the number of neurons.\n", + "\n", + " Returns:\n", + " List[List[int]]: List of symmetrical configurations of hidden layers.\n", + " \"\"\"\n", + " configurations = []\n", + " for num_layers in range(min_layers, max_layers + 1):\n", + " for num_neurons in range(min_neurons, max_neurons + 1, step_neurons):\n", + " half = num_layers // 2\n", + " config = [num_neurons] * half + [num_neurons] * (num_layers - half)\n", + " configurations.append(config)\n", + " return configurations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.RBM launch function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_rbm_experiment(hidden_units_sizes, n_epochs=100, character_sets=[['A', 'B'], ['1', '2', '3', '4'], ['A', 'B', '1', '2']]):\n", + " \"\"\"\n", + " Conducts an experiment with Restricted Boltzmann Machines (RBMs) on different character sets.\n", + "\n", + " Args:\n", + " hidden_units_sizes (list): A list of sizes for the hidden layers to experiment with.\n", + " n_epochs (int): The number of training epochs. Default is 100.\n", + " character_sets (list of lists): A list containing different sets of characters to train RBMs on.\n", + "\n", + " This function trains an RBM for each combination of character set and hidden unit size,\n", + " generates and saves images representing the learned features, and plots the results in a grid.\n", + " \"\"\"\n", + " # Determine the unique number of units\n", + " unique_units = sorted(hidden_units_sizes)\n", + "\n", + " # Prepare a grid of subplots\n", + " fig, axes = plt.subplots(len(character_sets), len(unique_units), figsize=(len(unique_units) * 3, len(character_sets) * 3), squeeze=False)\n", + "\n", + " for row_idx, characters in enumerate(character_sets):\n", + " data = read_alpha_digit(characters, file_path=ALPHA_DIGIT_PATH)\n", + "\n", + " for col_idx, num_units in enumerate(unique_units):\n", + " print(f\"\\nTraining RBM with {num_units} hidden units on characters {characters}\")\n", + " rbm = RBM(n_visible=data.shape[1], n_hidden=num_units, random_state=42)\n", + " rbm.train(data, learning_rate=0.1, n_epochs=n_epochs, batch_size=15, print_each=5000)\n", + "\n", + " # Generate and display an image\n", + " generated_image = rbm.generate_image(n_samples=1)\n", + "\n", + " # Save the image\n", + " save_path = f\"../resultat/rbm/{num_units}_Units_{len(characters)}_Chars.npy\"\n", + " os.makedirs(os.path.dirname(save_path), exist_ok=True)\n", + " np.save(save_path, generated_image)\n", + "\n", + " ax = axes[row_idx, col_idx]\n", + " ax.imshow(generated_image[0].reshape(20, 16), cmap='plasma')\n", + " ax.set_title(f\"Units: {num_units}, N_Chars: {len(characters)}\")\n", + " ax.axis('off')\n", + "\n", + " plt.tight_layout()\n", + " # Save the generated image\n", + " directory_image = \"../resultat/images/rbm\"\n", + " os.makedirs(directory_image, exist_ok=True)\n", + " plt.savefig(f\"{directory_image}/rbm_{len(characters)}_chars_Units_{num_units}_Layers_{character_sets}.png\")\n", + " plt.tight_layout()\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hidden_units_sizes = [100, 200, 300, 400, 500, 600, 700]\n", + "run_rbm_experiment(hidden_units_sizes, n_epochs=1000, character_sets=['E'])\n", + "run_rbm_experiment(hidden_units_sizes, n_epochs=1000, character_sets=['A'])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. DBM launch function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def run_dbm_experiment(hidden_layers_sizes, n_epochs=100, characters=['A', 'B', '1', '2']):\n", + " \"\"\"\n", + " Conducts an experiment with Deep Belief Networks (DBNs) on a set of characters.\n", + "\n", + " Args:\n", + " hidden_layers_sizes (list of lists): A list containing the sizes of hidden layers to experiment with.\n", + " n_epochs (int): The number of training epochs. Default is 100.\n", + " characters (list): The characters to use in the experiment. Default is ['A', 'B', '1', '2'].\n", + "\n", + " This function trains a DBN for each specified configuration of hidden layer sizes,\n", + " generates and saves images representing the learned features, and plots the results in a grid.\n", + " \"\"\"\n", + " # Load the data\n", + " data = read_alpha_digit(characters, file_path=ALPHA_DIGIT_PATH)\n", + "\n", + " # Determine the maximum number of layers and the unique number of units\n", + " max_layers = max(len(sizes) for sizes in hidden_layers_sizes)\n", + " unique_units = sorted({sizes[0] for sizes in hidden_layers_sizes})\n", + "\n", + " # Prepare a grid of subplots\n", + " fig, axes = plt.subplots(len(unique_units), max_layers, figsize=(max_layers * 3, len(unique_units) * 3), squeeze=False)\n", + "\n", + " # Initialize all axes as invisible; they will be activated when used\n", + " for ax_row in axes:\n", + " for ax in ax_row:\n", + " ax.set_visible(False)\n", + "\n", + " for layer_sizes in hidden_layers_sizes:\n", + " print(f\"\\nTraining DBN with hidden layers: {layer_sizes}\")\n", + " dbn = DBN(n_visible=data.shape[1], hidden_layer_sizes=layer_sizes, random_state=42)\n", + " dbn.train(data, learning_rate=0.1, n_epochs=n_epochs, batch_size=16, print_each=1000000)\n", + "\n", + " # Generate and display an image\n", + " generated_image = dbn.generate_image(n_samples=1)\n", + " unit_idx = unique_units.index(layer_sizes[0])\n", + " layer_idx = len(layer_sizes) - 2 # Index 0 for 2 layers, index 1 for 3 layers, etc.\n", + "\n", + " ax = axes[unit_idx][layer_idx]\n", + " ax.set_visible(True)\n", + " ax.imshow(generated_image[0].reshape(20, 16), cmap='plasma')\n", + " ax.set_title(f\"N_Layers: {len(layer_sizes)}, N_Units: {layer_sizes[0]}\")\n", + " ax.axis('off')\n", + "\n", + " # Save the generated image\n", + " directory = f\"../resultat/dbn/{layer_sizes[0]}_Units_{len(layer_sizes)}_Layers\"\n", + " os.makedirs(directory, exist_ok=True)\n", + " np.save(f\"{directory}/Units_{layer_sizes[0]}_Chars_{''.join(characters)}.npy\", generated_image[0])\n", + "\n", + " # Save the figure\n", + " plt.tight_layout()\n", + " directory_image = \"../resultat/images/dbn\"\n", + " os.makedirs(directory_image, exist_ok=True)\n", + " plt.savefig(f\"{directory_image}/dbn_{len(characters)}_chars_{layer_sizes[0]}_Units_{len(layer_sizes)}_Layers.png\")\n", + " plt.tight_layout()\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Exemple d'utilisation avec des configurations générées\n", + "configurations = generate_symmetric_configurations(min_layers = 2, max_layers = 5, min_neurons = 100, max_neurons = 700, step_neurons = 100)\n", + "run_dbm_experiment(configurations, n_epochs=1, characters=['Y'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## I- Effect of the number of characters on reconstruction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. RBM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we modify the above corresponding function to have plot adapted to our analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "\n", + "def run_rbm_experiment(hidden_units_sizes, n_epochs=100, character_sets=[['A', 'B'], ['1', '2', '3', '4'], ['A', 'B', '1', '2']]):\n", + " \"\"\"\n", + " Conducts an experiment with Restricted Boltzmann Machines (RBMs) on different character sets,\n", + " generating multiple samples for each configuration.\n", + "\n", + " Args:\n", + " hidden_units_sizes (list): A list of sizes for the hidden units to experiment with.\n", + " n_epochs (int): The number of training epochs. Default is 100.\n", + " character_sets (list of lists): A list containing different sets of characters to train RBMs on.\n", + "\n", + " This function trains an RBM for each combination of character set and hidden unit size,\n", + " generates and displays five samples from each trained RBM, and saves the results.\n", + " \"\"\"\n", + " unique_units = sorted(hidden_units_sizes)\n", + "\n", + " # Prepare a grid of subplots; each configuration now has 5 columns for the 5 samples\n", + " fig, axes = plt.subplots(len(character_sets), len(unique_units) * 5, figsize=(len(unique_units) * 3 * 5, len(character_sets) * 3), squeeze=False)\n", + "\n", + " for row_idx, characters in enumerate(character_sets):\n", + " data = read_alpha_digit(characters, file_path=ALPHA_DIGIT_PATH)\n", + "\n", + " for col_idx, num_units in enumerate(unique_units):\n", + " print(f\"\\nTraining RBM with {num_units} hidden units on characters {characters}\")\n", + " rbm = RBM(n_visible=data.shape[1], n_hidden=num_units, random_state=42)\n", + " rbm.train(data, learning_rate=0.1, n_epochs=n_epochs, batch_size=15, print_each=5000)\n", + "\n", + " # Generate 5 images\n", + " generated_images = rbm.generate_image(n_samples=5)\n", + "\n", + " for sample_idx in range(5):\n", + " ax = axes[row_idx, col_idx * 5 + sample_idx]\n", + " ax.imshow(generated_images[sample_idx].reshape(20, 16), cmap='plasma')\n", + " ax.set_title(f\"N_chars {len(characters)}, Generation: {sample_idx + 1}\")\n", + " ax.axis('off')\n", + "\n", + " # Save each generated sample\n", + " save_path = f\"../resultat/rbm/{num_units}_Units_{len(characters)}_Chars_Sample_{sample_idx}.npy\"\n", + " os.makedirs(os.path.dirname(save_path), exist_ok=True)\n", + " np.save(save_path, generated_images[sample_idx])\n", + "\n", + " plt.tight_layout()\n", + " directory_image = \"../resultat/images/rbm\"\n", + " os.makedirs(directory_image, exist_ok=True)\n", + " plt.savefig(f\"{directory_image}/rbm_samples.png\")\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# configuration = 2 layer with 200 units each\n", + "configurations_fixe = [200]\n", + "run_rbm_experiment(configurations_fixe, n_epochs=2, character_sets = [['E'],['E', 'O'], ['E', 'O', 'A'], ['E', 'O', 'A', '2'], ['E', 'O', 'A', '2', '7']])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. DBM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we modify the above corresponding function to have plot adapted to our analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "\n", + "def run_dbm_experiment(hidden_layers_sizes, n_epochs=100, character_sets=[['A', 'B'], ['1', '2', '3', '4'], ['A', 'B', '1', '2']]):\n", + " \"\"\"\n", + " Conducts an experiment with Deep Boltzmann Machines (DBMs) on different character sets,\n", + " generating multiple samples for each configuration.\n", + "\n", + " Args:\n", + " hidden_layers_sizes (list of lists): A list containing the sizes of hidden layers to experiment with.\n", + " n_epochs (int): The number of training epochs. Default is 100.\n", + " character_sets (list of lists): A list containing different sets of characters to train DBMs on.\n", + "\n", + " For each character set, this function trains a DBM, generates and displays five samples,\n", + " and saves the generated images and the complete figure for each set.\n", + " \"\"\"\n", + " # Assumes there's only one configuration of hidden layers sizes provided\n", + " layer_sizes = hidden_layers_sizes[0]\n", + "\n", + " # For each set of characters, generate and display images\n", + " for characters in character_sets:\n", + " data = read_alpha_digit(characters, file_path=ALPHA_DIGIT_PATH)\n", + "\n", + " # Initialize a new figure\n", + " plt.figure(figsize=(15, 3)) # Adjusted size for the set of subplots\n", + "\n", + " print(f\"\\nTraining DBN with hidden layers: {layer_sizes}\")\n", + " dbn = DBN(n_visible=data.shape[1], hidden_layer_sizes=layer_sizes, random_state=42)\n", + " dbn.train(data, learning_rate=0.1, n_epochs=n_epochs, batch_size=16, print_each=1000000)\n", + "\n", + " # Generate 5 images\n", + " generated_images = dbn.generate_image(n_samples=5)\n", + "\n", + " for img_idx in range(5):\n", + " ax = plt.subplot(1, 5, img_idx + 1)\n", + " ax.imshow(generated_images[img_idx].reshape(20, 16), cmap='plasma') # Ensure the shape is correct\n", + " ax.set_title(f\"N_chars {len(characters)}, generation: {img_idx + 1}\")\n", + " ax.axis('off')\n", + "\n", + " # Save the generated images\n", + " directory = f\"../resultat/dbn/{'_'.join([str(size) for size in layer_sizes])}_Units_{len(characters)}_Chars\"\n", + " os.makedirs(directory, exist_ok=True)\n", + " for img_idx, img in enumerate(generated_images):\n", + " np.save(f\"{directory}/Sample_{img_idx}_Chars_{''.join(characters)}.npy\", img)\n", + "\n", + " # Save the complete figure for this set of characters\n", + " plt.tight_layout()\n", + " directory_image = f\"../resultat/images/dbn/{'_'.join(characters)}\"\n", + " os.makedirs(directory_image, exist_ok=True)\n", + " plt.savefig(f\"{directory_image}/dbn_{len(characters)}_chars_{'_'.join([str(size) for size in layer_sizes])}_Units.png\")\n", + "\n", + " # Display all figures at the end of the loop\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example fixed configuration: two layers with 200 units each\n", + "fixed_configuration = [[400, 400, 400, 400]]\n", + "\n", + "# Run the experiment with the fixed configuration and different character sets\n", + "run_dbm_experiment(fixed_configuration, n_epochs=2, character_sets=[['E'],['E', 'Y'], ['E', 'Y', 'A'], ['E', 'Y', 'A', '2'], ['E', 'Y', 'A', '2', '7']])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# read saved files for RBM and DBM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def load_and_display_image(file_path):\n", + " \"\"\"\n", + " Charge et affiche une image à partir d'un fichier .npy.\n", + " \n", + " Args:\n", + " file_path (str): Le chemin du fichier .npy à charger.\n", + " \"\"\"\n", + " # Charger l'image à partir du fichier .npy\n", + " image = np.load(file_path)\n", + " \n", + " image_data_reshaped = image.reshape((20, 16))\n", + "\n", + " # Afficher l'image\n", + " plt.imshow(image_data_reshaped, cmap='plasma') # ou 'gray ou 'viridis' ou 'inferno' ou 'plasma' ou 'magma' ou 'cividis')\n", + " plt.title(\"Loaded Image\")\n", + " plt.axis('off') # Désactiver les axes pour une meilleure visualisation\n", + " plt.show()\n", + "\n", + "# Exemple d'utilisation\n", + "file_path = \"../resultat/dbn/100_Units_2_Layers/Units_100_Chars_A.npy\"\n", + "load_and_display_image(file_path)\n", + "\n", + "file_path2 = \"../resultat/rbn/100_Units_2_Chars.npy\"\n", + "load_and_display_image(file_path2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebook/principal_RBM_alpha.ipynb b/notebook/principal_RBM_alpha.ipynb deleted file mode 100644 index 96bbc14..0000000 --- a/notebook/principal_RBM_alpha.ipynb +++ /dev/null @@ -1,981 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Restricted Botlzman Machines (RBM)" - ] - }, - { - "cell_type": "code", - "execution_count": 194, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from typing import List, Dict, Tuple, Literal, Optional, Union, Iterable\n", - "\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "import scipy.io\n", - "from tqdm import tqdm\n", - "from numpy._typing import ArrayLike\n", - "\n", - "ArrayLike = Union[List, Tuple, np.ndarray]" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_FOLDER = \"../data/\"\n", - "ALPHA_DIGIT_PATH = os.path.join(DATA_FOLDER, \"binaryalphadigs.mat\")\n", - "MNIST_PATH = os.path.join(DATA_FOLDER, \"mnist_all.mat\")\n", - "\n", - "if not os.path.exists(ALPHA_DIGIT_PATH):\n", - " raise FileNotFoundError(f\"The file {ALPHA_DIGIT_PATH} does not exist.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1 Implementing a RBM and testing on Binary AlphaDigits" - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Class LabelsClass Count
0039
1139
2239
3339
4439
5539
6639
7739
8839
9939
10A39
11B39
12C39
13D39
14E39
15F39
16G39
17H39
18I39
19J39
20K39
21L39
22M39
23N39
24O39
25P39
26Q39
27R39
28S39
29T39
30U39
31V39
32W39
33X39
34Y39
35Z39
\n", - "
" - ], - "text/plain": [ - " Class Labels Class Count\n", - "0 0 39\n", - "1 1 39\n", - "2 2 39\n", - "3 3 39\n", - "4 4 39\n", - "5 5 39\n", - "6 6 39\n", - "7 7 39\n", - "8 8 39\n", - "9 9 39\n", - "10 A 39\n", - "11 B 39\n", - "12 C 39\n", - "13 D 39\n", - "14 E 39\n", - "15 F 39\n", - "16 G 39\n", - "17 H 39\n", - "18 I 39\n", - "19 J 39\n", - "20 K 39\n", - "21 L 39\n", - "22 M 39\n", - "23 N 39\n", - "24 O 39\n", - "25 P 39\n", - "26 Q 39\n", - "27 R 39\n", - "28 S 39\n", - "29 T 39\n", - "30 U 39\n", - "31 V 39\n", - "32 W 39\n", - "33 X 39\n", - "34 Y 39\n", - "35 Z 39" - ] - }, - "execution_count": 164, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def _load_data(file_path: str) -> Dict[str, np.ndarray]:\n", - " \"\"\"\n", - " Load Binary AlphaDigits data from a .mat file.\n", - "\n", - " Parameters:\n", - " - file_path (str): Path to the .mat file containing the data.\n", - "\n", - " Returns:\n", - " - data (dict): Loaded data dictionary.\n", - " \"\"\"\n", - " if file_path is None:\n", - " raise ValueError(\"File path must be provided.\")\n", - "\n", - " return scipy.io.loadmat(file_path)\n", - "\n", - "\n", - "data = _load_data(ALPHA_DIGIT_PATH)\n", - "class_labels = data[\"classlabels\"].flatten() \n", - "class_count = data[\"classcounts\"].flatten()\n", - "df = pd.DataFrame(\n", - " {\n", - " \"Class Labels\": class_labels,\n", - " \"Class Count\": class_count\n", - " }\n", - ")\n", - "df[\"Class Labels\"] = df[\"Class Labels\"].apply(lambda x: x[0])\n", - "df[\"Class Count\"] = df[\"Class Count\"].apply(lambda x: x[0][0])\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": 176, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(36, 39)\n", - "(20, 16)\n" - ] - } - ], - "source": [ - "def _load_data(file_path: str, which: Literal[\"alphadigit\", \"mnist\"]=\"alphadigit\") -> Dict[str, np.ndarray]:\n", - " \"\"\"\n", - " Load Binary AlphaDigits data from a .mat file.\n", - "\n", - " Parameters:\n", - " - file_path (str): Path to the .mat file containing the data.\n", - " - which (Literal[\"alphadigit\", \"mnist\"], optional): Specifies \n", - " which data to load. The default value is \"alphadigit\".\n", - "\n", - " Returns:\n", - " - data (dict): A dictionary containing the loaded data.\n", - "\n", - " Raises:\n", - " - ValueError: If the file_path parameter is None.\n", - " - ValueError: If the which parameter is not \"alphadigit\".\n", - "\n", - " Example Usage:\n", - " ```python\n", - " data = _load_data(\"data.mat\", \"alphadigit\")\n", - " ```\n", - " \"\"\"\n", - " if file_path is None:\n", - " raise ValueError(\"File path must be provided.\")\n", - " \n", - " if which == \"alphadigit\":\n", - " return scipy.io.loadmat(file_path)[\"dat\"]\n", - " \n", - " raise ValueError(\"MNIST NOT YET AVAILABLE.\")\n", - "\n", - "alphadigit_data = _load_data(ALPHA_DIGIT_PATH) \n", - "print(alphadigit_data.shape)\n", - "print(alphadigit_data[0][0].shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 177, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 > map to > [0]\n", - "10 > map to > [10]\n", - "A > map to > [10]\n", - "[1, 'C'] > map to > [[1], [12]]\n", - "36 > no mapping available, out of range\n" - ] - } - ], - "source": [ - "def _map_characters_to_indices(characters: Union[str, int, List[Union[str, int]]]) -> List[int]:\n", - " \"\"\"\n", - " Map alphanumeric character to its corresponding index.\n", - "\n", - " Parameters:\n", - " - character (str, int, list of str or int): Alphanumeric character or its index.\n", - "\n", - " Returns:\n", - " - char_index (int): Corresponding index for the character.\n", - " \"\"\"\n", - " if isinstance(characters, list):\n", - " return [_map_characters_to_indices(char) for char in characters]\n", - " if isinstance(characters, int) and 0 <= characters <= 35:\n", - " return [characters]\n", - " if (isinstance(characters, str) and characters.isdigit()\n", - " and 0 <= int(characters) <= 9):\n", - " return [int(characters)]\n", - " if (isinstance(characters, str) and characters.isalpha()\n", - " and 'A' <= characters.upper() <= 'Z'):\n", - " return [ord(characters.upper()) - ord('A') + 10]\n", - " \n", - " raise ValueError(\n", - " \"Invalid character input. It should be an alphanumeric\" \n", - " \"character '[0-9|A-Z]' or its index representing '[0-35]'.\"\n", - " )\n", - "\n", - "for char in [0, 10, \"A\", [1, \"C\"], 36]:\n", - " try:\n", - " map = _map_characters_to_indices(char)\n", - " print(f\"{char} > map to > {map}\")\n", - " except:\n", - " print(f\"{char} > no mapping available, out of range\")" - ] - }, - { - "cell_type": "code", - "execution_count": 183, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(78, 320)\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAGdCAYAAAA7TzlCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiMklEQVR4nO3df2xV9f3H8dct0Ftj6K0K9PZKKaDyQ4RGme2KOr8b1dIZpM6JNk6KIhpSEg3TIMuwqMvqr5FNaHBZhGqciiYCmTocVH4MaUUpzUAdoay2ELglELm3LaM07ef7x8Z1V+4tXHpu+dzL85F8Es45n8/p+36499XDOZdzXMYYIwCAtVIudAEAgN4R1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWG7ghS7ACT09PTp06JAGDx4sl8t1ocsBgBBjjNra2uTz+ZSScn7HxkkR1IcOHVJ2dvaFLgMAojpw4ICGDx9+XmOTIqgHDx4s6T8TkZ6efoGr+Y7H47nQJQA4T4FAwJH9BINBZWdnh3LqfCRFUJ8+3ZGenm5VUANIXE5nSV9Oy3IxEQAsR1ADgOXiFtRVVVUaOXKk0tLSlJ+frx07dvTa/7333tO4ceOUlpamiRMn6qOPPopXaQCQUOIS1KtXr9aCBQtUUVGh+vp65ebmqqioSEeOHInYf/v27SotLdWcOXO0a9culZSUqKSkRHv27IlHeQCQWEwc5OXlmfLy8tByd3e38fl8prKyMmL/mTNnmjvuuCNsXX5+vnn00UfP6ecFAgEjyQQCgfMvOg4k0Wi0BG1OcSKfHD+iPnXqlHbu3KnCwsLQupSUFBUWFqq2tjbimNra2rD+klRUVBS1f2dnp4LBYFgDgGTleFAfPXpU3d3dyszMDFufmZkpv98fcYzf74+pf2VlpTweT6jxn10AJLOE/NbHokWLFAgEQu3AgQMXuiQAiBvH/8PLkCFDNGDAALW2toatb21tldfrjTjG6/XG1N/tdsvtdjtTMABYzvEj6tTUVE2ePFk1NTWhdT09PaqpqVFBQUHEMQUFBWH9JWnDhg1R+wPARcWxS5v/45133jFut9tUV1ebr776yjzyyCMmIyPD+P1+Y4wxDzzwgHnqqadC/T/99FMzcOBA8/LLL5uvv/7aVFRUmEGDBpndu3ef08/jWx80Gs3p5hQn8ikuQW2MMcuWLTMjRowwqampJi8vz9TV1YW23XrrraasrCys/7vvvmvGjBljUlNTzYQJE8yHH354zj+LoKbRaE43pziRT67/BkpCCwaD8ng8CgQCVt2UiXtjA4nLqWh0Ip8S8lsfAHAxSYrbnNrKqd/IHJkj2SXBP+zjiiNqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCW41FcAM4bj9DqHxxRA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLOR7UlZWVuvHGGzV48GANGzZMJSUl2rt3b69jqqur5XK5wlpaWprTpQFAQnI8qLds2aLy8nLV1dVpw4YN6urq0u23366Ojo5ex6Wnp+vw4cOh1tzc7HRpAJCQHH9wwPr168OWq6urNWzYMO3cuVM/+tGPoo5zuVzyer1OlwMACS/uT3gJBAKSpMsvv7zXfu3t7crJyVFPT49uuOEG/fa3v9WECRMi9u3s7FRnZ2doORgMOlcwYCmepnLxiuvFxJ6eHj3++OO66aabdN1110XtN3bsWK1cuVLr1q3Tm2++qZ6eHk2ZMkUHDx6M2L+yslIejyfUsrOz4/USAOCCc5k4/pqeN2+e/vrXv2rbtm0aPnz4OY/r6urS+PHjVVpaqueee+6M7ZGOqLOzsxUIBJSenu5I7TZxuVwXugRYgCPqxBQMBuXxePqUT3E79TF//nx98MEH2rp1a0whLUmDBg3S9ddfr8bGxojb3W633G63E2UCgPUcP/VhjNH8+fO1Zs0affLJJxo1alTM++ju7tbu3buVlZXldHkAkHAcP6IuLy/XW2+9pXXr1mnw4MHy+/2SJI/Ho0suuUSSNGvWLF155ZWqrKyUJD377LP64Q9/qKuvvlrHjx/XSy+9pObmZj388MNOlwcACcfxoF6xYoUk6f/+7//C1q9atUqzZ8+WJLW0tCgl5buD+W+//VZz586V3+/XZZddpsmTJ2v79u269tprnS4PABJOXC8m9hcnTtbbjIuJkLiYmKicyCfu9QEAliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFgu7o/iAi5m3J8DTuCIGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACzHE16QNHiaCpzkcrkudAkhHFEDgOUIagCwHEENAJYjqAHAcgQ1AFjO8aBesmSJXC5XWBs3blyvY9577z2NGzdOaWlpmjhxoj766COnywKAhBWXI+oJEybo8OHDobZt27aofbdv367S0lLNmTNHu3btUklJiUpKSrRnz554lAYACScuQT1w4EB5vd5QGzJkSNS+f/jDHzRt2jQ9+eSTGj9+vJ577jndcMMNWr58eTxKA4CEE5eg3rdvn3w+n0aPHq37779fLS0tUfvW1taqsLAwbF1RUZFqa2ujjuns7FQwGAxrAJCsHA/q/Px8VVdXa/369VqxYoWampp0yy23qK2tLWJ/v9+vzMzMsHWZmZny+/1Rf0ZlZaU8Hk+oZWdnO/oaAMAmjgd1cXGx7rnnHk2aNElFRUX66KOPdPz4cb377ruO/YxFixYpEAiE2oEDBxzbNwDYJu73+sjIyNCYMWPU2NgYcbvX61Vra2vYutbWVnm93qj7dLvdcrvdjtYJALaK+/eo29vbtX//fmVlZUXcXlBQoJqamrB1GzZsUEFBQbxLA4CE4HhQP/HEE9qyZYu++eYbbd++XXfddZcGDBig0tJSSdKsWbO0aNGiUP/HHntM69ev1+9+9zv985//1JIlS/TFF19o/vz5TpcGAAnJ8VMfBw8eVGlpqY4dO6ahQ4fq5ptvVl1dnYYOHSpJamlpUUrKd78fpkyZorfeeku//vWv9atf/UrXXHON1q5dq+uuu87p0gAgIblMEtzENxgMyuPxKBAIKD09/UKX4zib7otrsyR4K8MiTn/u+pJP3OsDACxHUAOA5XgUF84Lpxlgq2Q8VcgRNQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOJ7wkAJ6mAicl4xNQkh1H1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAco4H9ciRI+Vyuc5o5eXlEftXV1ef0TctLc3psgAgYTl+P+rPP/9c3d3doeU9e/botttu0z333BN1THp6uvbu3Rta5n65APAdx4N66NChYcvPP/+8rrrqKt16661Rx7hcLnm9XqdLAYCkENdz1KdOndKbb76phx56qNej5Pb2duXk5Cg7O1szZszQl19+Gc+yACChxPVRXGvXrtXx48c1e/bsqH3Gjh2rlStXatKkSQoEAnr55Zc1ZcoUffnllxo+fHjEMZ2dners7AwtB4NBp0sHHMFpPDjBZeL4QL6ioiKlpqbqL3/5yzmP6erq0vjx41VaWqrnnnsuYp8lS5bomWeeOWN9IBBQenr6edcLOI2gxml9yae4nfpobm7Wxo0b9fDDD8c0btCgQbr++uvV2NgYtc+iRYsUCARC7cCBA30tFwCsFbegXrVqlYYNG6Y77rgjpnHd3d3avXu3srKyovZxu91KT08PawCQrOIS1D09PVq1apXKyso0cGD4afBZs2Zp0aJFoeVnn31Wf/vb3/Svf/1L9fX1+sUvfqHm5uaYj8QBIFnF5WLixo0b1dLSooceeuiMbS0tLUpJ+e73w7fffqu5c+fK7/frsssu0+TJk7V9+3Zde+218SgNABJOXC8m9pdgMCiPx8PFRFiHi4k4zcqLiQAAZxDUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACwX1ye8AImKe3TAJhxRA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALBdzUG/dulXTp0+Xz+eTy+XS2rVrw7YbY/T0008rKytLl1xyiQoLC7Vv376z7reqqkojR45UWlqa8vPztWPHjlhLA4CkFHNQd3R0KDc3V1VVVRG3v/jii3rllVf06quv6rPPPtOll16qoqIinTx5Muo+V69erQULFqiiokL19fXKzc1VUVGRjhw5Emt5AJB8TB9IMmvWrAkt9/T0GK/Xa1566aXQuuPHjxu3223efvvtqPvJy8sz5eXloeXu7m7j8/lMZWXlOdURCASMJBMIBGJ/EUAEkmg0R1tf8snRc9RNTU3y+/0qLCwMrfN4PMrPz1dtbW3EMadOndLOnTvDxqSkpKiwsDDqmM7OTgWDwbAGAMnK0aD2+/2SpMzMzLD1mZmZoW3fd/ToUXV3d8c0prKyUh6PJ9Sys7MdqB4A7JSQ3/pYtGiRAoFAqB04cOBClwQAceNoUHu9XklSa2tr2PrW1tbQtu8bMmSIBgwYENMYt9ut9PT0sAYAycrRoB41apS8Xq9qampC64LBoD777DMVFBREHJOamqrJkyeHjenp6VFNTU3UMQBwMRkY64D29nY1NjaGlpuamtTQ0KDLL79cI0aM0OOPP67f/OY3uuaaazRq1CgtXrxYPp9PJSUloTFTp07VXXfdpfnz50uSFixYoLKyMv3gBz9QXl6efv/736ujo0MPPvhg318hACS6WL8msmnTpohfPSkrKzPG/OcreosXLzaZmZnG7XabqVOnmr1794btIycnx1RUVIStW7ZsmRkxYoRJTU01eXl5pq6u7pxr4ut5cFqk9ziN1pfWl3xy/fdNmdCCwaA8Ho8CgQDnq+EIl8t1oUtAkulLPiXktz4A4GIS8zlq9D8nj+6S4B9Q/cKpeeLIHE7giBoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOV4FNdFxqlHQ/FIr3Pj5DzxWK+LF0fUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByMQf11q1bNX36dPl8PrlcLq1duza0raurSwsXLtTEiRN16aWXyufzadasWTp06FCv+1yyZIlcLldYGzduXMwvBgCSUcxB3dHRodzcXFVVVZ2x7cSJE6qvr9fixYtVX1+v999/X3v37tWdd9551v1OmDBBhw8fDrVt27bFWhoAJKWYHxxQXFys4uLiiNs8Ho82bNgQtm758uXKy8tTS0uLRowYEb2QgQPl9XpjLQcAkl7cn/ASCATkcrmUkZHRa799+/bJ5/MpLS1NBQUFqqysjBrsnZ2d6uzsDC0Hg0EnS8Y5cPJpIzwt5tw4NU88KSbxxPVi4smTJ7Vw4UKVlpYqPT09ar/8/HxVV1dr/fr1WrFihZqamnTLLbeora0tYv/Kykp5PJ5Qy87OjtdLAIALzmX68Gva5XJpzZo1KikpOWNbV1eX7r77bh08eFCbN2/uNai/7/jx48rJydHSpUs1Z86cM7ZHOqLOzs5WIBCI6eckimQ/AuKIun8l+/vJVn3Jp7ic+ujq6tLMmTPV3NysTz75JObiMjIyNGbMGDU2Nkbc7na75Xa7nSgVAKzn+KmP0yG9b98+bdy4UVdccUXM+2hvb9f+/fuVlZXldHkAkHBiDur29nY1NDSooaFBktTU1KSGhga1tLSoq6tLP//5z/XFF1/oz3/+s7q7u+X3++X3+3Xq1KnQPqZOnarly5eHlp944glt2bJF33zzjbZv36677rpLAwYMUGlpad9fIQAkOhOjTZs2GUlntLKyMtPU1BRxmySzadOm0D5ycnJMRUVFaPnee+81WVlZJjU11Vx55ZXm3nvvNY2NjedcUyAQMJJMIBCI9eUkhGhzmiwN/etC/31frK0v+dSni4m2CAaD8ng8XExMUEnwFkwoyf5+slVf8ol7fQCA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJaL+6O4gLNx6t4T3DPk3Dg5T9w3pH9wRA0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCW4wkvSBpOPm2Ep8WcG6fmiSfF9I4jagCwHEENAJYjqAHAcgQ1AFiOoAYAy8Uc1Fu3btX06dPl8/nkcrm0du3asO2zZ8+Wy+UKa9OmTTvrfquqqjRy5EilpaUpPz9fO3bsiLU0AEhKMQd1R0eHcnNzVVVVFbXPtGnTdPjw4VB7++23e93n6tWrtWDBAlVUVKi+vl65ubkqKirSkSNHYi0PAJJOzN+jLi4uVnFxca993G63vF7vOe9z6dKlmjt3rh588EFJ0quvvqoPP/xQK1eu1FNPPRVriQCQVOJyjnrz5s0aNmyYxo4dq3nz5unYsWNR+546dUo7d+5UYWHhd0WlpKiwsFC1tbURx3R2dioYDIY1AEhWjgf1tGnT9MYbb6impkYvvPCCtmzZouLiYnV3d0fsf/ToUXV3dyszMzNsfWZmpvx+f8QxlZWV8ng8oZadne30ywAAazj+X8jvu+++0J8nTpyoSZMm6aqrrtLmzZs1depUR37GokWLtGDBgtByMBgkrAEkrbh/PW/06NEaMmSIGhsbI24fMmSIBgwYoNbW1rD1ra2tUc9zu91upaenhzUASFZxD+qDBw/q2LFjysrKirg9NTVVkydPVk1NTWhdT0+PampqVFBQEO/yAMB6MQd1e3u7Ghoa1NDQIElqampSQ0ODWlpa1N7erieffFJ1dXX65ptvVFNToxkzZujqq69WUVFRaB9Tp07V8uXLQ8sLFizQn/70J73++uv6+uuvNW/ePHV0dIS+BQIAFzUTo02bNhlJZ7SysjJz4sQJc/vtt5uhQ4eaQYMGmZycHDN37lzj9/vD9pGTk2MqKirC1i1btsyMGDHCpKammry8PFNXV3fONQUCASPJBAKBWF9OQog037T4NvSvC/333R+tL/nk+u8kJbRgMCiPx6NAIJCU56u5V2//S4KPRUK5GN7jfckn7vUBAJYjqAHAcjyKC4jAqX+Kcwrl3Dg5T8l4GoUjagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcT3gBkFScelqMTU+K4YgaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFgu5qDeunWrpk+fLp/PJ5fLpbVr14Ztd7lcEdtLL70UdZ9Lliw5o/+4ceNifjEAkIxiDuqOjg7l5uaqqqoq4vbDhw+HtZUrV8rlcunuu+/udb8TJkwIG7dt27ZYSwOApBTzgwOKi4tVXFwcdbvX6w1bXrdunX784x9r9OjRvRcycOAZYwEAcT5H3draqg8//FBz5sw5a999+/bJ5/Np9OjRuv/++9XS0hK1b2dnp4LBYFgDgGQV16B+/fXXNXjwYP3sZz/rtV9+fr6qq6u1fv16rVixQk1NTbrlllvU1tYWsX9lZaU8Hk+oZWdnx6N8axhjHGvoX9Gu2ZxPQ/9y6jMXCAT6XIvL9OHT63K5tGbNGpWUlETcPm7cON12221atmxZTPs9fvy4cnJytHTp0ohH452dners7AwtB4NBZWdnKxAIKD09PaafdbHhA5+4+EWbmILBoDweT5/yKW4Pt/373/+uvXv3avXq1TGPzcjI0JgxY9TY2Bhxu9vtltvt7muJAJAQ4nbq47XXXtPkyZOVm5sb89j29nbt379fWVlZcagMABJLzEHd3t6uhoYGNTQ0SJKamprU0NAQdvEvGAzqvffe08MPPxxxH1OnTtXy5ctDy0888YS2bNmib775Rtu3b9ddd92lAQMGqLS0NNbyACDpxHzq44svvtCPf/zj0PKCBQskSWVlZaqurpYkvfPOOzLGRA3a/fv36+jRo6HlgwcPqrS0VMeOHdPQoUN18803q66uTkOHDo21PABIOn26mGgLJ07WXyy4mJi4kuCjelFyIp+41wcAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGC5uN2PGoCznLpPC/cMSTwcUQOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlkuIJL6efWBEMBi9wJYD9+Jz0r9Pz3Zcn6yRFULe1tUmSsrOzL3AlgP08Hs+FLuGi1NbWdt5z7zJJ8AC1np4eHTp0SIMHD+71uXLBYFDZ2dk6cOCA0tPT+7HCvqHu/pWodUuJW3sy122MUVtbm3w+n1JSzu9sc1IcUaekpGj48OHn3D89PT2h3gynUXf/StS6pcStPVnr7uu/YriYCACWI6gBwHIXVVC73W5VVFTI7XZf6FJiQt39K1HrlhK3duruXVJcTASAZHZRHVEDQCIiqAHAcgQ1AFiOoAYAyyVdUFdVVWnkyJFKS0tTfn6+duzY0Wv/9957T+PGjVNaWpomTpyojz76qJ8q/Y/KykrdeOONGjx4sIYNG6aSkhLt3bu31zHV1dVyuVxhLS0trZ8q/o8lS5acUcO4ceN6HXOh51qSRo4ceUbdLpdL5eXlEftfyLneunWrpk+fLp/PJ5fLpbVr14ZtN8bo6aefVlZWli655BIVFhZq3759Z91vrJ8RJ+vu6urSwoULNXHiRF166aXy+XyaNWuWDh061Os+z+f95mTdkjR79uwzapg2bdpZ9+vEfCdVUK9evVoLFixQRUWF6uvrlZubq6KiIh05ciRi/+3bt6u0tFRz5szRrl27VFJSopKSEu3Zs6ffat6yZYvKy8tVV1enDRs2qKurS7fffrs6Ojp6HZeenq7Dhw+HWnNzcz9V/J0JEyaE1bBt27aofW2Ya0n6/PPPw2resGGDJOmee+6JOuZCzXVHR4dyc3NVVVUVcfuLL76oV155Ra+++qo+++wzXXrppSoqKtLJkyej7jPWz4jTdZ84cUL19fVavHix6uvr9f7772vv3r268847z7rfWN5vTtd92rRp08JqePvtt3vdp2PzbZJIXl6eKS8vDy13d3cbn89nKisrI/afOXOmueOOO8LW5efnm0cffTSudfbmyJEjRpLZsmVL1D6rVq0yHo+n/4qKoKKiwuTm5p5zfxvn2hhjHnvsMXPVVVeZnp6eiNttmGtjjJFk1qxZE1ru6ekxXq/XvPTSS6F1x48fN26327z99ttR9xPrZ8TpuiPZsWOHkWSam5uj9on1/dZXkeouKyszM2bMiGk/Ts130hxRnzp1Sjt37lRhYWFoXUpKigoLC1VbWxtxTG1tbVh/SSoqKoravz8EAgFJ0uWXX95rv/b2duXk5Cg7O1szZszQl19+2R/lhdm3b598Pp9Gjx6t+++/Xy0tLVH72jjXp06d0ptvvqmHHnqo15t52TDX39fU1CS/3x82px6PR/n5+VHn9Hw+I/0hEAjI5XIpIyOj136xvN/iZfPmzRo2bJjGjh2refPm6dixY1H7OjnfSRPUR48eVXd3tzIzM8PWZ2Zmyu/3Rxzj9/tj6h9vPT09evzxx3XTTTfpuuuui9pv7NixWrlypdatW6c333xTPT09mjJlig4ePNhvtebn56u6ulrr16/XihUr1NTUpFtuuSV0y9nvs22uJWnt2rU6fvy4Zs+eHbWPDXMdyel5i2VOz+czEm8nT57UwoULVVpa2utNjWJ9v8XDtGnT9MYbb6impkYvvPCCtmzZouLiYnV3d0fs7+R8J8Xd85JFeXm59uzZc9ZzbwUFBSooKAgtT5kyRePHj9cf//hHPffcc/EuU5JUXFwc+vOkSZOUn5+vnJwcvfvuu5ozZ06/1NBXr732moqLi+Xz+aL2sWGuk1VXV5dmzpwpY4xWrFjRa18b3m/33Xdf6M8TJ07UpEmTdNVVV2nz5s2aOnVqXH920hxRDxkyRAMGDFBra2vY+tbWVnm93ohjvF5vTP3jaf78+frggw+0adOmmG7ZKkmDBg3S9ddfr8bGxjhVd3YZGRkaM2ZM1BpsmmtJam5u1saNG/Xwww/HNM6GuZYUmrdY5vR8PiPxcjqkm5ubtWHDhphvbXq291t/GD16tIYMGRK1BifnO2mCOjU1VZMnT1ZNTU1oXU9Pj2pqasKOiP5XQUFBWH9J2rBhQ9T+8WCM0fz587VmzRp98sknGjVqVMz76O7u1u7du5WVlRWHCs9Ne3u79u/fH7UGG+b6f61atUrDhg3THXfcEdM4G+ZakkaNGiWv1xs2p8FgUJ999lnUOT2fz0g8nA7pffv2aePGjbriiiti3sfZ3m/94eDBgzp27FjUGhyd75guPVrunXfeMW6321RXV5uvvvrKPPLIIyYjI8P4/X5jjDEPPPCAeeqpp0L9P/30UzNw4EDz8ssvm6+//tpUVFSYQYMGmd27d/dbzfPmzTMej8ds3rzZHD58ONROnDgR6vP9up955hnz8ccfm/3795udO3ea++67z6SlpZkvv/yy3+r+5S9/aTZv3myamprMp59+agoLC82QIUPMkSNHItZsw1yf1t3dbUaMGGEWLlx4xjab5rqtrc3s2rXL7Nq1y0gyS5cuNbt27Qp9O+L55583GRkZZt26deYf//iHmTFjhhk1apT597//HdrHT37yE7Ns2bLQ8tk+I/Gu+9SpU+bOO+80w4cPNw0NDWHv+c7Ozqh1n+39Fu+629razBNPPGFqa2tNU1OT2bhxo7nhhhvMNddcY06ePBm1bqfmO6mC2hhjli1bZkaMGGFSU1NNXl6eqaurC2279dZbTVlZWVj/d99914wZM8akpqaaCRMmmA8//LBf65UUsa1atSpq3Y8//njoNWZmZpqf/vSnpr6+vl/rvvfee01WVpZJTU01V155pbn33ntNY2Nj1JqNufBzfdrHH39sJJm9e/eesc2mud60aVPE98bp+np6eszixYtNZmamcbvdZurUqWe8ppycHFNRURG2rrfPSLzrbmpqivqe37RpU9S6z/Z+i3fdJ06cMLfffrsZOnSoGTRokMnJyTFz5849I3DjNd/c5hQALJc056gBIFkR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOX+H+yZ99OUmw08AAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAGdCAYAAAA7TzlCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh+klEQVR4nO3dfWxUVf7H8c8U6NQQOlWBtiOlPCigCF0kthZl/SmV0mWRuruCjatlBTWkJmvQDbIRi5psfSD+ITa42VWqYeUpkZIoiwuVhwVaUNpGUEIoW1sITAlEZtoiLWnP749dZndkpjB0pj0d3q/kJNx7z7n99jDz4ebOcI/DGGMEALBWXG8XAADoGkENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4Dl+vd2AZHQ2dmpkydPatCgQXI4HL1dDgD4GWPU3Nwst9utuLhruzaOiaA+efKk0tLSersMAAjp+PHjGjZs2DWNjYmgHjRoUG+XAESd1+vt7RKixuVy9XYJUdednIqJoOZ2B64HiYmJvV0CuqE7OcWHiQBgOYIaACwXtaAuLS3ViBEjlJCQoKysLO3fv7/L/hs2bNC4ceOUkJCgCRMmaPPmzdEqDQD6lKgE9bp167Ro0SIVFxerurpaGRkZys3N1enTp4P237t3rwoKCjR//nzV1NQoPz9f+fn5OnToUDTKA4C+xURBZmamKSoq8m93dHQYt9ttSkpKgvafM2eOmTlzZsC+rKws8+yzz17Vz/N6vUYSjRbTLZb19tz2RPN6vdc8PxG/om5vb9eBAweUk5Pj3xcXF6ecnBxVVlYGHVNZWRnQX5Jyc3ND9m9ra5PP5wtoABCrIh7UZ86cUUdHh5KTkwP2Jycny+PxBB3j8XjC6l9SUiKXy+Vv/GcXALGsT37rY8mSJfJ6vf52/Pjx3i4JAKIm4v/hZfDgwerXr5+ampoC9jc1NSklJSXomJSUlLD6O51OOZ3OyBQMAJaL+BV1fHy8Jk+erIqKCv++zs5OVVRUKDs7O+iY7OzsgP6StHXr1pD9AeC6EsEPbv3Wrl1rnE6nKSsrM99995155plnTFJSkvF4PMYYY5544gnz0ksv+fvv2bPH9O/f3yxfvtwcPnzYFBcXmwEDBpiDBw9e1c/jWx+066HFst6e255o3fnWR9T+9lesWGGGDx9u4uPjTWZmpqmqqvIfu//++01hYWFA//Xr15sxY8aY+Ph4M378ePP5559f9c8iqGnXQ4tlvT23PdG6E9SO/0xSn+bz+a6Lp2/h+hYDb9WQrocHq3m93mt+sFaf/NYHAFxPYuIxp5EUy1ctgK1sfN/ZdJXPFTUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMuxFBcABBGp5cEisfg2V9QAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIRD+qSkhLdfffdGjRokIYOHar8/HwdOXKkyzFlZWVyOBwBLSEhIdKlAUCfFPGg3rlzp4qKilRVVaWtW7fq4sWLmj59ulpbW7scl5iYqFOnTvlbQ0NDpEsDgD4p4gsHbNmyJWC7rKxMQ4cO1YEDB/Tzn/885DiHw6GUlJRIlwMAfV7UV3jxer2SpJtuuqnLfi0tLUpPT1dnZ6fuuusu/elPf9L48eOD9m1ra1NbW5t/2+fzRa5gCzkcjt4uAYiqSK2mEqui+mFiZ2ennn/+ed1777268847Q/YbO3asPvzwQ23atEmrV69WZ2enpkyZohMnTgTtX1JSIpfL5W9paWnR+hUAoNc5TBT/KVu4cKH+/ve/a/fu3Ro2bNhVj7t48aJuv/12FRQU6PXXX7/seLAr6kiFtY3/snNFjVhn4/suUi6tmej1epWYmHhN54jarY/nnntOn332mXbt2hVWSEvSgAEDNGnSJNXV1QU97nQ65XQ6I1EmAFgv4rc+jDF67rnntHHjRn355ZcaOXJk2Ofo6OjQwYMHlZqaGunyAKDPifgVdVFRkT755BNt2rRJgwYNksfjkSS5XC7dcMMNkqQnn3xSt9xyi0pKSiRJr732mu655x7deuutOnfunN5++201NDRowYIFkS4PAPqciAf1ypUrJUn/93//F7B/1apVmjdvniSpsbFRcXH/vZj/4Ycf9PTTT8vj8ejGG2/U5MmTtXfvXt1xxx2RLg8A+pyofpjYUy7drI8EG6eDDxMR62x830VKJD5M5FkfAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5aK+FFdfE8nnasTy8wsAKbZf4zY9Y4cragCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcK7xEkU0rRADou7iiBgDLEdQAYDmCGgAsR1ADgOUIagCwXMSDetmyZXI4HAFt3LhxXY7ZsGGDxo0bp4SEBE2YMEGbN2+OdFkA0GdF5Yp6/PjxOnXqlL/t3r07ZN+9e/eqoKBA8+fPV01NjfLz85Wfn69Dhw5FozQA6HMcxhgTyRMuW7ZM5eXlqq2tvar+c+fOVWtrqz777DP/vnvuuUc/+9nP9P7771/VOXw+n1wu17WUC6AbIhwfVon0/4Pwer1KTEy8prFRuaI+evSo3G63Ro0apccff1yNjY0h+1ZWVionJydgX25uriorK0OOaWtrk8/nC2gAEKsiHtRZWVkqKyvTli1btHLlStXX12vq1Klqbm4O2t/j8Sg5OTlgX3JysjweT8ifUVJSIpfL5W9paWkR/R0AwCYRD+q8vDw9+uijmjhxonJzc7V582adO3dO69evj9jPWLJkibxer78dP348YucGANtE/VkfSUlJGjNmjOrq6oIeT0lJUVNTU8C+pqYmpaSkhDyn0+mU0+mMaJ0AYKuof4+6paVFx44dU2pqatDj2dnZqqioCNi3detWZWdnR7s0AOgbTIS98MILZseOHaa+vt7s2bPH5OTkmMGDB5vTp08bY4x54oknzEsvveTvv2fPHtO/f3+zfPlyc/jwYVNcXGwGDBhgDh48eNU/0+v1Gkk0Gq2HWyyL9Fx5vd5rriXitz5OnDihgoICnT17VkOGDNF9992nqqoqDRkyRJLU2NiouLj/XshPmTJFn3zyiV5++WX98Y9/1G233aby8nLdeeedkS4NAPqkiH+PujfwPWqgd8RAfIQU89+jBgBEDkENAJZjKS7gOhPLtyuk2FwCjytqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBwrvAB9RCyvzBKLq7JEElfUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByEQ/qESNGyOFwXNaKioqC9i8rK7usb0JCQqTLAoA+K+LPo/7qq6/U0dHh3z506JAeeughPfrooyHHJCYm6siRI/5tnk0LAP8V8aAeMmRIwPYbb7yh0aNH6/777w85xuFwKCUlJdKlAEBMiOo96vb2dq1evVpPPfVUl1fJLS0tSk9PV1pammbPnq1vv/02mmUBQJ8S1aW4ysvLde7cOc2bNy9kn7Fjx+rDDz/UxIkT5fV6tXz5ck2ZMkXffvuthg0bFnRMW1ub2tra/Ns+ny/SpQMREcvLZ0ncpuwxJoqmT59ufvnLX4Y1pr293YwePdq8/PLLIfsUFxcbSTSa9S3W9fb89qXm9XqveZ6jduujoaFB27Zt04IFC8IaN2DAAE2aNEl1dXUh+yxZskRer9ffjh8/3t1yAcBaUQvqVatWaejQoZo5c2ZY4zo6OnTw4EGlpqaG7ON0OpWYmBjQACBWRSWoOzs7tWrVKhUWFqp//8Db4E8++aSWLFni337ttdf0j3/8Q//6179UXV2t3/72t2poaAj7ShwAYlVUPkzctm2bGhsb9dRTT112rLGxUXFx//334YcfftDTTz8tj8ejG2+8UZMnT9bevXt1xx13RKM0AOhzHP/5QKBP8/l8crlcvV0GcJkYeHt1iW99XD2v13vNt2l51gcAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGC5qK7wAvSkWH+uRqTwfI6+hytqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWYykuoI9gCa3rF1fUAGA5ghoALEdQA4DlCGoAsBxBDQCWCzuod+3apVmzZsntdsvhcKi8vDzguDFGr7zyilJTU3XDDTcoJydHR48eveJ5S0tLNWLECCUkJCgrK0v79+8PtzQAiElhB3Vra6syMjJUWloa9Phbb72ld999V++//7727dungQMHKjc3VxcuXAh5znXr1mnRokUqLi5WdXW1MjIylJubq9OnT4dbHgDEHtMNkszGjRv9252dnSYlJcW8/fbb/n3nzp0zTqfTrFmzJuR5MjMzTVFRkX+7o6PDuN1uU1JSclV1eL1eI4l2nbdY19vzS+te83q91/x3H9F71PX19fJ4PMrJyfHvc7lcysrKUmVlZdAx7e3tOnDgQMCYuLg45eTkhBzT1tYmn88X0AAgVkU0qD0ejyQpOTk5YH9ycrL/2E+dOXNGHR0dYY0pKSmRy+Xyt7S0tAhUDwB26pPf+liyZIm8Xq+/HT9+vLdLAoCoiWhQp6SkSJKampoC9jc1NfmP/dTgwYPVr1+/sMY4nU4lJiYGNACIVREN6pEjRyolJUUVFRX+fT6fT/v27VN2dnbQMfHx8Zo8eXLAmM7OTlVUVIQcAwDXlXA/fWxubjY1NTWmpqbGSDLvvPOOqampMQ0NDcYYY9544w2TlJRkNm3aZL755hsze/ZsM3LkSPPjjz/6z/Hggw+aFStW+LfXrl1rnE6nKSsrM99995155plnTFJSkvF4PFdVE9/6oEl864Nmd+vOtz7CfnVv3749aBGFhYXGmH9/RW/p0qUmOTnZOJ1OM23aNHPkyJGAc6Snp5vi4uKAfStWrDDDhw838fHxJjMz01RVVV11TQQ1TSKoaXa37gS14z8vgD7N5/PJ5XL1dhnoZTHwUu4Sz6Pu27xe7zV/ntYnv/UBANcTVnhBr4vlK2GughEJXFEDgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcS3FdZ2J52atIYgkt2IQragCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgubCDeteuXZo1a5bcbrccDofKy8v9xy5evKjFixdrwoQJGjhwoNxut5588kmdPHmyy3MuW7ZMDocjoI0bNy7sXwYAYlHYQd3a2qqMjAyVlpZeduz8+fOqrq7W0qVLVV1drU8//VRHjhzRww8/fMXzjh8/XqdOnfK33bt3h1saAMSksBcOyMvLU15eXtBjLpdLW7duDdj33nvvKTMzU42NjRo+fHjoQvr3V0pKSrjlAEDMi/oKL16vVw6HQ0lJSV32O3r0qNxutxISEpSdna2SkpKQwd7W1qa2tjb/ts/ni2TJ1mFVlqvDqiyIVVH9MPHChQtavHixCgoKlJiYGLJfVlaWysrKtGXLFq1cuVL19fWaOnWqmpubg/YvKSmRy+Xyt7S0tGj9CgDQ+0w3SDIbN24Meqy9vd3MmjXLTJo0yXi93rDO+8MPP5jExETz17/+NejxCxcuGK/X62/Hjx83kmK24er09t8TjdZVCzcH/1dUbn1cvHhRc+bMUUNDg7788ssur6aDSUpK0pgxY1RXVxf0uNPplNPpjESpAGC9iN/6uBTSR48e1bZt23TzzTeHfY6WlhYdO3ZMqampkS4PAPqcsIO6paVFtbW1qq2tlSTV19ertrZWjY2Nunjxon7zm9/o66+/1t/+9jd1dHTI4/HI4/Govb3df45p06bpvffe82+/+OKL2rlzp77//nvt3btXjzzyiPr166eCgoLu/4YA0NeFe69k+/btQe+/FBYWmvr6+pD3Z7Zv3+4/R3p6uikuLvZvz50716Smppr4+Hhzyy23mLlz55q6urqrrsnr9fb6/adoNlyd3v57otG6at25R+34zwu8T/P5fHK5XL1dRtTEwF9Rj+DrebCZ1+sN+/O6S3jWBwBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYLmoL8V1PeMZHVeHZ3QAXeOKGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACzHCi/XGVZTAfoerqgBwHIENQBYjqAGAMsR1ABgOYIaACwXdlDv2rVLs2bNktvtlsPhUHl5ecDxefPmyeFwBLQZM2Zc8bylpaUaMWKEEhISlJWVpf3794dbGgDEpLCDurW1VRkZGSotLQ3ZZ8aMGTp16pS/rVmzpstzrlu3TosWLVJxcbGqq6uVkZGh3NxcnT59OtzyACD2mG6QZDZu3Biwr7Cw0MyePTus82RmZpqioiL/dkdHh3G73aakpOSqxnu9XiPJumaj3p4TGu16bV6v95rft1G5R71jxw4NHTpUY8eO1cKFC3X27NmQfdvb23XgwAHl5OT498XFxSknJ0eVlZVBx7S1tcnn8wU0AIhVEQ/qGTNm6OOPP1ZFRYXefPNN7dy5U3l5eero6Aja/8yZM+ro6FBycnLA/uTkZHk8nqBjSkpK5HK5/C0tLS3SvwYAWCPi/4X8scce8/95woQJmjhxokaPHq0dO3Zo2rRpEfkZS5Ys0aJFi/zbPp+PsAYQs6L+9bxRo0Zp8ODBqqurC3p88ODB6tevn5qamgL2NzU1KSUlJegYp9OpxMTEgAYAsSrqQX3ixAmdPXtWqampQY/Hx8dr8uTJqqio8O/r7OxURUWFsrOzo10eAFgv7KBuaWlRbW2tamtrJUn19fWqra1VY2OjWlpa9Ic//EFVVVX6/vvvVVFRodmzZ+vWW29Vbm6u/xzTpk3Te++9599etGiR/vKXv+ijjz7S4cOHtXDhQrW2tup3v/td939DAOjrwv2ayPbt24N+9aSwsNCcP3/eTJ8+3QwZMsQMGDDApKenm6efftp4PJ6Ac6Snp5vi4uKAfStWrDDDhw838fHxJjMz01RVVV11TXw97+r19pzQaNdr687X8xz/efP2aT6fTy6Xq7fLuIyNU8vzqIHe4fV6r/nzNJ71AQCWI6gBwHIsxfUTNt6uAHB944oaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALBdTK7x0Z/FIm7EgLSKJVYx6ViQW3+aKGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYLuyg3rVrl2bNmiW32y2Hw6Hy8vKA4w6HI2h7++23Q55z2bJll/UfN25c2L8MAMSisIO6tbVVGRkZKi0tDXr81KlTAe3DDz+Uw+HQr3/96y7PO378+IBxu3fvDrc0AIhJYS8ckJeXp7y8vJDHU1JSArY3bdqkBx54QKNGjeq6kP79LxsLAIjyPeqmpiZ9/vnnmj9//hX7Hj16VG63W6NGjdLjjz+uxsbGkH3b2trk8/kCGgDEqqguxfXRRx9p0KBB+tWvftVlv6ysLJWVlWns2LE6deqUXn31VU2dOlWHDh3SoEGDLutfUlKiV199NVplW4elk4Drm8N0IwUcDoc2btyo/Pz8oMfHjRunhx56SCtWrAjrvOfOnVN6erreeeedoFfjbW1tamtr82/7fD6lpaXF7JqJAPquS2smdiefonZF/c9//lNHjhzRunXrwh6blJSkMWPGqK6uLuhxp9Mpp9PZ3RIBoE+I2j3qDz74QJMnT1ZGRkbYY1taWnTs2DGlpqZGoTIA6FvCDuqWlhbV1taqtrZWklRfX6/a2tqAD/98Pp82bNigBQsWBD3HtGnT9N577/m3X3zxRe3cuVPff/+99u7dq0ceeUT9+vVTQUFBuOUBQMwJ+9bH119/rQceeMC/vWjRIklSYWGhysrKJElr166VMSZk0B47dkxnzpzxb584cUIFBQU6e/ashgwZovvuu09VVVUaMmRIuOUBQMzp1oeJtojEzXoAiIZI5BPP+gAAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCWI6gBwHIENQBYjqAGAMsR1ABgOYIaACxHUAOA5QhqALAcQQ0AliOoAcByBDUAWI6gBgDLEdQAYDmCGgAsR1ADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsFz/3i4gEowxkiSfz9fLlQBAoEu5dCmnrkVMBHVzc7MkKS0trZcrAYDgmpub5XK5rmmsw3Qn5i3R2dmpkydPatCgQXI4HCH7+Xw+paWl6fjx40pMTOzBCruHuntWX61b6ru1x3Ldxhg1NzfL7XYrLu7a7jbHxBV1XFychg0bdtX9ExMT+9SL4RLq7ll9tW6p79Yeq3Vf65X0JXyYCACWI6gBwHLXVVA7nU4VFxfL6XT2dilhoe6e1Vfrlvpu7dTdtZj4MBEAYtl1dUUNAH0RQQ0AliOoAcByBDUAWC7mgrq0tFQjRoxQQkKCsrKytH///i77b9iwQePGjVNCQoImTJigzZs391Cl/1ZSUqK7775bgwYN0tChQ5Wfn68jR450OaasrEwOhyOgJSQk9FDF/7Zs2bLLahg3blyXY3p7riVpxIgRl9XtcDhUVFQUtH9vzvWuXbs0a9Ysud1uORwOlZeXBxw3xuiVV15RamqqbrjhBuXk5Ojo0aNXPG+475FI1n3x4kUtXrxYEyZM0MCBA+V2u/Xkk0/q5MmTXZ7zWl5vkaxbkubNm3dZDTNmzLjieSMx3zEV1OvWrdOiRYtUXFys6upqZWRkKDc3V6dPnw7af+/evSooKND8+fNVU1Oj/Px85efn69ChQz1W886dO1VUVKSqqipt3bpVFy9e1PTp09Xa2trluMTERJ06dcrfGhoaeqji/xo/fnxADbt37w7Z14a5lqSvvvoqoOatW7dKkh599NGQY3prrltbW5WRkaHS0tKgx9966y29++67ev/997Vv3z4NHDhQubm5unDhQshzhvseiXTd58+fV3V1tZYuXarq6mp9+umnOnLkiB5++OErnjec11uk675kxowZATWsWbOmy3NGbL5NDMnMzDRFRUX+7Y6ODuN2u01JSUnQ/nPmzDEzZ84M2JeVlWWeffbZqNbZldOnTxtJZufOnSH7rFq1yrhcrp4rKoji4mKTkZFx1f1tnGtjjPn9739vRo8ebTo7O4Met2GujTFGktm4caN/u7Oz06SkpJi3337bv+/cuXPG6XSaNWvWhDxPuO+RSNcdzP79+40k09DQELJPuK+37gpWd2FhoZk9e3ZY54nUfMfMFXV7e7sOHDignJwc/764uDjl5OSosrIy6JjKysqA/pKUm5sbsn9P8Hq9kqSbbrqpy34tLS1KT09XWlqaZs+erW+//bYnygtw9OhRud1ujRo1So8//rgaGxtD9rVxrtvb27V69Wo99dRTXT7My4a5/qn6+np5PJ6AOXW5XMrKygo5p9fyHukJXq9XDodDSUlJXfYL5/UWLTt27NDQoUM1duxYLVy4UGfPng3ZN5LzHTNBfebMGXV0dCg5OTlgf3JysjweT9AxHo8nrP7R1tnZqeeff1733nuv7rzzzpD9xo4dqw8//FCbNm3S6tWr1dnZqSlTpujEiRM9VmtWVpbKysq0ZcsWrVy5UvX19Zo6dar/kbM/ZdtcS1J5ebnOnTunefPmhexjw1wHc2newpnTa3mPRNuFCxe0ePFiFRQUdPlQo3Bfb9EwY8YMffzxx6qoqNCbb76pnTt3Ki8vTx0dHUH7R3K+Y+LpebGiqKhIhw4duuK9t+zsbGVnZ/u3p0yZottvv11//vOf9frrr0e7TElSXl6e/88TJ05UVlaW0tPTtX79es2fP79HauiuDz74QHl5eXK73SH72DDXserixYuaM2eOjDFauXJll31teL099thj/j9PmDBBEydO1OjRo7Vjxw5NmzYtqj87Zq6oBw8erH79+qmpqSlgf1NTk1JSUoKOSUlJCat/ND333HP67LPPtH379rAe2SpJAwYM0KRJk1RXVxel6q4sKSlJY8aMCVmDTXMtSQ0NDdq2bZsWLFgQ1jgb5lqSf97CmdNreY9Ey6WQbmho0NatW8N+tOmVXm89YdSoURo8eHDIGiI53zET1PHx8Zo8ebIqKir8+zo7O1VRURFwRfS/srOzA/pL0tatW0P2jwZjjJ577jlt3LhRX375pUaOHBn2OTo6OnTw4EGlpqZGocKr09LSomPHjoWswYa5/l+rVq3S0KFDNXPmzLDG2TDXkjRy5EilpKQEzKnP59O+fftCzum1vEei4VJIHz16VNu2bdPNN98c9jmu9HrrCSdOnNDZs2dD1hDR+Q7ro0fLrV271jidTlNWVma+++4788wzz5ikpCTj8XiMMcY88cQT5qWXXvL337Nnj+nfv79Zvny5OXz4sCkuLjYDBgwwBw8e7LGaFy5caFwul9mxY4c5deqUv50/f97f56d1v/rqq+aLL74wx44dMwcOHDCPPfaYSUhIMN9++22P1f3CCy+YHTt2mPr6erNnzx6Tk5NjBg8ebE6fPh20Zhvm+pKOjg4zfPhws3jx4suO2TTXzc3NpqamxtTU1BhJ5p133jE1NTX+b0e88cYbJikpyWzatMl88803Zvbs2WbkyJHmxx9/9J/jwQcfNCtWrPBvX+k9Eu2629vbzcMPP2yGDRtmamtrA17zbW1tIeu+0ust2nU3NzebF1980VRWVpr6+nqzbds2c9ddd5nbbrvNXLhwIWTdkZrvmApqY4xZsWKFGT58uImPjzeZmZmmqqrKf+z+++83hYWFAf3Xr19vxowZY+Lj48348ePN559/3qP1SgraVq1aFbLu559/3v87Jicnm1/84hemurq6R+ueO3euSU1NNfHx8eaWW24xc+fONXV1dSFrNqb35/qSL774wkgyR44cueyYTXO9ffv2oK+NS/V1dnaapUuXmuTkZON0Os20adMu+53S09NNcXFxwL6u3iPRrru+vj7ka3779u0h677S6y3adZ8/f95Mnz7dDBkyxAwYMMCkp6ebp59++rLAjdZ885hTALBczNyjBoBYRVADgOUIagCwHEENAJYjqAHAcgQ1AFiOoAYAyxHUAGA5ghoALEdQA4DlCGoAsBxBDQCW+38oKWAqVJcDrgAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def read_alpha_digit(characters: Optional[Union[str, int, List[Union[str, int]]]] = None,\n", - " file_path: Optional[str] = ALPHA_DIGIT_PATH,\n", - " data: Optional[Dict[str, np.ndarray]] = None,\n", - " use_data: bool = False,\n", - " ) -> np.ndarray:\n", - " \"\"\"\n", - " Reads binary AlphaDigits data from a .mat file or uses already loaded data. \n", - " It extracts the data for a specified alphanumeric character or its index, and \n", - " flattens the images into one-dimensional vectors.\n", - "\n", - " Parameters:\n", - " - characters (Union[str, int, List[Union[str, int]]], optional): Alphanumeric character \n", - " or its index whose data needs to be extracted. It can be a single character or \n", - " a list of characters. Default is None.\n", - " - file_path (str, optional): Path to the .mat file containing the data. \n", - " Default is None.\n", - " - data (dict, optional): Already loaded data dictionary. \n", - " Default is None.\n", - " - use_data (bool): Flag to indicate whether to use already loaded data.\n", - " Default is False.\n", - "\n", - " Returns:\n", - " - flattened_images (numpy.ndarray): Flattened images for the specified character(s).\n", - " \"\"\"\n", - " if not use_data:\n", - " data = _load_data(file_path, which=\"alphadigit\")\n", - "\n", - " char_indices = _map_characters_to_indices(characters)\n", - "\n", - " # Select the rows corresponding to the characters indices.\n", - " char_data: np.ndarray = data[char_indices]\n", - " \n", - " # Flatten each image into a one-dimensional vector.\n", - " flattened_images = np.array([image.flatten() for image in char_data.flatten()])\n", - " return flattened_images\n", - "\n", - "char = [20, \"Z\"]\n", - "data = read_alpha_digit(char, ALPHA_DIGIT_PATH)\n", - "print(data.shape)\n", - "plt.imshow(data[0].reshape(20, 16), cmap=\"gray\")\n", - "plt.show()\n", - "plt.imshow(data[40].reshape(20, 16), cmap=\"gray\")\n", - "plt.show()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 179, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data shape: (78, 320)\n" - ] - } - ], - "source": [ - "print(\"data shape:\", data.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "metadata": {}, - "outputs": [], - "source": [ - "class RBM:\n", - " def __init__(self, n_visible: int, n_hidden: int, random_state=None) -> None:\n", - " \"\"\"\n", - " Initialize the Restricted Boltzmann Machine.\n", - "\n", - " Parameters:\n", - " - n_visible (int): Number of visible units.\n", - " - n_hidden (int): Number of hidden units.\n", - " - random_state: Random seed for reproducibility.\n", - " \"\"\"\n", - " self.n_visible = n_visible\n", - " self.n_hidden = n_hidden\n", - " \n", - " self.a = np.zeros((1, n_visible))\n", - " self.b = np.zeros((1, n_hidden))\n", - " self.rng = np.random.default_rng(random_state)\n", - " self.W = 1e-4 * self.rng.standard_normal(size=(n_visible, n_hidden))\n", - "\n", - " def _sigmoid(self, x: np.ndarray) -> np.ndarray:\n", - " \"\"\"\n", - " Sigmoid activation function.\n", - "\n", - " Parameters:\n", - " - x (numpy.ndarray): Input array.\n", - "\n", - " Returns:\n", - " - numpy.ndarray: Result of applying the sigmoid function to the input.\n", - " \"\"\"\n", - " return 1 / (1 + np.exp(-x))\n", - " \n", - " def _reconstruction_error(self, input: np.ndarray, image: np.ndarray) -> float:\n", - " \"\"\"\n", - " Compute reconstruction error.\n", - "\n", - " Parameters:\n", - " - input (numpy.ndarray): Original input data.\n", - " - image (numpy.ndarray): Reconstructed image.\n", - "\n", - " Returns:\n", - " - float: Reconstruction error.\n", - " \"\"\"\n", - " return np.round(np.power(image - input, 2).mean(), 3)\n", - "\n", - " def entree_sortie(self, data: np.ndarray) -> np.ndarray:\n", - " \"\"\"\n", - " Compute hidden units given visible units.\n", - "\n", - " Parameters:\n", - " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", - "\n", - " Returns:\n", - " - numpy.ndarray: Hidden unit activations, shape (n_samples, n_hidden).\n", - " \"\"\"\n", - " return self._sigmoid(data @ self.W + self.b)\n", - "\n", - " def sortie_entree(self, data_h: np.ndarray) -> np.ndarray:\n", - " \"\"\"\n", - " Compute visible units given hidden units.\n", - "\n", - " Parameters:\n", - " - data_h (numpy.ndarray): Hidden unit activations, shape (n_samples, n_hidden).\n", - "\n", - " Returns:\n", - " - numpy.ndarray: Reconstructed visible units, shape (n_samples, n_visible).\n", - " \"\"\"\n", - " return self._sigmoid(data_h @ self.W.T + self.a)\n", - "\n", - " def train(self, data: np.ndarray, learning_rate: float=0.1, n_epochs: int=10, batch_size: int=10, print_each=10) -> 'RBM':\n", - " \"\"\"\n", - " Train the RBM using Contrastive Divergence.\n", - "\n", - " Parameters:\n", - " - data (numpy.ndarray): Input data, shape (n_samples, n_visible).\n", - " - learning_rate (float): Learning rate for gradient descent. Default is 0.1.\n", - " - n_epochs (int): Number of training epochs. Default is 10.\n", - " - batch_size (int): Size of mini-batches. Default is 10.\n", - "\n", - " Returns:\n", - " - RBM: Trained RBM instance.\n", - " \"\"\"\n", - " n_samples = data.shape[0]\n", - " for epoch in range(n_epochs):\n", - " self.rng.shuffle(data)\n", - " for i in tqdm(range(0, n_samples, batch_size), desc=f\"Epoch {epoch}\"):\n", - " batch = data[i:i+batch_size]\n", - " pos_h_probs = self.entree_sortie(batch)\n", - " pos_v_probs = self.sortie_entree(pos_h_probs)\n", - " neg_h_probs = self.entree_sortie(pos_v_probs)\n", - " \n", - " # Update weights and biases\n", - " self.W += learning_rate * (batch.T @ pos_h_probs - pos_v_probs.T @ neg_h_probs) / batch_size\n", - " self.b += learning_rate * (pos_h_probs.mean(axis=0) - neg_h_probs.mean(axis=0))\n", - " self.a += learning_rate * (batch.mean(axis=0) - pos_v_probs.mean(axis=0))\n", - " \n", - " if epoch % print_each == 0:\n", - " tqdm.write(\n", - " f\"Reconstruction error: {self._reconstruction_error(batch, pos_v_probs)}.\")\n", - "\n", - " return self\n", - "\n", - " def generer_image(self, n_samples: int=1, n_gibbs_steps: int=1) -> np.ndarray:\n", - " \"\"\"\n", - " Generate samples from the RBM using Gibbs sampling.\n", - "\n", - " Parameters:\n", - " - n_samples (int): Number of samples to generate. Default is 1.\n", - " - n_gibbs_steps (int): Number of Gibbs sampling steps. Default is 100.\n", - "\n", - " Returns:\n", - " - numpy.ndarray: Generated samples, shape (n_samples, n_visible).\n", - " \"\"\"\n", - " samples = np.zeros((n_samples, self.n_visible))\n", - " \n", - " # Matrix of initlization value of Gibbs samples for each sample. \n", - " V = self.rng.binomial(1, self.rng.random(), size=n_samples * self.n_visible).reshape((n_samples, self.n_visible))\n", - " for i in range(n_samples):\n", - " for _ in range(n_gibbs_steps):\n", - " h_probs = self._sigmoid(V[i] @ self.W + self.b)\n", - " h = self.rng.binomial(1, h_probs)\n", - " v_probs = self._sigmoid(h @ self.W.T + self.a)\n", - " v = self.rng.binomial(1, v_probs)\n", - " samples[i] = v\n", - " return samples\n" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 0: 100%|██████████| 4/4 [00:00<00:00, 999.83it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.163.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 1333.64it/s]\n", - "Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 1000.67it/s]\n", - "Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 999.89it/s]\n", - "Epoch 4: 100%|██████████| 4/4 [00:00<00:00, 998.94it/s]\n", - "Epoch 5: 100%|██████████| 4/4 [00:00<00:00, 999.89it/s]\n", - "Epoch 6: 100%|██████████| 4/4 [00:00<00:00, 799.37it/s]\n", - "Epoch 7: 100%|██████████| 4/4 [00:00<00:00, 499.96it/s]\n", - "Epoch 8: 100%|██████████| 4/4 [00:00<00:00, 800.36it/s]\n", - "Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 798.80it/s]\n", - "Epoch 10: 100%|██████████| 4/4 [00:00<00:00, 799.56it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.141.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 11: 100%|██████████| 4/4 [00:00<00:00, 800.52it/s]\n", - "Epoch 12: 100%|██████████| 4/4 [00:00<00:00, 1000.01it/s]\n", - "Epoch 13: 100%|██████████| 4/4 [00:00<00:00, 800.02it/s]\n", - "Epoch 14: 100%|██████████| 4/4 [00:00<00:00, 999.77it/s]\n", - "Epoch 15: 100%|██████████| 4/4 [00:00<00:00, 666.82it/s]\n", - "Epoch 16: 100%|██████████| 4/4 [00:00<00:00, 801.09it/s]\n", - "Epoch 17: 100%|██████████| 4/4 [00:00<00:00, 800.52it/s]\n", - "Epoch 18: 100%|██████████| 4/4 [00:00<00:00, 800.33it/s]\n", - "Epoch 19: 100%|██████████| 4/4 [00:00<00:00, 1001.57it/s]\n", - "Epoch 20: 100%|██████████| 4/4 [00:00<00:00, 1000.19it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.106.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 21: 100%|██████████| 4/4 [00:00<00:00, 798.34it/s]\n", - "Epoch 22: 100%|██████████| 4/4 [00:00<00:00, 999.66it/s]\n", - "Epoch 23: 100%|██████████| 4/4 [00:00<00:00, 666.50it/s]\n", - "Epoch 24: 100%|██████████| 4/4 [00:00<00:00, 666.90it/s]\n", - "Epoch 25: 100%|██████████| 4/4 [00:00<00:00, 85.11it/s]\n", - "Epoch 26: 100%|██████████| 4/4 [00:00<00:00, 668.18it/s]\n", - "Epoch 27: 100%|██████████| 4/4 [00:00<00:00, 666.79it/s]\n", - "Epoch 28: 100%|██████████| 4/4 [00:00<00:00, 399.90it/s]\n", - "Epoch 29: 100%|██████████| 4/4 [00:00<00:00, 666.93it/s]\n", - "Epoch 30: 100%|██████████| 4/4 [00:00<00:00, 667.22it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.088.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 31: 100%|██████████| 4/4 [00:00<00:00, 800.06it/s]\n", - "Epoch 32: 100%|██████████| 4/4 [00:00<00:00, 799.87it/s]\n", - "Epoch 33: 100%|██████████| 4/4 [00:00<00:00, 799.79it/s]\n", - "Epoch 34: 100%|██████████| 4/4 [00:00<00:00, 666.74it/s]\n", - "Epoch 35: 100%|██████████| 4/4 [00:00<00:00, 800.10it/s]\n", - "Epoch 36: 100%|██████████| 4/4 [00:00<00:00, 800.10it/s]\n", - "Epoch 37: 100%|██████████| 4/4 [00:00<00:00, 1000.01it/s]\n", - "Epoch 38: 100%|██████████| 4/4 [00:00<00:00, 799.91it/s]\n", - "Epoch 39: 100%|██████████| 4/4 [00:00<00:00, 799.98it/s]\n", - "Epoch 40: 100%|██████████| 4/4 [00:00<00:00, 800.02it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.074.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 41: 100%|██████████| 4/4 [00:00<00:00, 1333.75it/s]\n", - "Epoch 42: 100%|██████████| 4/4 [00:00<00:00, 1000.49it/s]\n", - "Epoch 43: 100%|██████████| 4/4 [00:00<00:00, 999.89it/s]\n", - "Epoch 44: 100%|██████████| 4/4 [00:00<00:00, 798.88it/s]\n", - "Epoch 45: 100%|██████████| 4/4 [00:00<00:00, 999.89it/s]\n", - "Epoch 46: 100%|██████████| 4/4 [00:00<00:00, 1000.43it/s]\n", - "Epoch 47: 100%|██████████| 4/4 [00:00<00:00, 799.87it/s]\n", - "Epoch 48: 100%|██████████| 4/4 [00:00<00:00, 799.83it/s]\n", - "Epoch 49: 100%|██████████| 4/4 [00:00<00:00, 999.54it/s]\n", - "Epoch 50: 100%|██████████| 4/4 [00:00<00:00, 1000.13it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.057.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 51: 100%|██████████| 4/4 [00:00<00:00, 1000.13it/s]\n", - "Epoch 52: 100%|██████████| 4/4 [00:00<00:00, 999.12it/s]\n", - "Epoch 53: 100%|██████████| 4/4 [00:00<00:00, 799.07it/s]\n", - "Epoch 54: 100%|██████████| 4/4 [00:00<00:00, 799.71it/s]\n", - "Epoch 55: 100%|██████████| 4/4 [00:00<00:00, 800.02it/s]\n", - "Epoch 56: 100%|██████████| 4/4 [00:00<00:00, 799.98it/s]\n", - "Epoch 57: 100%|██████████| 4/4 [00:00<00:00, 799.91it/s]\n", - "Epoch 58: 100%|██████████| 4/4 [00:00<00:00, 799.60it/s]\n", - "Epoch 59: 100%|██████████| 4/4 [00:00<00:00, 799.87it/s]\n", - "Epoch 60: 100%|██████████| 4/4 [00:00<00:00, 799.49it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.048.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 61: 100%|██████████| 4/4 [00:00<00:00, 800.63it/s]\n", - "Epoch 62: 100%|██████████| 4/4 [00:00<00:00, 1000.07it/s]\n", - "Epoch 63: 100%|██████████| 4/4 [00:00<00:00, 500.04it/s]\n", - "Epoch 64: 100%|██████████| 4/4 [00:00<00:00, 801.09it/s]\n", - "Epoch 65: 100%|██████████| 4/4 [00:00<00:00, 666.29it/s]\n", - "Epoch 66: 100%|██████████| 4/4 [00:00<00:00, 443.71it/s]\n", - "Epoch 67: 100%|██████████| 4/4 [00:00<00:00, 666.74it/s]\n", - "Epoch 68: 100%|██████████| 4/4 [00:00<00:00, 799.79it/s]\n", - "Epoch 69: 100%|██████████| 4/4 [00:00<00:00, 800.97it/s]\n", - "Epoch 70: 100%|██████████| 4/4 [00:00<00:00, 800.02it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.031.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 71: 100%|██████████| 4/4 [00:00<00:00, 799.83it/s]\n", - "Epoch 72: 100%|██████████| 4/4 [00:00<00:00, 799.71it/s]\n", - "Epoch 73: 100%|██████████| 4/4 [00:00<00:00, 666.69it/s]\n", - "Epoch 74: 100%|██████████| 4/4 [00:00<00:00, 665.23it/s]\n", - "Epoch 75: 100%|██████████| 4/4 [00:00<00:00, 666.58it/s]\n", - "Epoch 76: 100%|██████████| 4/4 [00:00<00:00, 571.43it/s]\n", - "Epoch 77: 100%|██████████| 4/4 [00:00<00:00, 799.71it/s]\n", - "Epoch 78: 100%|██████████| 4/4 [00:00<00:00, 799.75it/s]\n", - "Epoch 79: 100%|██████████| 4/4 [00:00<00:00, 666.69it/s]\n", - "Epoch 80: 100%|██████████| 4/4 [00:00<00:00, 800.52it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.036.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 81: 100%|██████████| 4/4 [00:00<00:00, 666.95it/s]\n", - "Epoch 82: 100%|██████████| 4/4 [00:00<00:00, 799.94it/s]\n", - "Epoch 83: 100%|██████████| 4/4 [00:00<00:00, 399.98it/s]\n", - "Epoch 84: 100%|██████████| 4/4 [00:00<00:00, 571.41it/s]\n", - "Epoch 85: 100%|██████████| 4/4 [00:00<00:00, 571.43it/s]\n", - "Epoch 86: 100%|██████████| 4/4 [00:00<00:00, 571.39it/s]\n", - "Epoch 87: 100%|██████████| 4/4 [00:00<00:00, 571.37it/s]\n", - "Epoch 88: 100%|██████████| 4/4 [00:00<00:00, 137.94it/s]\n", - "Epoch 89: 100%|██████████| 4/4 [00:00<00:00, 222.20it/s]\n", - "Epoch 90: 100%|██████████| 4/4 [00:00<00:00, 571.68it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reconstruction error: 0.023.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 91: 100%|██████████| 4/4 [00:00<00:00, 363.50it/s]\n", - "Epoch 92: 100%|██████████| 4/4 [00:00<00:00, 363.65it/s]\n", - "Epoch 93: 100%|██████████| 4/4 [00:00<00:00, 666.71it/s]\n", - "Epoch 94: 100%|██████████| 4/4 [00:00<00:00, 800.97it/s]\n", - "Epoch 95: 100%|██████████| 4/4 [00:00<00:00, 799.79it/s]\n", - "Epoch 96: 100%|██████████| 4/4 [00:00<00:00, 800.02it/s]\n", - "Epoch 97: 100%|██████████| 4/4 [00:00<00:00, 666.58it/s]\n", - "Epoch 98: 100%|██████████| 4/4 [00:00<00:00, 1000.79it/s]\n", - "Epoch 99: 100%|██████████| 4/4 [00:00<00:00, 798.99it/s]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA78AAAGICAYAAACJAFemAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlk0lEQVR4nO3dfYxdZZ0H8N/0haEvtDbbIZSqLZRKg9sYbdQNIIIWq9CoicjCorYr8uILBld8CasragLB/UMIoBZRFAWjKLqrgLu4S1bJ7mbVRRBZY4stxrpLKaWAiBE6z/7RzEynM8CdOeee59znfj4JCb1z73n5Ps859/7mzPndgZRSCgAAACjYjNwbAAAAAN2m+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACK1zPF70UXXRQDAwPTeu2XvvSlGBgYiG3bttW7UfvYtm1bDAwMxJe+9KWurSMX2ecj+zzkno/s85B7PrLPR/Z5yD0f2TdU/P7iF7+It7zlLbF06dIYHByMQw89NM4444z4xS9+0cTq+5rs85F9HnLPR/Z5yD0f2ecj+zzkno/sa5K67Fvf+lY64IAD0iGHHJL+9m//Nl1zzTXpIx/5SFqyZEk64IAD0k033dTRcp588sn0xBNPTGsbnnrqqfTEE0+k4eHhab2+E1u3bk0Rka699tqurWOqZJ+P7POQez6yz0Pu+cg+H9nnIfd8ZF+frha/W7ZsSXPnzk2rVq1KO3bsGPezBx98MK1atSrNmzcv3XfffU+7jN///vfd3MTatO1AkX0+ss9D7vnIPg+55yP7fGSfh9zzkX29uvpnz3//938ff/jDH+Lqq6+OoaGhcT9bvHhxbNq0KR5//PH41Kc+FRFjf4d+7733xl/91V/FokWL4thjjx33s3098cQT8d73vjcWL14cBx10ULz+9a+P7du3x8DAQFx00UWjz5vsb9SXL18e69evjzvuuCNe9rKXxYEHHhiHH354XHfddePWsWvXrrjgggti9erVMX/+/FiwYEG87nWvi7vuuqvGpOon+3xkn4fc85F9HnLPR/b5yD4Puecj+3rN6ubCv/vd78by5cvjFa94xaQ/P+6442L58uVx8803j3v8zW9+c6xcuTIuvvjiSCk97fI3btwY3/jGN+Ktb31r/MVf/EX827/9W5x88skdb9+WLVvilFNOiTPPPDM2bNgQX/ziF2Pjxo2xZs2aeOELXxgREb/+9a/jO9/5Trz5zW+Oww47LB544IHYtGlTvPKVr4x77703Dj300I7X1yTZ5yP7POSej+zzkHs+ss9H9nnIPR/Z16xbl5R3796dIiK94Q1veMbnvf71r08RkR599NH0sY99LEVEOv300yc8b+RnI37605+miEjnn3/+uOdt3LgxRUT62Mc+NvrYtddemyIibd26dfSxZcuWpYhIP/zhD0cf27FjRxocHEzvf//7Rx/74x//mPbs2TNuHVu3bk2Dg4PpE5/4xLjHoiV/IiH7fGSfh9zzkX0ecs9H9vnIPg+55yP7+nXtz54fe+yxiIg46KCDnvF5Iz9/9NFHRx8799xzn3X53//+9yMi4l3vete4x88777yOt/Goo44a91uUoaGhOPLII+PXv/716GODg4MxY8bemPbs2RMPPfRQzJ8/P4488sj47//+747X1STZ5yP7POSej+zzkHs+ss9H9nnIPR/Z169rxe/IIIwM2tOZbFAPO+ywZ13+/fffHzNmzJjw3COOOKLjbXz+858/4bFFixbFww8/PPrv4eHh+PSnPx0rV66MwcHBWLx4cQwNDcXdd98djzzySMfrapLs85F9HnLPR/Z5yD0f2ecj+zzkno/s69e14nfhwoWxZMmSuPvuu5/xeXfffXcsXbo0FixYMPrYnDlzurVZ48ycOXPSx9M+fxd/8cUXx9/8zd/EcccdF1/96lfjn/7pn+K2226LF77whTE8PNzIdk6V7PORfR5yz0f2ecg9H9nnI/s85J6P7OvX1YZX69evj89//vNxxx13jHYZ29ePfvSj2LZtW5xzzjlTXvayZctieHg4tm7dGitXrhx9fMuWLZW2eX/f/OY344QTTogvfOEL4x7fvXt3LF68uNZ11Un2+cg+D7nnI/s85J6P7PORfR5yz0f29erqVx194AMfiDlz5sQ555wTDz300Lif7dq1K84999yYO3dufOADH5jystetWxcREZ/5zGfGPX7FFVdMf4MnMXPmzAkd0m688cbYvn17reupm+zzkX0ecs9H9nnIPR/Z5yP7POSej+zr1dUrvytXrowvf/nLccYZZ8Tq1avjzDPPjMMOOyy2bdsWX/jCF2Lnzp3xta99LVasWDHlZa9Zsybe9KY3xWWXXRYPPfTQaGvuX/3qVxERE77DarrWr18fn/jEJ+Kv//qv4+ijj46f//zncf3118fhhx9ey/K7Rfb5yD4Puecj+zzkno/s85F9HnLPR/b16mrxG7H3O6ZWrVoVl1xyyegA/dmf/VmccMIJceGFF8af//mfT3vZ1113XRxyyCHxta99Lb797W/H2rVr4+tf/3oceeSRceCBB9ay/RdeeGE8/vjjccMNN8TXv/71eMlLXhI333xzfPjDH65l+d0k+3xkn4fc85F9HnLPR/b5yD4Puecj+xp17UuUMrnzzjtTRKSvfvWruTel78g+H9nnIfd8ZJ+H3PORfT6yz0Pu+ZScfVfv+e22J554YsJjl112WcyYMSOOO+64DFvUP2Sfj+zzkHs+ss9D7vnIPh/Z5yH3fPot+67/2XM3fepTn4qf/vSnccIJJ8SsWbPi1ltvjVtvvTXOPvvseN7znpd784om+3xkn4fc85F9HnLPR/b5yD4PuefTd9nnvvRcxT//8z+nY445Ji1atCjNnj07rVixIl100UXpySefzL1pxZN9PrLPQ+75yD4Puecj+3xkn4fc8+m37AdS2q/vNAAAABSmp+/5BQAAgE4ofgEAAChepYZXdX3x8TOZ7l9lN7FtVeT6a/O259JtVXPvp/w6zarTTNp0h0WOcZxs/9t8Dh3Rpjnfyb5U2d4qWU223m5k36bjKIcmzjd1z/m2jFmvnm/adlw2rRfnfAmamidVxreJ46XNpruvrvwCAABQPMUvAAAAxVP8AgAAUDzFLwAAAMWr9D2/JTQSmEwvNIUgj1LnfN3kVK9cTZymsu5+H19znrbz2aYczjflMJbNc+UXAACA4il+AQAAKJ7iFwAAgOIpfgEAACjerCovdkM2/WayOa9ZASUzlwGgOZrTdZcrvwAAABRP8QsAAEDxFL8AAAAUT/ELAABA8So1vMql3xsMtelG+P23pfRxaCL7OtfbxHjkyqQEsutNOcet1DnTpveOKhm3aT96Ub9/voOm5Hgvacux7MovAAAAxVP8AgAAUDzFLwAAAMVT/AIAAFC8Sg2v2tR4o5NtacuN1r1gumNb95zo1TFr07FRp15r+NWEXp2jbdeWJm+aWzWjlEZH3gPr15bjINdYNLH/bck4l+k2eZ3Ka/tNE+/F012HK78AAAAUT/ELAABA8RS/AAAAFE/xCwAAQPEqNbyCOmka0C793gCjUxpgTE2d80rO5TCWQBOca+rXa40mXfkFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn4RVZaDhQP5nmUXrubWp81u2s27SvJWuiOUrpx+WzMZeryzGHjFt7GIvJ9Vpzq8m48gsAAEDxFL8AAAAUT/ELAABA8RS/AAAAFK/Yhlf93uwil1y5d3rDfEnzos590dgB6IaSzrndJqv+4n2XtiuhudVkXPkFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIpXbMMrmtHtBh1tuTkeRkx3zpvLU6OhW7u1vblhP5FJdT7LlEkTuanpl/ddV34BAAAonuIXAACA4il+AQAAKJ7iFwAAgOIV0/BKE5ru0zggn35pQtA25jzsleNYcK6iGzS3Yn/9OGb9/LnSlV8AAACKp/gFAACgeIpfAAAAilfMPb/Uyz0xZZBz5/r5/pfcZN9fjJE+JU3xWaZ/6NHRjBLmvCu/AAAAFE/xCwAAQPEUvwAAABRP8QsAAEDx+qrhVQk3aXeDhhDtomlDbzG/p0Zzq3Zx/m83+XXOXIbJabA3niu/AAAAFE/xCwAAQPEUvwAAABRP8QsAAEDxerLhlYZA06chRLvkaP7T6TpLHssquZecS9vJvhrn/zx8ZqlfE5maz2UqfVx9vnl2rvwCAABQPMUvAAAAxVP8AgAAUDzFLwAAAMXryYZXneiXm7ZH5Gqo0W85l2CyMSu5IYvmD/Sjuo9px0JnnG/qV/L7E9NjTlCFK78AAAAUT/ELAABA8RS/AAAAFE/xCwAAQPFa3/DKTe3dp8lGM3qtEUqn22v+8HR6bc73Ks2toHPmd/8w1k+vn7Nx5RcAAIDiKX4BAAAonuIXAACA4il+AQAAKF6rGl5psEMpeq1R22THVKf7MNnzch2jziH5NDHne+246gbNrdrDfKyfTNmfObFX3Tm05bNbjs9trvwCAABQPMUvAAAAxVP8AgAAUDzFLwAAAMXL1vDKDeztoUlQNb3W6KfTcay7CVYu5m1vatMcglwcB/XrdqZV3juZvpLe63PNl36Zp678AgAAUDzFLwAAAMVT/AIAAFA8xS8AAADFa6ThVd03UPfLDdlt00nuJTUcKFUpx09b9qMfG8a1JXuqy9FMz/yZnFzKYByrmW5+TeRe0vt4P3PlFwAAgOIpfgEAACie4hcAAIDiKX4BAAAoXu0Nr9zon8dkN+HnGIvSx7/0/WurKrm3Zczash1T1avbTfNKnisl71ubtOWzDEymF+Zim4+htjQMc+UXAACA4il+AQAAKJ7iFwAAgOIpfgEAAChe7Q2vaI9Obyxvy43wAM+kLc0yOuXcCsAzqfI+0el7Yq+9d3abK78AAAAUT/ELAABA8RS/AAAAFE/xCwAAQPEGko4cAAAAFM6VXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4il8AAACKp/gFAACgeIpfAAAAiqf4zWj58uWxcePG3JvRd+Sej+zzkX0ecs9H9nnIPR/Z5yH3fKaT/bSL361bt8Z73vOeeMELXhBz586NuXPnxlFHHRXvfve74+67757uYlvnlltuiYsuuij3ZoySez6yz0f2ecg9H9nnIfd8ZJ+H3PORfR6zpvOi733ve/GXf/mXMWvWrDjjjDPiRS96UcyYMSN++ctfxk033RSf/exnY+vWrbFs2bK6t7dxt9xyS1x11VWtGDS55yP7fGSfh9zzkX0ecs9H9nnIPR/Z5zPl4ve+++6L0047LZYtWxb/8i//EkuWLBn380svvTQ+85nPxIwZ7fyL6scffzzmzZuXezOmTO75yD4f2ech93xkn4fc85F9HnLPR/aZpSk6++yzU0Sk//zP/+z4Nf/zP/+T3vSmN6VFixalwcHBtGbNmvQP//AP455z7bXXpohId9xxR3rf+96XFi9enObOnZve+MY3ph07dkxY5i233JKOPfbYNHfu3DR//vx00kknpXvuuWfcczZs2JDmzZuXtmzZkl73utel+fPnpze84Q0ppZR++MMfplNOOSU973nPSwcccEB67nOfm84///z0hz/8YdzrI2LCfyP27NmTPv3pT6ejjjoqDQ4OpoMPPjidffbZadeuXeO2Y3h4OH3yk59MS5cuTXPmzEnHH398uueee9KyZcvShg0bOspQ7nlyT0n2su+/7OVuzvdb9nI35/ste7mb8/2YfUopTbn4PfTQQ9MRRxzR8fPvueeetHDhwnTUUUelSy+9NF155ZXpuOOOSwMDA+mmm24afd7IgL34xS9Or3rVq9IVV1yR3v/+96eZM2emU089ddwyr7vuujQwMJBe+9rXpiuuuCJdeumlafny5ek5z3lO2rp16+jzNmzYkAYHB9OKFSvShg0b0uc+97l03XXXpZRSOu+889JJJ52ULr744rRp06Z05plnppkzZ6ZTTjll9PX//u//nk488cQUEekrX/nK6H8j3vGOd6RZs2als846K33uc59LH/rQh9K8efPSS1/60vSnP/1p9Hkf+chHUkSkk046KV155ZXp7W9/ezr00EPT4sWLOx4wuefJPSXZy77/spe7Od9v2cvdnO+37OVuzvdj9ilNsfh95JFHUkSkN77xjRN+9vDDD6cHH3xw9L+Rqv/Vr351Wr16dfrjH/84+tzh4eF09NFHp5UrV44+NjJga9euTcPDw6OPv+9970szZ85Mu3fvTiml9Nhjj6XnPOc56ayzzhq3/v/7v/9LCxcuHPf4yG8bPvzhD0/Y3n1/KzHikksuSQMDA+n+++8ffezd7373uN9QjPjRj36UIiJdf/314x7//ve/P+7xHTt2pAMOOCCdfPLJ4/brwgsvTBHR0YDJfUyTuack+33JfkzJ2ct9jDk/puTs5T7GnB9TcvZyH2POjyk9+xFT+mPyRx99NCIi5s+fP+Fnxx9/fAwNDY3+d9VVV8WuXbviX//1X+PUU0+Nxx57LHbu3Bk7d+6Mhx56KNatWxebN2+O7du3j1vO2WefHQMDA6P/fsUrXhF79uyJ+++/PyIibrvttti9e3ecfvrpo8vbuXNnzJw5M17+8pfH7bffPmHb3vnOd054bM6cOaP///jjj8fOnTvj6KOPjpRS3Hnnnc+axY033hgLFy6ME088cdx2rFmzJubPnz+6HT/4wQ/iT3/6U5x33nnj9uv8889/1nWMkPuYJnOPkP2+ZN8f2ct9jDnfH9nLfYw53x/Zy32MOd8/2Y+YUsOrgw46KCIifv/730/42aZNm+Kxxx6LBx54IN7ylrdERMSWLVsipRQf/ehH46Mf/eiky9yxY0csXbp09N/Pf/7zx/180aJFERHx8MMPR0TE5s2bIyLiVa961aTLW7Bgwbh/z5o1K5773OdOeN5vfvOb+Lu/+7v4x3/8x9Flj3jkkUcmXfa+Nm/eHI888kgcfPDBk/58x44dERGjE23lypXjfj40NDS6b89G7mOazD1C9vuS/UQlZi/3Meb8RCVmL/cx5vxEJWYv9zHm/ESlZj9iSsXvwoULY8mSJXHPPfdM+NnLX/7yiIjYtm3b6GPDw8MREXHBBRfEunXrJl3mEUccMe7fM2fOnPR5KaVxy/zKV74ShxxyyITnzZo1fpcGBwcndEvbs2dPnHjiibFr16740Ic+FKtWrYp58+bF9u3bY+PGjaPreCbDw8Nx8MEHx/XXXz/pz4eGhp51GZ2S+5gmc4+Q/b5k3x/Zy32MOd8f2ct9jDnfH9nLfYw53z/Zj5jyVx2dfPLJcc0118R//dd/xcte9rJnfO7hhx8eERGzZ8+OtWvXTm8L97NixYqIiDj44IOnvcyf//zn8atf/Sq+/OUvx9ve9rbRx2+77bYJz9338vr+2/GDH/wgjjnmmHGX/fc38v1cmzdvHs0jIuLBBx+c8FuSZyL3se1oMvcI2e+7HbKful7MXu5j22HOT10vZi/3se0w56euF7OX+9h2mPNT16vZR0RM+QukPvjBD8bcuXPj7W9/ezzwwAMTfj7yW4WIvaEef/zxsWnTpvjf//3fCc998MEHp7r6WLduXSxYsCAuvvjiePLJJ6e1zJHfiOy7rSmluPzyyyc8d+R7rHbv3j3u8VNPPTX27NkTn/zkJye85qmnnhp9/tq1a2P27NlxxRVXjFvfZZdd9qzbuS+579V07hGyHyH7/sle7nuZ8/2Tvdz3Muf7J3u572XO91f2EdO48rty5cq44YYb4vTTT48jjzwyzjjjjHjRi14UKaXYunVr3HDDDTFjxozRvw2/6qqr4thjj43Vq1fHWWedFYcffng88MAD8R//8R/x29/+Nu66664prX/BggXx2c9+Nt761rfGS17ykjjttNNiaGgofvOb38TNN98cxxxzTFx55ZXPuIxVq1bFihUr4oILLojt27fHggUL4lvf+takvz1Ys2ZNRES8973vjXXr1sXMmTPjtNNOi1e+8pVxzjnnxCWXXBI/+9nP4jWveU3Mnj07Nm/eHDfeeGNcfvnlccopp8TQ0FBccMEFcckll8T69evjpJNOijvvvDNuvfXWWLx4ccf7Lfc8uUfIXvb9l73czfl+y17u5ny/ZS93c74fs4+ISfpOd2jLli3pne98ZzriiCPSgQcemObMmZNWrVqVzj333PSzn/1s3HPvu+++9La3vS0dcsghafbs2Wnp0qVp/fr16Zvf/Oboc0bac//4xz8e99rbb789RUS6/fbbJzy+bt26tHDhwnTggQemFStWpI0bN6af/OQno88Z+WLmydx7771p7dq1af78+Wnx4sXprLPOSnfddVeKiHTttdeOPu+pp55K5513XhoaGkoDAwMTWnVfffXVac2aNWnOnDnpoIMOSqtXr04f/OAH0+9+97vR5+zZsyd9/OMfT0uWLKn8xcxy36vp3FOS/QjZ90/2ct/LnO+f7OW+lznfP9nLfS9zvn+yH0hpn+vHAAAAUKAp3/MLAAAAvUbxCwAAQPEUvwAAABRP8QsAAEDxFL8AAAAUT/ELAABA8WZVefHAwEBHz5vs25Q6fW2dOv1Wpyr71eny6v6GqU7XkSP3Nqmae5ty3n+9da8z1/HSqVzr7USO8e+Wus9fVbJp8zfzNXGer6ot5/8mzpl1z9EmxrKT9VZ5D2riuM15Xqpbt99jq8j1GbLK8tqsiWPo6VRZT7dzznWu7vb5xpVfAAAAiqf4BQAAoHiKXwAAAIqn+AUAAKB4A6ltHTkK1EQjgbY0P+iFpi9V9Nr+tblBF53L1eSmKd3e7l7NpW7dPh+0qZFjP45vN7Wt4VVbyKU9jEU+db+3dPs878ovAAAAxVP8AgAAUDzFLwAAAMVT/AIAAFC8Sg2v2tx4Cdm1bf/b3Iyh7iYEnb62bm0b825qKvcqmfbaeDTRtKMb21P3a9uihH2oav8M2rT/xqdd739NKHV/c87lXjuO6n5fqkLDKwAAAHgail8AAACKp/gFAACgeIpfAAAAild7w6vJ5GiC1aZtq3u9vdZorE1j0Y0mAr3WrKBTbWlMUFWOJkZVsmv7uaHfVJk/ORte9ZNc74mdKmHMcjY6avN7kWO0fm1u8tYNpTYRm0xbjmVXfgEAACie4hcAAIDiKX4BAAAonuIXAACA4s3KteK6m9B08rxcjWSauEm97hvmq4xFyeqet1V00uStibnXxP63ae51e1tyNliqe3sm0+k2Tnf/mjgXtk2OhiklN2mpsm+TvbbOc0auPJt6n2viXD/d8a3yebTtx0Gu99jJcmnLZ5vJ5NyWupvudbtu6lSOzzeu/AIAAFA8xS8AAADFU/wCAABQPMUvAAAAxcvW8GoydTZWaaIhSyk6zb2TTNvUKKukJgTd3pd+OzbqHrPpLr9Knr06FjnmWluOs6lqoinXdI+Ffms405bGhrm0qelgU/pxn+tSpRnc/q+tcv7uhfN8Fd1uutdEk9ccx5krvwAAABRP8QsAAEDxFL8AAAAUT/ELAABA8RppeJXjRui6b6Au5Qb5TnNpS+5V1tu2McvRlKbKstqWX1t0eyzqbibRjXHstSYYdZ+XmjrPVTmv5XgPzNWEponzf1uaH7VpTuSUowFZmxp6tl2O98kmlte2z0X9VOvUOQ9c+QUAAKB4il8AAACKp/gFAACgeIpfAAAAitdIw6u2NAmosh11N5lo803ldWti/JtoqlJ13d3e5yrrrHsu99P87lSuedKUtjRYqqIXztW9Oj+Ynv3Hu8p8bMtnsTrk+AzRRAM/nysnqruhW6k5NaWJhpfd5sovAAAAxVP8AgAAUDzFLwAAAMVT/AIAAFC8Rhpe1X1zdJ03q1dplNTmm7lz6XRs6h7/nM0fmpgHnexL3Q3dqqyjlIYSdR73Vc4rJZ1/StqXHJqYM9NtNtbE+b/T5ZWik8aGjqm96j42Osm+7m2bTL+N5XT3ty3NR7u1zBxN3kptaOrKLwAAAMVT/AIAAFA8xS8AAADFU/wCAABQvEYaXk2myo3bORrs9FvDgemq0lCl0zGs0iSibTfqV2nQkKPxW0nZ7y9H05h+bF7T5iYvvZB9E03suq3uTHvxfDOZ6TY2ZK9uN01romlV28433ZajqWSV5TWliQaAnawjV5O3bnPlFwAAgOIpfgEAACie4hcAAIDi1X7Pb64vZp6uOr8A/emel0uOsWjintC67w1uSt33atW5fyUfB5Pp9jY3cW9YSXIc070wb9u0jfuvt8q2VRnvNmVSRbc/x7CXnOlEyeeaiOnvX5VzdZvvF3blFwAAgOIpfgEAACie4hcAAIDiKX4BAAAoXu0Nr5r4Yubp6rcvI2/Ll2S3udFTHTq9Mb/ThgPTzUuzmWZ00uSt3zKZirbMobrndzfev6qsu+7mg9Pdv7aMN+Wrcgyak/VqS2PSKssvaU5UaWA63de1OXtXfgEAACie4hcAAIDiKX4BAAAonuIXAACA4tXe8KoJ022yVOWG75JufN9fjkZjbWmGkFuV+bf/85po9lH6eOyv384FdWvLObdKU7pOVTnP5VTnNjZxvLTp+KuziUyn6m5k2aY8p0Jzq/bo9nFQdyPCXh3/XpvzdWdf53nUlV8AAACKp/gFAACgeIpfAAAAiqf4BQAAoHjZGl5VuXG5k5uo625cUooqjabqbFZQ9zpzNptpYr50e19KagrRiW7vW7/l+XTa0tyqH+WYg7nOw7nOwTn2t8q49uo5qIk55Jw9fVWOg06e16vNBKtoorlVtxvxdbvOq7IOV34BAAAonuIXAACA4il+AQAAKJ7iFwAAgOI10vCqiRu361R3k6U2NU2ou6lUndqeXTfUmX0TDVna0vSlKaXPv37RRDO9puZKW5oA1t1Usu3njCpNpeocM+ekqamzGVg/fkbpRN2fIdvcNK4pTZwjpzvnm9Dt8XDlFwAAgOIpfgEAACie4hcAAIDiKX4BAAAoXqWGV22+abzKDfh1N/Jok7rHbLo3wzfR+KAXm6pMRY7mNVWa/+TKvtsN3Tpdflsay3VLp/Mgxxyqct5r+/yOaKYZU12vezpVxqPt55Yqy9s/l1xzuVfPSzmUnF8TnyE7WV7d87ZtY1b3+06Oc2Sbzy2u/AIAAFA8xS8AAADFU/wCAABQPMUvAAAAxavU8KqJph85btJue2OaEm6Eb2KdbWtuVWV7ptsIpW5taVZQVZ1NMJpo6lBK7hH1NrFr4lzdCzn3QlOufbX9PTaXTsYsV6O2ps5BbZmjVZQyb+ueG3V+Js3VrDbn/Kw75072udufW3Nx5RcAAIDiKX4BAAAonuIXAACA4il+AQAAKF6lhle5dLuRR6fLz3Xje5sbmVRp/tOWfZiqKs0F6sym00YCuRo0NaHKfnS7QVwJzQCfSd2NN+rcv7ZlVbccTUTadA5vU/PNbjeVqnu/mng/mIo2n0fqbjbWqTY3Dopoz2fyurUp46fT7WOjTef5OrnyCwAAQPEUvwAAABRP8QsAAEDxFL8AAAAUbyCVcOdyA9p003ebtqXb2ravbWq8kCOHNu3/ZHotk7aM6zNpooldCc04urEtTcz76Y5vE+uk/9Q9N3LMtSbOhU3kNBnHaTPachyUeq525RcAAIDiKX4BAAAonuIXAACA4il+AQAAKF6rGl7VeWN1rkYhda+j0/V22hChU000WOhk+Z1qanvbfPN/yY2X2tQEpC25NHX+qbIeTVSqaVP2dTa8Yvrqfv9v2zi2ucFeyXK9x9K5fs++zv135RcAAIDiKX4BAAAonuIXAACA4il+AQAAKF6rGl4BAABAN7jyCwAAQPEUvwAAABRP8QsAAEDxFL8AAAAUT/ELAABA8RS/AAAAFE/xCwAAQPEUvwAAABRP8QsAAEDx/h/ScFlJHos8KwAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Load the alpha_digit data\n", - "data = read_alpha_digit(file_path=ALPHA_DIGIT_PATH, characters='Z')\n", - "\n", - "# Initialize RBM\n", - "n_visible = data.shape[1] # Number of visible units (size of each image)\n", - "n_hidden = 100 # Number of hidden units (hyperparameter)\n", - "rbm = RBM(n_visible=n_visible, n_hidden=n_hidden, random_state=42)\n", - "\n", - "# Train RBM\n", - "rbm.train(data, learning_rate=0.1, n_epochs=100, batch_size=10)\n", - "\n", - "# Generate samples\n", - "generated_samples = rbm.generer_image(n_samples=10, n_gibbs_steps=1)\n", - "\n", - "# Plot original and generated samples\n", - "plt.figure(figsize=(12, 6))\n", - "for i in range(10):\n", - " plt.subplot(2, 10, i + 1)\n", - " plt.imshow(data[i].reshape(20, 16), cmap='gray')\n", - " plt.title('Original')\n", - " plt.axis('off')\n", - " \n", - " plt.subplot(2, 10, i + 11)\n", - " plt.imshow(generated_samples[i].reshape(20, 16), cmap='gray')\n", - " plt.title('Generated')\n", - " plt.axis('off')\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/requirements.txt b/requirements.txt index 011b590..e3e4e7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ numpy scipy -numba \ No newline at end of file +numba +matplotlib +pandas +tqdm +torch \ No newline at end of file diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_A.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_A.npy new file mode 100644 index 0000000..1826de6 Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_A.npy differ diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY.npy new file mode 100644 index 0000000..c2b335e Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY.npy differ diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY2.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY2.npy new file mode 100644 index 0000000..7cbff89 Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EY2.npy differ diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG2.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG2.npy new file mode 100644 index 0000000..e1a4542 Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG2.npy differ diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG27.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG27.npy new file mode 100644 index 0000000..41d5cf1 Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_EYG27.npy differ diff --git a/resultat/dbn/100_Units_2_Layers/Units_100_Chars_Y.npy b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_Y.npy new file mode 100644 index 0000000..57059b4 Binary files /dev/null and b/resultat/dbn/100_Units_2_Layers/Units_100_Chars_Y.npy differ diff --git a/resultat/dbn/100_Units_3_Layers/Units_100_Chars_Y.npy b/resultat/dbn/100_Units_3_Layers/Units_100_Chars_Y.npy new file mode 100644 index 0000000..9314dc4 Binary files /dev/null and b/resultat/dbn/100_Units_3_Layers/Units_100_Chars_Y.npy differ diff --git a/resultat/dbn/100_Units_4_Layers/Units_100_Chars_Y.npy b/resultat/dbn/100_Units_4_Layers/Units_100_Chars_Y.npy new file mode 100644 index 0000000..d16cda8 Binary files /dev/null and b/resultat/dbn/100_Units_4_Layers/Units_100_Chars_Y.npy differ diff --git a/resultat/dbn/100_Units_5_Layers/Units_100_Chars_Y.npy b/resultat/dbn/100_Units_5_Layers/Units_100_Chars_Y.npy new file mode 100644 index 0000000..ddfe1a2 Binary files /dev/null and b/resultat/dbn/100_Units_5_Layers/Units_100_Chars_Y.npy differ diff --git a/resultat/dbn/200_200_Units_2_Chars/Sample_0_Chars_AB.npy b/resultat/dbn/200_200_Units_2_Chars/Sample_0_Chars_AB.npy new file mode 100644 index 0000000..e2137c9 Binary files /dev/null and b/resultat/dbn/200_200_Units_2_Chars/Sample_0_Chars_AB.npy differ diff --git a/resultat/dbn/200_200_Units_2_Chars/Sample_1_Chars_AB.npy b/resultat/dbn/200_200_Units_2_Chars/Sample_1_Chars_AB.npy new file mode 100644 index 0000000..87f698e Binary files /dev/null and b/resultat/dbn/200_200_Units_2_Chars/Sample_1_Chars_AB.npy differ diff --git a/resultat/dbn/200_200_Units_2_Chars/Sample_2_Chars_AB.npy b/resultat/dbn/200_200_Units_2_Chars/Sample_2_Chars_AB.npy new file mode 100644 index 0000000..a3a7a1f Binary files /dev/null and b/resultat/dbn/200_200_Units_2_Chars/Sample_2_Chars_AB.npy differ diff --git a/resultat/dbn/200_200_Units_2_Chars/Sample_3_Chars_AB.npy b/resultat/dbn/200_200_Units_2_Chars/Sample_3_Chars_AB.npy new file mode 100644 index 0000000..bc231d8 Binary files /dev/null and b/resultat/dbn/200_200_Units_2_Chars/Sample_3_Chars_AB.npy differ diff --git a/resultat/dbn/200_200_Units_2_Chars/Sample_4_Chars_AB.npy b/resultat/dbn/200_200_Units_2_Chars/Sample_4_Chars_AB.npy new file mode 100644 index 0000000..bc9f411 Binary files /dev/null and b/resultat/dbn/200_200_Units_2_Chars/Sample_4_Chars_AB.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_1234.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_1234.npy new file mode 100644 index 0000000..519fcbb Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_1234.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_AB12.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_AB12.npy new file mode 100644 index 0000000..cd75ee0 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_0_Chars_AB12.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_1234.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_1234.npy new file mode 100644 index 0000000..6cfa5e5 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_1234.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_AB12.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_AB12.npy new file mode 100644 index 0000000..1ff8eba Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_1_Chars_AB12.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_1234.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_1234.npy new file mode 100644 index 0000000..dc1efcf Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_1234.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_AB12.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_AB12.npy new file mode 100644 index 0000000..0ab13b2 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_2_Chars_AB12.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_1234.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_1234.npy new file mode 100644 index 0000000..ecd7dad Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_1234.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_AB12.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_AB12.npy new file mode 100644 index 0000000..c93ed94 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_3_Chars_AB12.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_1234.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_1234.npy new file mode 100644 index 0000000..d623204 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_1234.npy differ diff --git a/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_AB12.npy b/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_AB12.npy new file mode 100644 index 0000000..ce5a276 Binary files /dev/null and b/resultat/dbn/200_200_Units_4_Chars/Sample_4_Chars_AB12.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_A.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_A.npy new file mode 100644 index 0000000..c35c2a5 Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_A.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY.npy new file mode 100644 index 0000000..986a711 Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY2.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY2.npy new file mode 100644 index 0000000..db28dcc Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EY2.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG2.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG2.npy new file mode 100644 index 0000000..bd9d396 Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG2.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG27.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG27.npy new file mode 100644 index 0000000..51c2d12 Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_EYG27.npy differ diff --git a/resultat/dbn/200_Units_2_Layers/Units_200_Chars_Y.npy b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_Y.npy new file mode 100644 index 0000000..d90028b Binary files /dev/null and b/resultat/dbn/200_Units_2_Layers/Units_200_Chars_Y.npy differ diff --git a/resultat/dbn/200_Units_3_Layers/Units_200_Chars_Y.npy b/resultat/dbn/200_Units_3_Layers/Units_200_Chars_Y.npy new file mode 100644 index 0000000..540ba77 Binary files /dev/null and b/resultat/dbn/200_Units_3_Layers/Units_200_Chars_Y.npy differ diff --git a/resultat/dbn/200_Units_4_Layers/Units_200_Chars_Y.npy b/resultat/dbn/200_Units_4_Layers/Units_200_Chars_Y.npy new file mode 100644 index 0000000..cec1ecc Binary files /dev/null and b/resultat/dbn/200_Units_4_Layers/Units_200_Chars_Y.npy differ diff --git a/resultat/dbn/200_Units_5_Layers/Units_200_Chars_Y.npy b/resultat/dbn/200_Units_5_Layers/Units_200_Chars_Y.npy new file mode 100644 index 0000000..5c85713 Binary files /dev/null and b/resultat/dbn/200_Units_5_Layers/Units_200_Chars_Y.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_A.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_A.npy new file mode 100644 index 0000000..e33f092 Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_A.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY.npy new file mode 100644 index 0000000..a2408ce Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY2.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY2.npy new file mode 100644 index 0000000..9833027 Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EY2.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG2.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG2.npy new file mode 100644 index 0000000..7baaeed Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG2.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG27.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG27.npy new file mode 100644 index 0000000..c7a32e2 Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_EYG27.npy differ diff --git a/resultat/dbn/300_Units_2_Layers/Units_300_Chars_Y.npy b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_Y.npy new file mode 100644 index 0000000..ec84c14 Binary files /dev/null and b/resultat/dbn/300_Units_2_Layers/Units_300_Chars_Y.npy differ diff --git a/resultat/dbn/300_Units_3_Layers/Units_300_Chars_Y.npy b/resultat/dbn/300_Units_3_Layers/Units_300_Chars_Y.npy new file mode 100644 index 0000000..103f357 Binary files /dev/null and b/resultat/dbn/300_Units_3_Layers/Units_300_Chars_Y.npy differ diff --git a/resultat/dbn/300_Units_4_Layers/Units_300_Chars_Y.npy b/resultat/dbn/300_Units_4_Layers/Units_300_Chars_Y.npy new file mode 100644 index 0000000..887c1ea Binary files /dev/null and b/resultat/dbn/300_Units_4_Layers/Units_300_Chars_Y.npy differ diff --git a/resultat/dbn/300_Units_5_Layers/Units_300_Chars_Y.npy b/resultat/dbn/300_Units_5_Layers/Units_300_Chars_Y.npy new file mode 100644 index 0000000..57ea2b4 Binary files /dev/null and b/resultat/dbn/300_Units_5_Layers/Units_300_Chars_Y.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_0_Chars_E.npy b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_0_Chars_E.npy new file mode 100644 index 0000000..3c983c3 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_0_Chars_E.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_1_Chars_E.npy b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_1_Chars_E.npy new file mode 100644 index 0000000..a4cbe17 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_1_Chars_E.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_2_Chars_E.npy b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_2_Chars_E.npy new file mode 100644 index 0000000..2edd270 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_2_Chars_E.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_3_Chars_E.npy b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_3_Chars_E.npy new file mode 100644 index 0000000..0f9451d Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_3_Chars_E.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_4_Chars_E.npy b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_4_Chars_E.npy new file mode 100644 index 0000000..1f88607 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_1_Chars/Sample_4_Chars_E.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_0_Chars_EY.npy b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_0_Chars_EY.npy new file mode 100644 index 0000000..e239d66 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_0_Chars_EY.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_1_Chars_EY.npy b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_1_Chars_EY.npy new file mode 100644 index 0000000..cd2f9ab Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_1_Chars_EY.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_2_Chars_EY.npy b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_2_Chars_EY.npy new file mode 100644 index 0000000..180b18e Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_2_Chars_EY.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_3_Chars_EY.npy b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_3_Chars_EY.npy new file mode 100644 index 0000000..27528a2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_3_Chars_EY.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_4_Chars_EY.npy b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_4_Chars_EY.npy new file mode 100644 index 0000000..116edac Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_2_Chars/Sample_4_Chars_EY.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_0_Chars_EYA.npy b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_0_Chars_EYA.npy new file mode 100644 index 0000000..2b6382e Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_0_Chars_EYA.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_1_Chars_EYA.npy b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_1_Chars_EYA.npy new file mode 100644 index 0000000..7a96bb2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_1_Chars_EYA.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_2_Chars_EYA.npy b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_2_Chars_EYA.npy new file mode 100644 index 0000000..476e4ab Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_2_Chars_EYA.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_3_Chars_EYA.npy b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_3_Chars_EYA.npy new file mode 100644 index 0000000..be828f2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_3_Chars_EYA.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_4_Chars_EYA.npy b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_4_Chars_EYA.npy new file mode 100644 index 0000000..477f9e2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_3_Chars/Sample_4_Chars_EYA.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_0_Chars_EYA2.npy b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_0_Chars_EYA2.npy new file mode 100644 index 0000000..4e9ad91 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_0_Chars_EYA2.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_1_Chars_EYA2.npy b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_1_Chars_EYA2.npy new file mode 100644 index 0000000..bf32ee1 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_1_Chars_EYA2.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_2_Chars_EYA2.npy b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_2_Chars_EYA2.npy new file mode 100644 index 0000000..bd39cc2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_2_Chars_EYA2.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_3_Chars_EYA2.npy b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_3_Chars_EYA2.npy new file mode 100644 index 0000000..306c28f Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_3_Chars_EYA2.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_4_Chars_EYA2.npy b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_4_Chars_EYA2.npy new file mode 100644 index 0000000..de0f746 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_4_Chars/Sample_4_Chars_EYA2.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_0_Chars_EYA27.npy b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_0_Chars_EYA27.npy new file mode 100644 index 0000000..f2f5ab2 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_0_Chars_EYA27.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_1_Chars_EYA27.npy b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_1_Chars_EYA27.npy new file mode 100644 index 0000000..401cfa6 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_1_Chars_EYA27.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_2_Chars_EYA27.npy b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_2_Chars_EYA27.npy new file mode 100644 index 0000000..1655508 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_2_Chars_EYA27.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_3_Chars_EYA27.npy b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_3_Chars_EYA27.npy new file mode 100644 index 0000000..8b774ac Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_3_Chars_EYA27.npy differ diff --git a/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_4_Chars_EYA27.npy b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_4_Chars_EYA27.npy new file mode 100644 index 0000000..cc75825 Binary files /dev/null and b/resultat/dbn/400_400_400_400_Units_5_Chars/Sample_4_Chars_EYA27.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_A.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_A.npy new file mode 100644 index 0000000..8ad8291 Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_A.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY.npy new file mode 100644 index 0000000..60f2807 Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY2.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY2.npy new file mode 100644 index 0000000..3223d34 Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EY2.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG2.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG2.npy new file mode 100644 index 0000000..8181689 Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG2.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG27.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG27.npy new file mode 100644 index 0000000..99affe9 Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_EYG27.npy differ diff --git a/resultat/dbn/400_Units_2_Layers/Units_400_Chars_Y.npy b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_Y.npy new file mode 100644 index 0000000..58fe37a Binary files /dev/null and b/resultat/dbn/400_Units_2_Layers/Units_400_Chars_Y.npy differ diff --git a/resultat/dbn/400_Units_3_Layers/Units_400_Chars_Y.npy b/resultat/dbn/400_Units_3_Layers/Units_400_Chars_Y.npy new file mode 100644 index 0000000..6904306 Binary files /dev/null and b/resultat/dbn/400_Units_3_Layers/Units_400_Chars_Y.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_Y.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_Y.npy new file mode 100644 index 0000000..046c7c4 Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_Y.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YA.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YA.npy new file mode 100644 index 0000000..ea8b68b Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YA.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB.npy new file mode 100644 index 0000000..5f3e79f Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB1.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB1.npy new file mode 100644 index 0000000..0800bc8 Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB1.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB12.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB12.npy new file mode 100644 index 0000000..0a965dd Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAB12.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAZ.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAZ.npy new file mode 100644 index 0000000..7a7c67d Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YAZ.npy differ diff --git a/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YUZ.npy b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YUZ.npy new file mode 100644 index 0000000..7767ee0 Binary files /dev/null and b/resultat/dbn/400_Units_4_Layers/Units_400_Chars_YUZ.npy differ diff --git a/resultat/dbn/400_Units_5_Layers/Units_400_Chars_Y.npy b/resultat/dbn/400_Units_5_Layers/Units_400_Chars_Y.npy new file mode 100644 index 0000000..7009eb8 Binary files /dev/null and b/resultat/dbn/400_Units_5_Layers/Units_400_Chars_Y.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_A.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_A.npy new file mode 100644 index 0000000..e6f903a Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_A.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY.npy new file mode 100644 index 0000000..3012339 Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY2.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY2.npy new file mode 100644 index 0000000..d63b559 Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EY2.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG2.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG2.npy new file mode 100644 index 0000000..76ea8b4 Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG2.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG27.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG27.npy new file mode 100644 index 0000000..34f45b9 Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_EYG27.npy differ diff --git a/resultat/dbn/500_Units_2_Layers/Units_500_Chars_Y.npy b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_Y.npy new file mode 100644 index 0000000..4507820 Binary files /dev/null and b/resultat/dbn/500_Units_2_Layers/Units_500_Chars_Y.npy differ diff --git a/resultat/dbn/500_Units_3_Layers/Units_500_Chars_Y.npy b/resultat/dbn/500_Units_3_Layers/Units_500_Chars_Y.npy new file mode 100644 index 0000000..1929520 Binary files /dev/null and b/resultat/dbn/500_Units_3_Layers/Units_500_Chars_Y.npy differ diff --git a/resultat/dbn/500_Units_4_Layers/Units_500_Chars_Y.npy b/resultat/dbn/500_Units_4_Layers/Units_500_Chars_Y.npy new file mode 100644 index 0000000..7a4cc96 Binary files /dev/null and b/resultat/dbn/500_Units_4_Layers/Units_500_Chars_Y.npy differ diff --git a/resultat/dbn/500_Units_5_Layers/Units_500_Chars_Y.npy b/resultat/dbn/500_Units_5_Layers/Units_500_Chars_Y.npy new file mode 100644 index 0000000..d8b7538 Binary files /dev/null and b/resultat/dbn/500_Units_5_Layers/Units_500_Chars_Y.npy differ diff --git a/resultat/dbn/600_Units_2_Layers/Units_600_Chars_A.npy b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_A.npy new file mode 100644 index 0000000..16dc32a Binary files /dev/null and b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_A.npy differ diff --git a/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY.npy b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY.npy new file mode 100644 index 0000000..0d09f52 Binary files /dev/null and b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY.npy differ diff --git a/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY2.npy b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY2.npy new file mode 100644 index 0000000..f4294f3 Binary files /dev/null and b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EY2.npy differ diff --git a/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EYG2.npy b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EYG2.npy new file mode 100644 index 0000000..5674001 Binary files /dev/null and b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_EYG2.npy differ diff --git a/resultat/dbn/600_Units_2_Layers/Units_600_Chars_Y.npy b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_Y.npy new file mode 100644 index 0000000..e5d4c15 Binary files /dev/null and b/resultat/dbn/600_Units_2_Layers/Units_600_Chars_Y.npy differ diff --git a/resultat/dbn/600_Units_3_Layers/Units_600_Chars_Y.npy b/resultat/dbn/600_Units_3_Layers/Units_600_Chars_Y.npy new file mode 100644 index 0000000..a6290a7 Binary files /dev/null and b/resultat/dbn/600_Units_3_Layers/Units_600_Chars_Y.npy differ diff --git a/resultat/dbn/600_Units_4_Layers/Units_600_Chars_Y.npy b/resultat/dbn/600_Units_4_Layers/Units_600_Chars_Y.npy new file mode 100644 index 0000000..83ad53a Binary files /dev/null and b/resultat/dbn/600_Units_4_Layers/Units_600_Chars_Y.npy differ diff --git a/resultat/dbn/600_Units_5_Layers/Units_600_Chars_Y.npy b/resultat/dbn/600_Units_5_Layers/Units_600_Chars_Y.npy new file mode 100644 index 0000000..d617a85 Binary files /dev/null and b/resultat/dbn/600_Units_5_Layers/Units_600_Chars_Y.npy differ diff --git a/resultat/dbn/700_Units_2_Layers/Units_700_Chars_A.npy b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_A.npy new file mode 100644 index 0000000..00f799b Binary files /dev/null and b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_A.npy differ diff --git a/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY.npy b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY.npy new file mode 100644 index 0000000..949d1e3 Binary files /dev/null and b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY.npy differ diff --git a/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY2.npy b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY2.npy new file mode 100644 index 0000000..066267c Binary files /dev/null and b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EY2.npy differ diff --git a/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EYG2.npy b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EYG2.npy new file mode 100644 index 0000000..4d0b68f Binary files /dev/null and b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_EYG2.npy differ diff --git a/resultat/dbn/700_Units_2_Layers/Units_700_Chars_Y.npy b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_Y.npy new file mode 100644 index 0000000..f474c1a Binary files /dev/null and b/resultat/dbn/700_Units_2_Layers/Units_700_Chars_Y.npy differ diff --git a/resultat/dbn/700_Units_3_Layers/Units_700_Chars_Y.npy b/resultat/dbn/700_Units_3_Layers/Units_700_Chars_Y.npy new file mode 100644 index 0000000..a895c3b Binary files /dev/null and b/resultat/dbn/700_Units_3_Layers/Units_700_Chars_Y.npy differ diff --git a/resultat/dbn/700_Units_4_Layers/Units_700_Chars_Y.npy b/resultat/dbn/700_Units_4_Layers/Units_700_Chars_Y.npy new file mode 100644 index 0000000..051eaee Binary files /dev/null and b/resultat/dbn/700_Units_4_Layers/Units_700_Chars_Y.npy differ diff --git a/resultat/dbn/700_Units_5_Layers/Units_700_Chars_Y.npy b/resultat/dbn/700_Units_5_Layers/Units_700_Chars_Y.npy new file mode 100644 index 0000000..91b178d Binary files /dev/null and b/resultat/dbn/700_Units_5_Layers/Units_700_Chars_Y.npy differ diff --git a/resultat/images/dbn/1_2_3_4/dbn_4_chars_200_200_Units.png b/resultat/images/dbn/1_2_3_4/dbn_4_chars_200_200_Units.png new file mode 100644 index 0000000..59f1f9e Binary files /dev/null and b/resultat/images/dbn/1_2_3_4/dbn_4_chars_200_200_Units.png differ diff --git a/resultat/images/dbn/A_B/dbn_2_chars_200_200_Units.png b/resultat/images/dbn/A_B/dbn_2_chars_200_200_Units.png new file mode 100644 index 0000000..170a51c Binary files /dev/null and b/resultat/images/dbn/A_B/dbn_2_chars_200_200_Units.png differ diff --git a/resultat/images/dbn/A_B_1_2/dbn_4_chars_200_200_Units.png b/resultat/images/dbn/A_B_1_2/dbn_4_chars_200_200_Units.png new file mode 100644 index 0000000..1eb25cf Binary files /dev/null and b/resultat/images/dbn/A_B_1_2/dbn_4_chars_200_200_Units.png differ diff --git a/resultat/images/dbn/E/dbn_1_chars_400_400_400_400_Units.png b/resultat/images/dbn/E/dbn_1_chars_400_400_400_400_Units.png new file mode 100644 index 0000000..78abaa7 Binary files /dev/null and b/resultat/images/dbn/E/dbn_1_chars_400_400_400_400_Units.png differ diff --git a/resultat/images/dbn/E_Y/dbn_2_chars_400_400_400_400_Units.png b/resultat/images/dbn/E_Y/dbn_2_chars_400_400_400_400_Units.png new file mode 100644 index 0000000..c94461b Binary files /dev/null and b/resultat/images/dbn/E_Y/dbn_2_chars_400_400_400_400_Units.png differ diff --git a/resultat/images/dbn/E_Y_A/dbn_3_chars_400_400_400_400_Units.png b/resultat/images/dbn/E_Y_A/dbn_3_chars_400_400_400_400_Units.png new file mode 100644 index 0000000..171a720 Binary files /dev/null and b/resultat/images/dbn/E_Y_A/dbn_3_chars_400_400_400_400_Units.png differ diff --git a/resultat/images/dbn/E_Y_A_2/dbn_4_chars_400_400_400_400_Units.png b/resultat/images/dbn/E_Y_A_2/dbn_4_chars_400_400_400_400_Units.png new file mode 100644 index 0000000..8fdada1 Binary files /dev/null and b/resultat/images/dbn/E_Y_A_2/dbn_4_chars_400_400_400_400_Units.png differ diff --git a/resultat/images/dbn/E_Y_A_2_7/dbn_5_chars_400_400_400_400_Units.png b/resultat/images/dbn/E_Y_A_2_7/dbn_5_chars_400_400_400_400_Units.png new file mode 100644 index 0000000..a242697 Binary files /dev/null and b/resultat/images/dbn/E_Y_A_2_7/dbn_5_chars_400_400_400_400_Units.png differ diff --git a/resultat/images/dbn/dbn_1_chars_400_Units_4_Layers.png b/resultat/images/dbn/dbn_1_chars_400_Units_4_Layers.png new file mode 100644 index 0000000..3d4177c Binary files /dev/null and b/resultat/images/dbn/dbn_1_chars_400_Units_4_Layers.png differ diff --git a/resultat/images/dbn/dbn_1_chars_700_Units_5_Layers.png b/resultat/images/dbn/dbn_1_chars_700_Units_5_Layers.png new file mode 100644 index 0000000..c407cc0 Binary files /dev/null and b/resultat/images/dbn/dbn_1_chars_700_Units_5_Layers.png differ diff --git a/resultat/images/dbn/dbn_2_chars_400_Units_4_Layers.png b/resultat/images/dbn/dbn_2_chars_400_Units_4_Layers.png new file mode 100644 index 0000000..c823bfb Binary files /dev/null and b/resultat/images/dbn/dbn_2_chars_400_Units_4_Layers.png differ diff --git a/resultat/images/dbn/dbn_3_chars_400_Units_4_Layers.png b/resultat/images/dbn/dbn_3_chars_400_Units_4_Layers.png new file mode 100644 index 0000000..e29f3f0 Binary files /dev/null and b/resultat/images/dbn/dbn_3_chars_400_Units_4_Layers.png differ diff --git a/resultat/images/dbn/dbn_4_chars_400_Units_4_Layers.png b/resultat/images/dbn/dbn_4_chars_400_Units_4_Layers.png new file mode 100644 index 0000000..1061d38 Binary files /dev/null and b/resultat/images/dbn/dbn_4_chars_400_Units_4_Layers.png differ diff --git a/resultat/images/dbn/dbn_5_chars_400_Units_4_Layers.png b/resultat/images/dbn/dbn_5_chars_400_Units_4_Layers.png new file mode 100644 index 0000000..57a7b2f Binary files /dev/null and b/resultat/images/dbn/dbn_5_chars_400_Units_4_Layers.png differ diff --git a/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['2'].png b/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['2'].png new file mode 100644 index 0000000..63ac3f5 Binary files /dev/null and b/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['2'].png differ diff --git a/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['Y'].png b/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['Y'].png new file mode 100644 index 0000000..1297999 Binary files /dev/null and b/resultat/images/rbm/rbm_1_chars_Units_700_Layers_['Y'].png differ diff --git a/resultat/images/rbm/rbm_samples.png b/resultat/images/rbm/rbm_samples.png new file mode 100644 index 0000000..3327829 Binary files /dev/null and b/resultat/images/rbm/rbm_samples.png differ diff --git a/resultat/rbm/100_Units_1_Chars.npy b/resultat/rbm/100_Units_1_Chars.npy new file mode 100644 index 0000000..926c82d Binary files /dev/null and b/resultat/rbm/100_Units_1_Chars.npy differ diff --git a/resultat/rbm/200_Units_1_Chars.npy b/resultat/rbm/200_Units_1_Chars.npy new file mode 100644 index 0000000..185fb99 Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars.npy differ diff --git a/resultat/rbm/200_Units_1_Chars_Sample_0.npy b/resultat/rbm/200_Units_1_Chars_Sample_0.npy new file mode 100644 index 0000000..63e407f Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars_Sample_0.npy differ diff --git a/resultat/rbm/200_Units_1_Chars_Sample_1.npy b/resultat/rbm/200_Units_1_Chars_Sample_1.npy new file mode 100644 index 0000000..80e7aa9 Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars_Sample_1.npy differ diff --git a/resultat/rbm/200_Units_1_Chars_Sample_2.npy b/resultat/rbm/200_Units_1_Chars_Sample_2.npy new file mode 100644 index 0000000..043fa08 Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars_Sample_2.npy differ diff --git a/resultat/rbm/200_Units_1_Chars_Sample_3.npy b/resultat/rbm/200_Units_1_Chars_Sample_3.npy new file mode 100644 index 0000000..aed9946 Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars_Sample_3.npy differ diff --git a/resultat/rbm/200_Units_1_Chars_Sample_4.npy b/resultat/rbm/200_Units_1_Chars_Sample_4.npy new file mode 100644 index 0000000..10e4153 Binary files /dev/null and b/resultat/rbm/200_Units_1_Chars_Sample_4.npy differ diff --git a/resultat/rbm/300_Units_1_Chars.npy b/resultat/rbm/300_Units_1_Chars.npy new file mode 100644 index 0000000..e1e6f12 Binary files /dev/null and b/resultat/rbm/300_Units_1_Chars.npy differ diff --git a/resultat/rbm/400_Units_1_Chars.npy b/resultat/rbm/400_Units_1_Chars.npy new file mode 100644 index 0000000..bc1d65e Binary files /dev/null and b/resultat/rbm/400_Units_1_Chars.npy differ diff --git a/resultat/rbm/500_Units_1_Chars.npy b/resultat/rbm/500_Units_1_Chars.npy new file mode 100644 index 0000000..b742dab Binary files /dev/null and b/resultat/rbm/500_Units_1_Chars.npy differ diff --git a/resultat/rbm/600_Units_1_Chars.npy b/resultat/rbm/600_Units_1_Chars.npy new file mode 100644 index 0000000..6174cdf Binary files /dev/null and b/resultat/rbm/600_Units_1_Chars.npy differ diff --git a/resultat/rbm/700_Units_1_Chars.npy b/resultat/rbm/700_Units_1_Chars.npy new file mode 100644 index 0000000..ce9a0bb Binary files /dev/null and b/resultat/rbm/700_Units_1_Chars.npy differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d57a7ef --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup, find_packages + +setup( + name='generative_model', + version='0.1', + author='Yedidia AGNIMO & C. Yann Éric CHOHO', + author_email='yedidia.agnimo@ensae.fr // chohoyanneric.choho@ensae.fr', + description='Implement basics deep learning architecture for generative models.', + packages=find_packages(where='src'), + package_dir={'': 'src'}, + python_requires='>=3.10', +) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/principal_dbn_alpha.py b/src/principal_dbn_alpha.py new file mode 100644 index 0000000..69bcb20 --- /dev/null +++ b/src/principal_dbn_alpha.py @@ -0,0 +1,125 @@ +""" +Module: dbn.py +Module providing implementation of Deep Belief Network (DBN). +""" + +from typing import List + +import numpy as np +from principal_rbm_alpha import RBM + + +class DBN: + """ + Implementation of a Deep Belief Network (DBN). + + Attributes: + - n_visible (int): Number of visible units. + - hidden_layer_sizes (List[int]): List of sizes for each hidden layer. + - rbms (List[RBM]): List of Restricted Boltzmann Machines (RBMs) forming the DBN. + - rng (numpy.random.Generator): Random number generator for sampling. + """ + + def __init__( + self, + n_visible: int, + hidden_layer_sizes: List[int], + random_state=None + ): + """ + Initialize the Deep Belief Network. + + Parameters: + - n_visible (int): Number of visible units. + - hidden_layer_sizes (List[int]): List of sizes for each hidden layer. + - random_state: Random seed for reproducibility. + """ + self.n_visible = n_visible + self.hidden_layer_sizes = hidden_layer_sizes + self.rbms: List[RBM] = [] + self.rng = np.random.default_rng(random_state) + + # Initialize the first RBM + first_rbm = RBM( + n_visible=n_visible, + n_hidden=hidden_layer_sizes[0], + random_state=random_state, + ) + self.rbms.append(first_rbm) + + # Initialize RBMs for subsequent hidden layers + for i, size in enumerate(hidden_layer_sizes[1:], start=1): + rbm = RBM( + n_visible=hidden_layer_sizes[i - 1], + n_hidden=size, + random_state=random_state, + ) + self.rbms.append(rbm) + + def __getitem__(self, key): + return self.rbms[key] + + def __repr__(self): + """ + Return a string representation of the DBN object. + """ + rbm_reprs = [repr(rbm) for rbm in self.rbms] + join_rbm_reprs = ",\n ".join(rbm_reprs) + return f"DBN([\n {join_rbm_reprs}\n])" + + def train( + self, + data: np.ndarray, + learning_rate: float = 0.1, + n_epochs: int = 10, + batch_size: int = 10, + print_each: int = 10, + ) -> "DBN": + """ + Train the DBN using Greedy layer-wise procedure. + + Parameters: + - data (numpy.ndarray): Input data, shape (n_samples, n_visible). + - learning_rate (float): Learning rate for gradient descent. Default is 0.1. + - n_epochs (int): Number of training epochs. Default is 10. + - batch_size (int): Size of mini-batches. Default is 10. + - print_each: Print reconstruction error each `print_each` epochs. + + Returns: + - DBN: Trained DBN instance. + """ + input_data = data + for rbm in self.rbms: + rbm.train( + input_data, + learning_rate=learning_rate, + n_epochs=n_epochs, + batch_size=batch_size, + print_each=print_each, + ) + # Update input data for the next RBM + input_data = rbm.input_output(input_data) + + return self + + def generate_image(self, n_samples: int = 1, n_gibbs_steps: int = 1) -> np.ndarray: + """ + Generate samples from the DBN using Gibbs sampling. + + Parameters: + - n_samples (int): Number of samples to generate. Default is 1. + - n_gibbs_steps (int): Number of Gibbs sampling steps. Default is 100. + + Returns: + - numpy.ndarray: Generated samples, shape (n_samples, n_visible). + """ + # samples = np.zeros((n_samples, self.n_visible)) + + # Generate samples using the first RBM in the DBN + samples = self.rbms[-1].generate_image(n_samples, n_gibbs_steps) + for rbm in reversed(self.rbms[:-1]): + # Sample from the conditional probability of layer l-1 given layer l: p(h_{s-1}|h_{s}). + h_probs = rbm.output_input(samples) + h = self.rng.binomial(1, p=h_probs) + samples = h + return samples diff --git a/src/principal_rbm_alpha.py b/src/principal_rbm_alpha.py index 18ca253..cff9ef3 100644 --- a/src/principal_rbm_alpha.py +++ b/src/principal_rbm_alpha.py @@ -1,4 +1,15 @@ """Principal RBM alpha. +#TODO: control for verbosity (add 'verbose' arg / think about where to progression bar with `tqdm`) +#TODO: add a representation "__repr__" to the class. Look like `RBM(n_visible, n_hidden, rng)`. +#TODO: move `sigmoid`, and function related to data into others modules (utils, load_data). +# HACK: optimize code, accelerate matrix computation with numba, parallelized when possible. +#TODO: check relevance of using the RBM's RNG for generation phase (look inside the gibbs sampling). +# If a seed has been define, the gibbs sampling step will return the same sample for each it will +# sample the same h from the binomial -> #WARNING there might be something wrong with the function +# --------------------------- Other Tags (Example usage) ---------------------. +# FIXME: Example: This function is returning incorrect results for negative input values. +# BUG: Example: Division by zero error occurs in certain cases. +# HACK: Example: This code temporarily fixes the issue, but needs a proper solution. """ import os @@ -169,7 +180,7 @@ def _reconstruction_error( """ return np.round(np.power(output_img - input_img, 2).mean(), 3) - def entree_sortie(self, data: np.ndarray) -> np.ndarray: + def input_output(self, data: np.ndarray) -> np.ndarray: """ Compute hidden units given visible units. @@ -181,7 +192,7 @@ def entree_sortie(self, data: np.ndarray) -> np.ndarray: """ return self._sigmoid(data @ self.W + self.b) - def sortie_entree(self, data_h: np.ndarray) -> np.ndarray: + def output_input(self, data_h: np.ndarray) -> np.ndarray: """ Compute visible units given hidden units. @@ -218,9 +229,9 @@ def train( self.rng.shuffle(data) for i in tqdm(range(0, n_samples, batch_size), desc=f"Epoch {epoch}"): batch = data[i: i + batch_size] - pos_h_probs = self.entree_sortie(batch) - pos_v_probs = self.sortie_entree(pos_h_probs) - neg_h_probs = self.entree_sortie(pos_v_probs) + pos_h_probs = self.input_output(batch) + pos_v_probs = self.output_input(pos_h_probs) + neg_h_probs = self.input_output(pos_v_probs) # Update weights and biases self.W += ( @@ -242,7 +253,7 @@ def train( return self - def generer_image(self, n_samples: int = 1, n_gibbs_steps: int = 1) -> np.ndarray: + def generate_image(self, n_samples: int=1, n_gibbs_steps: int=1) -> np.ndarray: """ Generate samples from the RBM using Gibbs sampling. @@ -260,8 +271,8 @@ def generer_image(self, n_samples: int = 1, n_gibbs_steps: int = 1) -> np.ndarra 1, self.rng.random(), size=n_samples * self.n_visible ).reshape((n_samples, self.n_visible)) for i in range(n_samples): + h_probs = self._sigmoid(V[i] @ self.W + self.b) for _ in range(n_gibbs_steps): - h_probs = self._sigmoid(V[i] @ self.W + self.b) h = self.rng.binomial(1, h_probs) v_probs = self._sigmoid(h @ self.W.T + self.a) v = self.rng.binomial(1, v_probs) diff --git a/src/tests/test_rbm.py b/src/tests/test_rbm.py new file mode 100644 index 0000000..a349913 --- /dev/null +++ b/src/tests/test_rbm.py @@ -0,0 +1,62 @@ +import os +import numpy as np +import scipy.io +import unittest +from rbm import RBM, _load_data, _map_character_to_index, read_alpha_digit + +DATA_FOLDER = "../data/" +ALPHA_DIGIT_PATH = os.path.join(DATA_FOLDER, "binaryalphadigs.mat") + + +class TestRBM(unittest.TestCase): + def setUp(self): + # Load alpha_digit data for testing + self.data = read_alpha_digit(file_path=ALPHA_DIGIT_PATH, character='A') + self.n_samples = self.data.shape[0] + self.n_visible = self.data.shape[1] + self.n_hidden = 100 + self.rbm = RBM(n_visible=self.n_visible, n_hidden=self.n_hidden) + + def test__sigmoid(self): + # Test _sigmoid method with positive and negative values + x = np.array([1, -1, 0]) + sigmoid_x = self.rbm._sigmoid(x) + self.assertTrue(np.allclose(sigmoid_x, [0.73105858, 0.26894142, 0.5])) + + def test__reconstruction_error(self): + # Test _reconstruction_error method with arrays containing zeros + input_img = np.zeros_like(self.data) + output_img = np.zeros_like(self.data) + error = self.rbm._reconstruction_error(input_img, output_img) + self.assertEqual(error, 0.0) + + def test_entree_sortie(self): + # Test entree_sortie method with a small input array + input_data = np.ones((2, self.n_visible)) + output = self.rbm.entree_sortie(input_data) + self.assertEqual(output.shape, (2, self.n_hidden)) + + def test_sortie_entree(self): + # Test sortie_entree method with a small input array + input_data = np.ones((2, self.n_hidden)) + output = self.rbm.sortie_entree(input_data) + self.assertEqual(output.shape, (2, self.n_visible)) + + def test_train(self): + # Test train method + trained_rbm = self.rbm.train( + self.data[:100], learning_rate=0.1, n_epochs=1, batch_size=10 + ) + self.assertTrue(hasattr(trained_rbm, 'W')) + self.assertTrue(hasattr(trained_rbm, 'a')) + self.assertTrue(hasattr(trained_rbm, 'b')) + + def test_generer_image(self): + # Test generer_image method + samples = self.rbm.generer_image(n_samples=2, n_gibbs_steps=10) + self.assertEqual(samples.shape, (2, self.n_visible)) + self.assertTrue(np.all((samples == 0) | (samples == 1))) + + +if __name__ == "__main__": + unittest.main()