From 7973626c433aceb0a2171176046ca9e3fb9bcf22 Mon Sep 17 00:00:00 2001 From: Klus3kk Date: Mon, 30 Jun 2025 10:18:14 +0200 Subject: [PATCH] Finished the notebook examples --- examples/notebooks/intro.ipynb | 483 +++++++++++++++++++++++++++++---- fit/nn/modules/container.py | 2 + fit/nn/utils/model_io.py | 2 +- 3 files changed, 433 insertions(+), 54 deletions(-) diff --git a/examples/notebooks/intro.ipynb b/examples/notebooks/intro.ipynb index aac5eac..a5a5486 100644 --- a/examples/notebooks/intro.ipynb +++ b/examples/notebooks/intro.ipynb @@ -55,9 +55,9 @@ "Vector: Tensor([1. 2. 3.], requires_grad=True)\n", "Matrix: Tensor([[1. 2.]\n", " [3. 4.]], requires_grad=True)\n", - "Random tensor: Tensor([[ 0.25159698 0.36753971 1.09971921]\n", - " [ 0.4132483 1.51174617 -0.2696245 ]\n", - " [-0.92404499 0.40118634 1.46629455]], requires_grad=True)\n" + "Random tensor: Tensor([[-1.08957361 -1.14300669 -1.06784566]\n", + " [ 0.33019738 -0.22200625 -1.82231365]\n", + " [ 0.33581401 1.90352181 -0.31104841]], requires_grad=True)\n" ] } ], @@ -334,7 +334,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Model output: [[0.84966947 0.07866769 0.07166284]]\n", + "Model output: [[0.70130264 0.12593192 0.17276544]]\n", "Output sum (should be ~1.0): 1.0\n", "Number of parameters: 6\n" ] @@ -429,7 +429,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loss: 4.794796120705038\n", + "Loss: 4.835619504464079\n", "Training step completed!\n" ] } @@ -477,11 +477,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Step 1, Loss: 0.2421\n", - "Step 2, Loss: 2.3305\n", - "Step 3, Loss: 0.3512\n", - "Step 4, Loss: 0.2987\n", - "Step 5, Loss: 0.0059\n" + "Step 1, Loss: 0.8734\n", + "Step 2, Loss: 0.4638\n", + "Step 3, Loss: 0.8612\n", + "Step 4, Loss: 1.1857\n", + "Step 5, Loss: 0.6425\n" ] } ], @@ -596,10 +596,10 @@ "Iris Dataset loaded:\n", "Train batches: 4\n", "Validation batches: 1\n", - "XOR - X: [[1. 1.]\n", + "XOR - X: [[0. 0.]\n", + " [0. 1.]\n", " [0. 0.]\n", - " [1. 0.]\n", - " [1. 1.]], y: [0. 0. 1. 0.]\n" + " [1. 0.]], y: [0. 1. 0. 1.]\n" ] } ], @@ -641,7 +641,7 @@ "text": [ "Original features: 10\n", "Selected features: 5\n", - "Selected feature indices: [ True True False False True False True False False True]\n" + "Selected feature indices: [ True False True False False True True False False True]\n" ] } ], @@ -883,18 +883,18 @@ "text": [ "Epoch train_loss val_loss accuracy Time \n", "--------------------------------------------------\n", - "1 0.9928 1.2488 0.0071 \n", - "2 0.9054 1.1707 0.1051 \n", - "3 0.8113 1.0706 0.1652 \n", - "4 0.6538 0.9592 0.2041 \n", - "5 0.6012 0.9884 0.3324 \n", - "6 0.5177 0.7813 0.3942 \n", - "7 0.4032 0.7712 0.4748 \n", - "8 0.3692 0.5478 0.5609 \n", - "9 0.1141 0.4863 0.6513 \n", - "10 0.1353 0.4451 0.7141 \n", + "1 0.9389 0.9989 0.0152 \n", + "2 0.8907 1.1735 0.0667 \n", + "3 0.8214 1.1364 0.1414 \n", + "4 0.7042 0.9471 0.2042 \n", + "5 0.5579 0.8620 0.3140 \n", + "6 0.5659 0.7661 0.3624 \n", + "7 0.3904 0.8223 0.4745 \n", + "8 0.3932 0.5757 0.5251 \n", + "9 0.3083 0.4118 0.6616 \n", + "10 0.1303 0.5031 0.7300 \n", "Metrics logged!\n", - "Best validation loss: 0.4450874973037048\n", + "Best validation loss: 0.41183470041323994\n", "Logs exported to training_log.json\n" ] }, @@ -952,7 +952,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Evaluation metrics: {'loss': np.float64(1.1119251456841315), 'accuracy': np.float64(0.35)}\n" + "Evaluation metrics: {'loss': np.float64(1.0865359454049197), 'accuracy': np.float64(0.5)}\n" ] } ], @@ -997,15 +997,16 @@ "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "Can't pickle local object 'Tensor.__init__..'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[23], line 11\u001b[0m\n\u001b[1;32m 4\u001b[0m model \u001b[38;5;241m=\u001b[39m Sequential(\n\u001b[1;32m 5\u001b[0m Linear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m),\n\u001b[1;32m 6\u001b[0m ReLU(),\n\u001b[1;32m 7\u001b[0m Linear(\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# Save model\u001b[39;00m\n\u001b[0;32m---> 11\u001b[0m \u001b[43msave_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdemo_model.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel saved!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# Load model\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/Codes/fit/fit/nn/utils/model_io.py:33\u001b[0m, in \u001b[0;36msave_model\u001b[0;34m(model, path)\u001b[0m\n\u001b[1;32m 27\u001b[0m save_data \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 28\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m'\u001b[39m: model,\n\u001b[1;32m 29\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstate\u001b[39m\u001b[38;5;124m'\u001b[39m: model_state\n\u001b[1;32m 30\u001b[0m }\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mwb\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m---> 33\u001b[0m \u001b[43mpickle\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdump\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel saved to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "\u001b[0;31mAttributeError\u001b[0m: Can't pickle local object 'Tensor.__init__..'" + "name": "stdout", + "output_type": "stream", + "text": [ + "Model saved to demo_model.pkl\n", + "Model saved!\n", + "Model loaded from demo_model.pkl\n", + "Model loaded!\n", + "Original output: [[-0.21497577]]\n", + "Loaded output: [[-0.21497577]]\n", + "Outputs match: True\n" ] } ], @@ -1049,11 +1050,234 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training XOR model...\n", + "Starting training for 100 epochs...\n", + "Model: Sequential\n", + "Optimizer: Adam\n", + "Loss: MSELoss\n", + "Batch size: 32\n", + "--------------------------------------------------\n", + "Epoch train_loss Time \n", + "--------------------------\n", + "1 0.5087 \n", + "Epoch 1: train_loss=0.5087\n", + "2 0.4190 \n", + "Epoch 2: train_loss=0.4190\n", + "3 0.3445 \n", + "Epoch 3: train_loss=0.3445\n", + "4 0.2844 \n", + "Epoch 4: train_loss=0.2844\n", + "5 0.2368 \n", + "Epoch 5: train_loss=0.2368\n", + "6 0.1989 \n", + "Epoch 6: train_loss=0.1989\n", + "7 0.1719 \n", + "Epoch 7: train_loss=0.1719\n", + "8 0.1514 \n", + "Epoch 8: train_loss=0.1514\n", + "9 0.1336 \n", + "Epoch 9: train_loss=0.1336\n", + "10 0.1180 \n", + "Epoch 10: train_loss=0.1180\n", + "11 0.1040 \n", + "Epoch 11: train_loss=0.1040\n", + "12 0.0917 \n", + "Epoch 12: train_loss=0.0917\n", + "13 0.0811 \n", + "Epoch 13: train_loss=0.0811\n", + "14 0.0721 \n", + "Epoch 14: train_loss=0.0721\n", + "15 0.0646 \n", + "Epoch 15: train_loss=0.0646\n", + "16 0.0587 \n", + "Epoch 16: train_loss=0.0587\n", + "17 0.0538 \n", + "Epoch 17: train_loss=0.0538\n", + "18 0.0497 \n", + "Epoch 18: train_loss=0.0497\n", + "19 0.0459 \n", + "Epoch 19: train_loss=0.0459\n", + "20 0.0420 \n", + "Epoch 20: train_loss=0.0420\n", + "21 0.0380 \n", + "Epoch 21: train_loss=0.0380\n", + "22 0.0337 \n", + "Epoch 22: train_loss=0.0337\n", + "23 0.0295 \n", + "Epoch 23: train_loss=0.0295\n", + "24 0.0254 \n", + "Epoch 24: train_loss=0.0254\n", + "25 0.0224 \n", + "Epoch 25: train_loss=0.0224\n", + "26 0.0195 \n", + "Epoch 26: train_loss=0.0195\n", + "27 0.0167 \n", + "Epoch 27: train_loss=0.0167\n", + "28 0.0142 \n", + "Epoch 28: train_loss=0.0142\n", + "29 0.0124 \n", + "Epoch 29: train_loss=0.0124\n", + "30 0.0108 \n", + "Epoch 30: train_loss=0.0108\n", + "31 0.0091 \n", + "Epoch 31: train_loss=0.0091\n", + "32 0.0075 \n", + "Epoch 32: train_loss=0.0075\n", + "33 0.0059 \n", + "Epoch 33: train_loss=0.0059\n", + "34 0.0046 \n", + "Epoch 34: train_loss=0.0046\n", + "35 0.0035 \n", + "Epoch 35: train_loss=0.0035\n", + "36 0.0026 \n", + "Epoch 36: train_loss=0.0026\n", + "37 0.0018 \n", + "Epoch 37: train_loss=0.0018\n", + "38 0.0013 \n", + "Epoch 38: train_loss=0.0013\n", + "39 0.0009 \n", + "Epoch 39: train_loss=0.0009\n", + "40 0.0006 \n", + "Epoch 40: train_loss=0.0006\n", + "41 0.0004 \n", + "Epoch 41: train_loss=0.0004\n", + "42 0.0003 \n", + "Epoch 42: train_loss=0.0003\n", + "43 0.0003 \n", + "Epoch 43: train_loss=0.0003\n", + "44 0.0003 \n", + "Epoch 44: train_loss=0.0003\n", + "45 0.0004 \n", + "Epoch 45: train_loss=0.0004\n", + "46 0.0005 \n", + "Epoch 46: train_loss=0.0005\n", + "47 0.0006 \n", + "Epoch 47: train_loss=0.0006\n", + "48 0.0007 \n", + "Epoch 48: train_loss=0.0007\n", + "49 0.0007 \n", + "Epoch 49: train_loss=0.0007\n", + "50 0.0007 \n", + "Epoch 50: train_loss=0.0007\n", + "51 0.0006 \n", + "Epoch 51: train_loss=0.0006\n", + "52 0.0006 \n", + "Epoch 52: train_loss=0.0006\n", + "53 0.0005 \n", + "Epoch 53: train_loss=0.0005\n", + "54 0.0004 \n", + "Epoch 54: train_loss=0.0004\n", + "55 0.0003 \n", + "Epoch 55: train_loss=0.0003\n", + "56 0.0003 \n", + "Epoch 56: train_loss=0.0003\n", + "57 0.0002 \n", + "Epoch 57: train_loss=0.0002\n", + "58 0.0002 \n", + "Epoch 58: train_loss=0.0002\n", + "59 0.0002 \n", + "Epoch 59: train_loss=0.0002\n", + "60 0.0001 \n", + "Epoch 60: train_loss=0.0001\n", + "61 0.0001 \n", + "Epoch 61: train_loss=0.0001\n", + "62 0.0001 \n", + "Epoch 62: train_loss=0.0001\n", + "63 0.0002 \n", + "Epoch 63: train_loss=0.0002\n", + "64 0.0002 \n", + "Epoch 64: train_loss=0.0002\n", + "65 0.0002 \n", + "Epoch 65: train_loss=0.0002\n", + "66 0.0002 \n", + "Epoch 66: train_loss=0.0002\n", + "67 0.0002 \n", + "Epoch 67: train_loss=0.0002\n", + "68 0.0002 \n", + "Epoch 68: train_loss=0.0002\n", + "69 0.0002 \n", + "Epoch 69: train_loss=0.0002\n", + "70 0.0001 \n", + "Epoch 70: train_loss=0.0001\n", + "71 0.0001 \n", + "Epoch 71: train_loss=0.0001\n", + "72 0.0001 \n", + "Epoch 72: train_loss=0.0001\n", + "73 0.0001 \n", + "Epoch 73: train_loss=0.0001\n", + "74 0.0001 \n", + "Epoch 74: train_loss=0.0001\n", + "75 0.0001 \n", + "Epoch 75: train_loss=0.0001\n", + "76 0.0001 \n", + "Epoch 76: train_loss=0.0001\n", + "77 0.0000 \n", + "Epoch 77: train_loss=0.0000\n", + "78 0.0000 \n", + "Epoch 78: train_loss=0.0000\n", + "79 0.0000 \n", + "Epoch 79: train_loss=0.0000\n", + "80 0.0000 \n", + "Epoch 80: train_loss=0.0000\n", + "81 0.0000 \n", + "Epoch 81: train_loss=0.0000\n", + "82 0.0000 \n", + "Epoch 82: train_loss=0.0000\n", + "83 0.0000 \n", + "Epoch 83: train_loss=0.0000\n", + "84 0.0000 \n", + "Epoch 84: train_loss=0.0000\n", + "85 0.0000 \n", + "Epoch 85: train_loss=0.0000\n", + "86 0.0000 \n", + "Epoch 86: train_loss=0.0000\n", + "87 0.0000 \n", + "Epoch 87: train_loss=0.0000\n", + "88 0.0000 \n", + "Epoch 88: train_loss=0.0000\n", + "89 0.0000 \n", + "Epoch 89: train_loss=0.0000\n", + "90 0.0000 \n", + "Epoch 90: train_loss=0.0000\n", + "91 0.0000 \n", + "Epoch 91: train_loss=0.0000\n", + "92 0.0000 \n", + "Epoch 92: train_loss=0.0000\n", + "93 0.0000 \n", + "Epoch 93: train_loss=0.0000\n", + "94 0.0000 \n", + "Epoch 94: train_loss=0.0000\n", + "95 0.0000 \n", + "Epoch 95: train_loss=0.0000\n", + "96 0.0000 \n", + "Epoch 96: train_loss=0.0000\n", + "97 0.0000 \n", + "Epoch 97: train_loss=0.0000\n", + "98 0.0000 \n", + "Epoch 98: train_loss=0.0000\n", + "99 0.0000 \n", + "Epoch 99: train_loss=0.0000\n", + "100 0.0000 \n", + "Epoch 100: train_loss=0.0000\n", + "Training completed!\n", + "\n", + "XOR Results:\n", + "Input: [0. 0.], Expected: 0, Predicted: 0.003\n", + "Input: [0. 1.], Expected: 1, Predicted: 0.997\n", + "Input: [1. 0.], Expected: 1, Predicted: 1.001\n", + "Input: [1. 1.], Expected: 0, Predicted: -0.004\n" + ] + } + ], "source": [ - "from fit.simple.trainer import Trainer\n", + "from fit.simple.trainer import SimpleTrainer\n", "\n", "# Create XOR model\n", "xor_model = Sequential(Linear(2, 8), ReLU(), Linear(8, 4), ReLU(), Linear(4, 1))\n", @@ -1063,16 +1287,19 @@ "y = Tensor([[0], [1], [1], [0]])\n", "\n", "# Create trainer\n", - "trainer = Trainer(model=xor_model, loss=\"mse\", optimizer=\"adam\", lr=0.01)\n", + "trainer = SimpleTrainer(\n", + " model=xor_model,\n", + " data=(X, y),\n", + " loss='mse',\n", + " optimizer='adam',\n", + " lr=0.01\n", + ")\n", "\n", "# Train the model\n", "print(\"Training XOR model...\")\n", "history = trainer.fit(\n", - " data=(X, y),\n", - " epochs=100,\n", - " batch_size=4,\n", - " validation_split=0.0, # Use all data for training\n", - " verbose=True,\n", + " epochs=100, \n", + " verbose=True\n", ")\n", "\n", "# Test the trained model\n", @@ -1080,7 +1307,7 @@ "for i, (input_val, expected) in enumerate(zip(X.data, y.data)):\n", " prediction = xor_model(Tensor([input_val]))\n", " print(\n", - " f\"Input: {input_val}, Expected: {expected[0]:.0f}, Predicted: {prediction.data[0]:.3f}\"\n", + " f\"Input: {input_val}, Expected: {expected[0]:.0f}, Predicted: {float(prediction.data.flatten()[0]):.3f}\"\n", " )" ] }, @@ -1093,9 +1320,128 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Iris dataset...\n", + "Training Iris classifier...\n", + "Starting training for 50 epochs...\n", + "Model: Sequential\n", + "Optimizer: Adam\n", + "Loss: CrossEntropyLoss\n", + "Batch size: 16\n", + "--------------------------------------------------\n", + "Epoch train_loss val_loss val_accuracyTime \n", + "--------------------------------------------------\n", + "1 0.9388 0.7904 0.7111 \n", + "Epoch 1: train_loss=0.9388, val_loss=0.7904, val_acc=0.7111\n", + "2 0.7772 0.7367 0.7556 \n", + "Epoch 2: train_loss=0.7772, val_loss=0.7367, val_acc=0.7556\n", + "3 0.7353 0.7237 0.8222 \n", + "Epoch 3: train_loss=0.7353, val_loss=0.7237, val_acc=0.8222\n", + "4 0.6990 0.7130 0.8222 \n", + "Epoch 4: train_loss=0.6990, val_loss=0.7130, val_acc=0.8222\n", + "5 0.6880 0.6977 0.8222 \n", + "Epoch 5: train_loss=0.6880, val_loss=0.6977, val_acc=0.8222\n", + "6 0.6643 0.6939 0.8444 \n", + "Epoch 6: train_loss=0.6643, val_loss=0.6939, val_acc=0.8444\n", + "7 0.6346 0.6692 0.8889 \n", + "Epoch 7: train_loss=0.6346, val_loss=0.6692, val_acc=0.8889\n", + "8 0.6180 0.6601 0.8889 \n", + "Epoch 8: train_loss=0.6180, val_loss=0.6601, val_acc=0.8889\n", + "9 0.6207 0.6591 0.9111 \n", + "Epoch 9: train_loss=0.6207, val_loss=0.6591, val_acc=0.9111\n", + "10 0.5999 0.6356 0.9111 \n", + "Epoch 10: train_loss=0.5999, val_loss=0.6356, val_acc=0.9111\n", + "11 0.5926 0.6311 0.9333 \n", + "Epoch 11: train_loss=0.5926, val_loss=0.6311, val_acc=0.9333\n", + "12 0.5865 0.6442 0.9333 \n", + "Epoch 12: train_loss=0.5865, val_loss=0.6442, val_acc=0.9333\n", + "13 0.5819 0.6296 0.9333 \n", + "Epoch 13: train_loss=0.5819, val_loss=0.6296, val_acc=0.9333\n", + "14 0.5740 0.6224 0.9333 \n", + "Epoch 14: train_loss=0.5740, val_loss=0.6224, val_acc=0.9333\n", + "15 0.5725 0.6266 0.9333 \n", + "Epoch 15: train_loss=0.5725, val_loss=0.6266, val_acc=0.9333\n", + "16 0.5698 0.6344 0.8889 \n", + "Epoch 16: train_loss=0.5698, val_loss=0.6344, val_acc=0.8889\n", + "17 0.5677 0.6385 0.9111 \n", + "Epoch 17: train_loss=0.5677, val_loss=0.6385, val_acc=0.9111\n", + "18 0.5675 0.6283 0.9111 \n", + "Epoch 18: train_loss=0.5675, val_loss=0.6283, val_acc=0.9111\n", + "19 0.5710 0.6383 0.9111 \n", + "Epoch 19: train_loss=0.5710, val_loss=0.6383, val_acc=0.9111\n", + "20 0.5655 0.6257 0.9111 \n", + "Epoch 20: train_loss=0.5655, val_loss=0.6257, val_acc=0.9111\n", + "21 0.5630 0.6246 0.9333 \n", + "Epoch 21: train_loss=0.5630, val_loss=0.6246, val_acc=0.9333\n", + "22 0.5628 0.6333 0.9111 \n", + "Epoch 22: train_loss=0.5628, val_loss=0.6333, val_acc=0.9111\n", + "23 0.5608 0.6335 0.8889 \n", + "Epoch 23: train_loss=0.5608, val_loss=0.6335, val_acc=0.8889\n", + "24 0.5633 0.6279 0.9111 \n", + "Epoch 24: train_loss=0.5633, val_loss=0.6279, val_acc=0.9111\n", + "25 0.5654 0.6437 0.9111 \n", + "Epoch 25: train_loss=0.5654, val_loss=0.6437, val_acc=0.9111\n", + "26 0.5589 0.6344 0.8889 \n", + "Epoch 26: train_loss=0.5589, val_loss=0.6344, val_acc=0.8889\n", + "27 0.5596 0.6301 0.9111 \n", + "Epoch 27: train_loss=0.5596, val_loss=0.6301, val_acc=0.9111\n", + "28 0.5588 0.6376 0.9111 \n", + "Epoch 28: train_loss=0.5588, val_loss=0.6376, val_acc=0.9111\n", + "29 0.5594 0.6356 0.9111 \n", + "Epoch 29: train_loss=0.5594, val_loss=0.6356, val_acc=0.9111\n", + "30 0.5577 0.6313 0.9111 \n", + "Epoch 30: train_loss=0.5577, val_loss=0.6313, val_acc=0.9111\n", + "31 0.5572 0.6349 0.9111 \n", + "Epoch 31: train_loss=0.5572, val_loss=0.6349, val_acc=0.9111\n", + "32 0.5573 0.6323 0.9111 \n", + "Epoch 32: train_loss=0.5573, val_loss=0.6323, val_acc=0.9111\n", + "33 0.5570 0.6322 0.9111 \n", + "Epoch 33: train_loss=0.5570, val_loss=0.6322, val_acc=0.9111\n", + "34 0.5566 0.6310 0.9111 \n", + "Epoch 34: train_loss=0.5566, val_loss=0.6310, val_acc=0.9111\n", + "35 0.5560 0.6322 0.9111 \n", + "Epoch 35: train_loss=0.5560, val_loss=0.6322, val_acc=0.9111\n", + "36 0.5595 0.6344 0.9111 \n", + "Epoch 36: train_loss=0.5595, val_loss=0.6344, val_acc=0.9111\n", + "37 0.5622 0.6246 0.9333 \n", + "Epoch 37: train_loss=0.5622, val_loss=0.6246, val_acc=0.9333\n", + "38 0.5554 0.6438 0.9111 \n", + "Epoch 38: train_loss=0.5554, val_loss=0.6438, val_acc=0.9111\n", + "39 0.5593 0.6429 0.9111 \n", + "Epoch 39: train_loss=0.5593, val_loss=0.6429, val_acc=0.9111\n", + "40 0.5558 0.6332 0.9333 \n", + "Epoch 40: train_loss=0.5558, val_loss=0.6332, val_acc=0.9333\n", + "41 0.5558 0.6288 0.9111 \n", + "Epoch 41: train_loss=0.5558, val_loss=0.6288, val_acc=0.9111\n", + "42 0.5553 0.6308 0.9333 \n", + "Epoch 42: train_loss=0.5553, val_loss=0.6308, val_acc=0.9333\n", + "43 0.5547 0.6262 0.9111 \n", + "Epoch 43: train_loss=0.5547, val_loss=0.6262, val_acc=0.9111\n", + "44 0.5556 0.6264 0.9111 \n", + "Epoch 44: train_loss=0.5556, val_loss=0.6264, val_acc=0.9111\n", + "45 0.5564 0.6353 0.9111 \n", + "Epoch 45: train_loss=0.5564, val_loss=0.6353, val_acc=0.9111\n", + "46 0.5549 0.6264 0.9111 \n", + "Epoch 46: train_loss=0.5549, val_loss=0.6264, val_acc=0.9111\n", + "47 0.5562 0.6269 0.9111 \n", + "Epoch 47: train_loss=0.5562, val_loss=0.6269, val_acc=0.9111\n", + "48 0.5538 0.6401 0.9111 \n", + "Epoch 48: train_loss=0.5538, val_loss=0.6401, val_acc=0.9111\n", + "49 0.5613 0.6395 0.9111 \n", + "Epoch 49: train_loss=0.5613, val_loss=0.6395, val_acc=0.9111\n", + "50 0.5559 0.6223 0.9333 \n", + "Epoch 50: train_loss=0.5559, val_loss=0.6223, val_acc=0.9333\n", + "Training completed!\n", + "Final validation metrics: {'loss': np.float64(0.6222946536996398), 'accuracy': np.float64(0.9333333333333333)}\n" + ] + } + ], "source": [ "# Load Iris dataset\n", "iris_data = load_dataset(\"iris\", batch_size=16, validation_split=0.3)\n", @@ -1106,12 +1452,20 @@ ")\n", "\n", "# Create trainer\n", - "trainer = Trainer(model=classifier, loss=\"crossentropy\", optimizer=\"adam\", lr=0.01)\n", + "trainer = SimpleTrainer(\n", + " model=classifier,\n", + " data=iris_data['train'],\n", + " validation_data=iris_data['val'], \n", + " loss=\"crossentropy\", \n", + " optimizer=\"adam\", \n", + " lr=0.01\n", + ")\n", "\n", "# Train\n", "print(\"Training Iris classifier...\")\n", "history = trainer.fit(\n", - " data=iris_data[\"train\"], validation_data=iris_data[\"val\"], epochs=50, verbose=True\n", + " epochs=50, \n", + " verbose=True\n", ")\n", "\n", "# Evaluate\n", @@ -1128,9 +1482,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training loop for regression...\n", + "Epoch 0: Loss = 13.2949\n", + "Epoch 20: Loss = 4.1434\n", + "Epoch 40: Loss = 0.9452\n", + "Epoch 60: Loss = 0.4565\n", + "Epoch 80: Loss = 0.2559\n", + "Training completed!\n", + "Test prediction: 3.706, Expected: 3.750\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1677/1613204889.py:42: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)\n", + " print(f\"Test prediction: {float(prediction.data[0]):.3f}, Expected: {expected:.3f}\")\n" + ] + } + ], "source": [ "from fit.optim.adam import Adam\n", "from fit.loss.regression import MSELoss\n", @@ -1165,7 +1542,7 @@ " optimizer.step()\n", "\n", " if epoch % 20 == 0:\n", - " print(f\"Epoch {epoch}: Loss = {loss.data[0]:.4f}\")\n", + " print(f\"Epoch {epoch}: Loss = {float(loss.data):.4f}\")\n", "\n", "print(\"Training completed!\")\n", "\n", @@ -1173,7 +1550,7 @@ "test_X = Tensor([[1.0, -1.0, 0.5]])\n", "prediction = regression_model(test_X)\n", "expected = 1.0 * 1.5 + (-1.0) * (-2.0) + 0.5 * 0.5 # Using true weights\n", - "print(f\"Test prediction: {prediction.data[0]:.3f}, Expected: {expected:.3f}\")" + "print(f\"Test prediction: {float(prediction.data[0]):.3f}, Expected: {expected:.3f}\")" ] } ], diff --git a/fit/nn/modules/container.py b/fit/nn/modules/container.py index 02a2345..ff87b4f 100644 --- a/fit/nn/modules/container.py +++ b/fit/nn/modules/container.py @@ -80,6 +80,8 @@ def get_config(self): layer_config.update(layer.get_config()) layers_config.append(layer_config) return {"layers": layers_config} + + class ModuleList(Layer): diff --git a/fit/nn/utils/model_io.py b/fit/nn/utils/model_io.py index b7adbfa..9f27d71 100644 --- a/fit/nn/utils/model_io.py +++ b/fit/nn/utils/model_io.py @@ -125,7 +125,7 @@ def load_model(path, model_class=None): else: raise ValueError(f"Unknown layer type: {layer_type}") - model.layers.append(layer) + model.add_layer(layer) else: raise ValueError(f"Cannot instantiate model of type {model_type}")