{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to train an MLP using the Pipeline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import classes and define paths" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from cetaceo.pipeline import Pipeline\n", "from cetaceo.models import MLP\n", "from cetaceo.data import HDF5Dataset\n", "from cetaceo.evaluators import RegressionEvaluatorPlotter\n", "from cetaceo.plotting import TrueVsPredPlotter\n", "from cetaceo.utils import PathManager\n", "from pathlib import Path\n", "\n", "import torch\n", "from sklearn.preprocessing import MinMaxScaler\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DATA_DIR = Path.cwd().parent / \"sample_data\"\n", "CASE_DIR = Path.cwd() / \"results\"\n", "PathManager.create_directory(CASE_DIR / 'models')\n", "PathManager.create_directory(CASE_DIR / 'hyperparameters')\n", "PathManager.create_directory(CASE_DIR / 'plots')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define sklearn scalers if needed\n", "\n", "Here, we create 2 minmax scalers, one for scaling the inputs, and other for the outputs." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "x_scaler = MinMaxScaler()\n", "y_scaler = MinMaxScaler()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create datasets\n", "For this example, we will use the airfoil data from the DLR paper. As the files are processed as .h5, a `HDF5Dataset` is needed. We create one for each dataset split." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_dataset = HDF5Dataset(src_file = str(DATA_DIR) + \"/train.h5\", x_scaler=x_scaler, y_scaler=y_scaler)\n", "test_dataset = HDF5Dataset(src_file = str(DATA_DIR) + \"/test.h5\" , x_scaler=x_scaler, y_scaler=y_scaler)\n", "valid_dataset = HDF5Dataset(src_file = str(DATA_DIR) + \"/val.h5\", x_scaler=x_scaler, y_scaler=y_scaler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After creating the datasets, we can scale them because we passed the scalers on the constructors" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\tTrain dataset length: 23283\n", "\tTest dataset length: 23283\n", "\tValid dataset length: 11940\n", "\tX, y train shapes: torch.Size([23283, 4]) torch.Size([23283, 1])\n" ] } ], "source": [ "x, y = train_dataset[:]\n", "train_dataset.scale_data()\n", "valid_dataset.scale_data()\n", "test_dataset.scale_data()\n", "print(\"\\tTrain dataset length: \", len(train_dataset))\n", "print(\"\\tTest dataset length: \", len(test_dataset))\n", "print(\"\\tValid dataset length: \", len(valid_dataset))\n", "print(\"\\tX, y train shapes:\", x.shape, y.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluators and Plotters\n", "\n", "Here we define which evaluator to use. For this case we use a `RegressionEvaluatorPlotter`, which gives metrics related with regression problems. Additionally, this class can take as parameters a list of plotters, which are useful to create plots based on the model's predictions" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "plotters = [TrueVsPredPlotter()]\n", "evaluator = RegressionEvaluatorPlotter(plots_path=CASE_DIR / 'plots', plotters=plotters)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model creation\n", "\n", "Now, the only thing left is creating the model. For this example we are using an `MLP`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "training_params = {\n", " \"epochs\": 100,\n", " \"lr\": 0.00126,\n", " 'lr_gamma': 0.966,\n", " 'lr_scheduler_step': 1,\n", " 'batch_size': 512,\n", " \"optimizer_class\": torch.optim.Adam,\n", " \"print_rate\": 1,\n", "}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model = MLP(\n", " input_size=x.shape[1],\n", " output_size=y.shape[1],\n", " hidden_size=512,\n", " n_layers=3,\n", " p_dropouts=0.15\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the pipeline" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100 | Train loss (x1e5) 3968.2092 | Test loss (x1e5) 640.3254\n", "Epoch 2/100 | Train loss (x1e5) 531.2284 | Test loss (x1e5) 268.4059\n", "Epoch 3/100 | Train loss (x1e5) 363.7133 | Test loss (x1e5) 193.8889\n", "Epoch 4/100 | Train loss (x1e5) 281.3932 | Test loss (x1e5) 117.6769\n", "Epoch 5/100 | Train loss (x1e5) 219.3242 | Test loss (x1e5) 94.8438\n", "Epoch 6/100 | Train loss (x1e5) 188.3426 | Test loss (x1e5) 66.5134\n", "Epoch 7/100 | Train loss (x1e5) 162.8639 | Test loss (x1e5) 52.9106\n", "Epoch 8/100 | Train loss (x1e5) 147.2162 | Test loss (x1e5) 43.3563\n", "Epoch 9/100 | Train loss (x1e5) 133.1248 | Test loss (x1e5) 69.3100\n", "Epoch 10/100 | Train loss (x1e5) 129.9856 | Test loss (x1e5) 40.7987\n", "Epoch 11/100 | Train loss (x1e5) 116.6901 | Test loss (x1e5) 34.9828\n", "Epoch 12/100 | Train loss (x1e5) 111.6007 | Test loss (x1e5) 30.6759\n", "Epoch 13/100 | Train loss (x1e5) 107.0436 | Test loss (x1e5) 37.7164\n", "Epoch 14/100 | Train loss (x1e5) 100.2151 | Test loss (x1e5) 31.0014\n", "Epoch 15/100 | Train loss (x1e5) 97.2787 | Test loss (x1e5) 31.0849\n", "Epoch 16/100 | Train loss (x1e5) 94.6964 | Test loss (x1e5) 43.5958\n", "Epoch 17/100 | Train loss (x1e5) 91.4782 | Test loss (x1e5) 29.6458\n", "Epoch 18/100 | Train loss (x1e5) 90.4933 | Test loss (x1e5) 33.6109\n", "Epoch 19/100 | Train loss (x1e5) 86.9575 | Test loss (x1e5) 30.7852\n", "Epoch 20/100 | Train loss (x1e5) 88.1818 | Test loss (x1e5) 29.2189\n", "Epoch 21/100 | Train loss (x1e5) 84.4978 | Test loss (x1e5) 22.2148\n", "Epoch 22/100 | Train loss (x1e5) 84.4161 | Test loss (x1e5) 35.0264\n", "Epoch 23/100 | Train loss (x1e5) 80.5170 | Test loss (x1e5) 22.3276\n", "Epoch 24/100 | Train loss (x1e5) 80.5800 | Test loss (x1e5) 29.2161\n", "Epoch 25/100 | Train loss (x1e5) 77.0368 | Test loss (x1e5) 22.5613\n", "Epoch 26/100 | Train loss (x1e5) 77.6624 | Test loss (x1e5) 28.3152\n", "Epoch 27/100 | Train loss (x1e5) 75.9246 | Test loss (x1e5) 27.3747\n", "Epoch 28/100 | Train loss (x1e5) 75.1129 | Test loss (x1e5) 22.1203\n", "Epoch 29/100 | Train loss (x1e5) 73.9176 | Test loss (x1e5) 21.3708\n", "Epoch 30/100 | Train loss (x1e5) 73.8303 | Test loss (x1e5) 19.4375\n", "Epoch 31/100 | Train loss (x1e5) 71.4301 | Test loss (x1e5) 17.4285\n", "Epoch 32/100 | Train loss (x1e5) 71.6028 | Test loss (x1e5) 30.2202\n", "Epoch 33/100 | Train loss (x1e5) 70.7754 | Test loss (x1e5) 27.7524\n", "Epoch 34/100 | Train loss (x1e5) 70.7107 | Test loss (x1e5) 21.8441\n", "Epoch 35/100 | Train loss (x1e5) 67.9373 | Test loss (x1e5) 18.7597\n", "Epoch 36/100 | Train loss (x1e5) 66.5284 | Test loss (x1e5) 21.3425\n", "Epoch 37/100 | Train loss (x1e5) 69.2680 | Test loss (x1e5) 18.1248\n", "Epoch 38/100 | Train loss (x1e5) 66.0457 | Test loss (x1e5) 18.3276\n", "Epoch 39/100 | Train loss (x1e5) 66.4341 | Test loss (x1e5) 18.3668\n", "Epoch 40/100 | Train loss (x1e5) 64.8588 | Test loss (x1e5) 14.3345\n", "Epoch 41/100 | Train loss (x1e5) 65.0324 | Test loss (x1e5) 22.2836\n", "Epoch 42/100 | Train loss (x1e5) 64.2346 | Test loss (x1e5) 20.3639\n", "Epoch 43/100 | Train loss (x1e5) 63.5873 | Test loss (x1e5) 23.6401\n", "Epoch 44/100 | Train loss (x1e5) 63.5231 | Test loss (x1e5) 16.7890\n", "Epoch 45/100 | Train loss (x1e5) 62.1668 | Test loss (x1e5) 14.9163\n", "Epoch 46/100 | Train loss (x1e5) 60.7827 | Test loss (x1e5) 17.4896\n", "Epoch 47/100 | Train loss (x1e5) 62.4567 | Test loss (x1e5) 13.8604\n", "Epoch 48/100 | Train loss (x1e5) 61.4397 | Test loss (x1e5) 17.9461\n", "Epoch 49/100 | Train loss (x1e5) 60.8647 | Test loss (x1e5) 19.6511\n", "Epoch 50/100 | Train loss (x1e5) 60.5594 | Test loss (x1e5) 13.8324\n", "Epoch 51/100 | Train loss (x1e5) 59.3509 | Test loss (x1e5) 12.3390\n", "Epoch 52/100 | Train loss (x1e5) 60.1447 | Test loss (x1e5) 14.2438\n", "Epoch 53/100 | Train loss (x1e5) 59.2692 | Test loss (x1e5) 19.0007\n", "Epoch 54/100 | Train loss (x1e5) 59.7058 | Test loss (x1e5) 13.2120\n", "Epoch 55/100 | Train loss (x1e5) 59.4506 | Test loss (x1e5) 21.6260\n", "Epoch 56/100 | Train loss (x1e5) 59.7846 | Test loss (x1e5) 14.1505\n", "Epoch 57/100 | Train loss (x1e5) 59.1430 | Test loss (x1e5) 11.9013\n", "Epoch 58/100 | Train loss (x1e5) 57.7920 | Test loss (x1e5) 18.3841\n", "Epoch 59/100 | Train loss (x1e5) 58.3397 | Test loss (x1e5) 12.8296\n", "Epoch 60/100 | Train loss (x1e5) 56.5668 | Test loss (x1e5) 15.5571\n", "Epoch 61/100 | Train loss (x1e5) 57.3597 | Test loss (x1e5) 18.4371\n", "Epoch 62/100 | Train loss (x1e5) 58.0973 | Test loss (x1e5) 12.8372\n", "Epoch 63/100 | Train loss (x1e5) 56.9284 | Test loss (x1e5) 12.7845\n", "Epoch 64/100 | Train loss (x1e5) 55.6151 | Test loss (x1e5) 15.2220\n", "Epoch 65/100 | Train loss (x1e5) 56.3089 | Test loss (x1e5) 12.8679\n", "Epoch 66/100 | Train loss (x1e5) 57.0098 | Test loss (x1e5) 11.3940\n", "Epoch 67/100 | Train loss (x1e5) 55.8330 | Test loss (x1e5) 16.7325\n", "Epoch 68/100 | Train loss (x1e5) 55.6434 | Test loss (x1e5) 13.6628\n", "Epoch 69/100 | Train loss (x1e5) 55.7484 | Test loss (x1e5) 13.5055\n", "Epoch 70/100 | Train loss (x1e5) 54.4978 | Test loss (x1e5) 12.0881\n", "Epoch 71/100 | Train loss (x1e5) 55.6273 | Test loss (x1e5) 15.7260\n", "Epoch 72/100 | Train loss (x1e5) 55.3004 | Test loss (x1e5) 13.6676\n", "Epoch 73/100 | Train loss (x1e5) 53.8224 | Test loss (x1e5) 11.3997\n", "Epoch 74/100 | Train loss (x1e5) 54.8473 | Test loss (x1e5) 11.9121\n", "Epoch 75/100 | Train loss (x1e5) 54.4477 | Test loss (x1e5) 13.0294\n", "Epoch 76/100 | Train loss (x1e5) 54.2868 | Test loss (x1e5) 11.7455\n", "Epoch 77/100 | Train loss (x1e5) 53.2922 | Test loss (x1e5) 11.6593\n", "Epoch 78/100 | Train loss (x1e5) 52.6854 | Test loss (x1e5) 14.3897\n", "Epoch 79/100 | Train loss (x1e5) 52.8497 | Test loss (x1e5) 13.3810\n", "Epoch 80/100 | Train loss (x1e5) 53.2287 | Test loss (x1e5) 11.8646\n", "Epoch 81/100 | Train loss (x1e5) 54.0476 | Test loss (x1e5) 14.3325\n", "Epoch 82/100 | Train loss (x1e5) 53.6350 | Test loss (x1e5) 13.7774\n", "Epoch 83/100 | Train loss (x1e5) 53.7555 | Test loss (x1e5) 11.0659\n", "Epoch 84/100 | Train loss (x1e5) 53.1362 | Test loss (x1e5) 13.7008\n", "Epoch 85/100 | Train loss (x1e5) 52.4187 | Test loss (x1e5) 16.8150\n", "Epoch 86/100 | Train loss (x1e5) 52.6289 | Test loss (x1e5) 13.3272\n", "Epoch 87/100 | Train loss (x1e5) 51.3318 | Test loss (x1e5) 13.8505\n", "Epoch 88/100 | Train loss (x1e5) 53.0107 | Test loss (x1e5) 12.2213\n", "Epoch 89/100 | Train loss (x1e5) 51.8942 | Test loss (x1e5) 13.0393\n", "Epoch 90/100 | Train loss (x1e5) 51.7824 | Test loss (x1e5) 12.1874\n", "Epoch 91/100 | Train loss (x1e5) 51.4040 | Test loss (x1e5) 12.8719\n", "Epoch 92/100 | Train loss (x1e5) 50.7001 | Test loss (x1e5) 10.9069\n", "Epoch 93/100 | Train loss (x1e5) 51.5473 | Test loss (x1e5) 12.0467\n", "Epoch 94/100 | Train loss (x1e5) 50.8814 | Test loss (x1e5) 13.0899\n", "Epoch 95/100 | Train loss (x1e5) 51.5298 | Test loss (x1e5) 13.6348\n", "Epoch 96/100 | Train loss (x1e5) 51.8971 | Test loss (x1e5) 14.7821\n", "Epoch 97/100 | Train loss (x1e5) 51.7207 | Test loss (x1e5) 11.0060\n", "Epoch 98/100 | Train loss (x1e5) 51.4634 | Test loss (x1e5) 10.5973\n", "Epoch 99/100 | Train loss (x1e5) 51.1898 | Test loss (x1e5) 11.7234\n", "Epoch 100/100 | Train loss (x1e5) 50.5890 | Test loss (x1e5) 10.9737\n", "\n", "--------------------------------------------------\n", "Metrics on train data:\n", "--------------------------------------------------\n", "Rescale output: True\n", "\n", "Regression evaluator metrics:\n", "mse: 0.0015\n", "mae: 0.0246\n", "mre: 15.6336%\n", "ae_95: 0.0813\n", "ae_99: 0.1536\n", "r2: 0.9936\n", "l2_error: 0.0668\n", "--------------------------------------------------\n", "Metrics on test data:\n", "--------------------------------------------------\n", "Rescale output: True\n", "\n", "Regression evaluator metrics:\n", "mse: 0.0028\n", "mae: 0.0286\n", "mre: 17.6299%\n", "ae_95: 0.0993\n", "ae_99: 0.2377\n", "r2: 0.9885\n", "l2_error: 0.0895\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pipeline = Pipeline(\n", " train_dataset=train_dataset,\n", " test_dataset=test_dataset,\n", " model=model,\n", " training_params=training_params,\n", " evaluators=[evaluator],\n", " )\n", "pipeline.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To save the model:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model.save(path=str(CASE_DIR / \"models\"))" ] } ], "metadata": { "kernelspec": { "display_name": "cetaceo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }