{ "cells": [ { "cell_type": "raw", "id": "d045d53a", "metadata": {}, "source": [ "Run in Google Colab" ] }, { "cell_type": "markdown", "id": "c28251c6", "metadata": {}, "source": [ "# MLPClassifier and MLPRegressor in SciKeras\n", "\n", "SciKeras is a bridge between Keras and Scikit-Learn. As such, one of SciKeras' design goals is to be able to create a Scikit-Learn style estimator backed by Keras.\n", "\n", "This notebook implements an estimator that is analogous to `sklearn.neural_network.MLPClassifier` using Keras. This estimator should (for the most part) work as a drop-in replacement for `MLPClassifier`!\n", "\n", "## Table of contents\n", "\n", "* [1. Setup](#1.-Setup)\n", "* [2. Defining the Keras Model](#2.-Defining-the-Keras-Model)\n", " * [2.1 Inputs](#2.1-Inputs)\n", " * [2.2 Hidden layers](#2.2-Hidden-layers)\n", " * [2.3 Output layers](#2.3-Output-layers)\n", " * [2.4 Losses and optimizer](#2.4-Losses-and-optimizer)\n", " * [2.5 Wrapping with SciKeras](#2.5-Wrapping-with-SciKeras)\n", "* [3. Testing our classifier](#3.-Testing-our-classifier)\n", "* [4. Self contained MLPClassifier](#4.-Self-contained-MLPClassifier)\n", " * [4.1 Subclassing](#4.1-Subclassing)\n", "* [5. MLPRegressor](#5.-MLPRegressor)\n", "\n", "## 1. Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "24809c6f", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:15.249790Z", "iopub.status.busy": "2024-04-11T22:25:15.249490Z", "iopub.status.idle": "2024-04-11T22:25:17.504796Z", "shell.execute_reply": "2024-04-11T22:25:17.504154Z" } }, "outputs": [], "source": [ "try:\n", " import scikeras\n", "except ImportError:\n", " !python -m pip install scikeras" ] }, { "cell_type": "markdown", "id": "86ade6ed", "metadata": {}, "source": [ "Silence TensorFlow logging to keep output succinct." ] }, { "cell_type": "code", "execution_count": 2, "id": "b2b0662c", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.508346Z", "iopub.status.busy": "2024-04-11T22:25:17.507792Z", "iopub.status.idle": "2024-04-11T22:25:17.511589Z", "shell.execute_reply": "2024-04-11T22:25:17.510927Z" } }, "outputs": [], "source": [ "import warnings\n", "from tensorflow import get_logger\n", "get_logger().setLevel('ERROR')\n", "warnings.filterwarnings(\"ignore\", message=\"Setting the random state for TF\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "de49e173", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.513798Z", "iopub.status.busy": "2024-04-11T22:25:17.513603Z", "iopub.status.idle": "2024-04-11T22:25:17.840318Z", "shell.execute_reply": "2024-04-11T22:25:17.839695Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from scikeras.wrappers import KerasClassifier, KerasRegressor\n", "import keras" ] }, { "cell_type": "markdown", "id": "1783c293", "metadata": {}, "source": [ "## 2. Defining the Keras Model\n", "\n", "First, we outline our model building function, using a `Sequential` Model:" ] }, { "cell_type": "code", "execution_count": 4, "id": "fbcace67", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.844027Z", "iopub.status.busy": "2024-04-11T22:25:17.843416Z", "iopub.status.idle": "2024-04-11T22:25:17.846824Z", "shell.execute_reply": "2024-04-11T22:25:17.846144Z" } }, "outputs": [], "source": [ "def get_clf_model():\n", " model = keras.Sequential()\n", " return model" ] }, { "cell_type": "markdown", "id": "b2c1984c", "metadata": {}, "source": [ "### 2.1 Inputs\n", "\n", "We need to define an input layer for Keras. SciKeras allows you to dynamically determine the input size based on the features (`X`). To do this, you need to add the `meta` parameter to `get_clf_model`'s parameters. `meta` will be a dictionary with all of the `meta` attributes that `KerasClassifier` generates during the `fit` call, including `n_features_in_`, which we will use to dynamically size the input layer." ] }, { "cell_type": "code", "execution_count": 5, "id": "a7111c9f", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.849374Z", "iopub.status.busy": "2024-04-11T22:25:17.848886Z", "iopub.status.idle": "2024-04-11T22:25:17.852752Z", "shell.execute_reply": "2024-04-11T22:25:17.852094Z" } }, "outputs": [], "source": [ "from typing import Dict, Iterable, Any\n", "\n", "\n", "def get_clf_model(meta: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " model.add(inp)\n", " return model" ] }, { "cell_type": "markdown", "id": "e93cb5c4", "metadata": {}, "source": [ "### 2.2 Hidden Layers\n", "\n", "Multilayer perceptrons are generally composed of an input layer, an output layer and 0 or more hidden layers. The size of the hidden layers is specified via the `hidden_layer_sizes` parameter in MLClassifier, where the the ith element represents the number of neurons in the ith hidden layer. Let's add that parameter:" ] }, { "cell_type": "code", "execution_count": 6, "id": "5645b7c8", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.855556Z", "iopub.status.busy": "2024-04-11T22:25:17.855322Z", "iopub.status.idle": "2024-04-11T22:25:17.859429Z", "shell.execute_reply": "2024-04-11T22:25:17.858757Z" } }, "outputs": [], "source": [ "def get_clf_model(hidden_layer_sizes: Iterable[int], meta: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " model.add(inp)\n", " for hidden_layer_size in hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " return model" ] }, { "cell_type": "markdown", "id": "83631030", "metadata": {}, "source": [ "### 2.3 Output layers\n", "\n", "The output layer needs to reflect the type of classification task being performed. Here, we will handle 2 cases:\n", "\n", "- binary classification: single output unit with sigmoid activation\n", "- multiclass classification: one output unit for each class, with softmax activation\n", "The main complication arises from determining which one to use. Like with the input features, SciKeras provides useful information on the target within the `meta` parameter. Specifically, we will use the `n_classes_` and `target_type_` attributes to determine the number of output units and activation function." ] }, { "cell_type": "code", "execution_count": 7, "id": "2ae944e6", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.861868Z", "iopub.status.busy": "2024-04-11T22:25:17.861494Z", "iopub.status.idle": "2024-04-11T22:25:17.866160Z", "shell.execute_reply": "2024-04-11T22:25:17.865496Z" } }, "outputs": [], "source": [ "def get_clf_model(hidden_layer_sizes: Iterable[int], meta: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " model.add(inp)\n", " for hidden_layer_size in hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " if meta[\"target_type_\"] == \"binary\":\n", " n_output_units = 1\n", " output_activation = \"sigmoid\"\n", " elif meta[\"target_type_\"] == \"multiclass\":\n", " n_output_units = meta[\"n_classes_\"]\n", " output_activation = \"softmax\"\n", " else:\n", " raise NotImplementedError(f\"Unsupported task type: {meta['target_type_']}\")\n", " out = keras.layers.Dense(n_output_units, activation=output_activation)\n", " model.add(out)\n", " return model" ] }, { "cell_type": "markdown", "id": "be7bcc6a", "metadata": {}, "source": [ "For now, we raise a `NotImplementedError` for other target types. For an example handling multi-output target types, see the [Multi Output notebook](https://colab.research.google.com/github/adriangb/scikeras/blob/master/notebooks/MultiInput.ipynb).\n", "\n", "### 2.4 Losses and optimizer\n", "\n", "Like the output layer, the loss must match the type of classification task. Generally, it is easier and safet to allow SciKeras to compile your model for you by passing the loss to `KerasClassifier` directly (`KerasClassifier(loss=\"binary_crossentropy\")`). However, in order to implement custom logic around the choice of loss function, we compile the model ourselves within `get_clf_model`; SciKeras will not re-compile the model." ] }, { "cell_type": "code", "execution_count": 8, "id": "d0222012", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.868363Z", "iopub.status.busy": "2024-04-11T22:25:17.868169Z", "iopub.status.idle": "2024-04-11T22:25:17.873059Z", "shell.execute_reply": "2024-04-11T22:25:17.872403Z" } }, "outputs": [], "source": [ "def get_clf_model(hidden_layer_sizes: Iterable[int], meta: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " model.add(inp)\n", " for hidden_layer_size in hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " if meta[\"target_type_\"] == \"binary\":\n", " n_output_units = 1\n", " output_activation = \"sigmoid\"\n", " loss = \"binary_crossentropy\"\n", " elif meta[\"target_type_\"] == \"multiclass\":\n", " n_output_units = meta[\"n_classes_\"]\n", " output_activation = \"softmax\"\n", " loss = \"sparse_categorical_crossentropy\"\n", " else:\n", " raise NotImplementedError(f\"Unsupported task type: {meta['target_type_']}\")\n", " out = keras.layers.Dense(n_output_units, activation=output_activation)\n", " model.add(out)\n", " model.compile(loss=loss)\n", " return model" ] }, { "cell_type": "markdown", "id": "905a07ab", "metadata": {}, "source": [ "At this point, we have a valid, compiled model. However if we want to be able to tune the optimizer, we should accept `compile_kwargs` as a parameter in `get_clf_model`. `compile_kwargs` will be a dictionary containing valid `kwargs` for `Model.compile`, so we can unpack it directly like `model.compile(**compile_kwargs)`. In this case however, we will only be taking the `optimizer` kwarg." ] }, { "cell_type": "code", "execution_count": 9, "id": "a8bdca10", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.875409Z", "iopub.status.busy": "2024-04-11T22:25:17.875226Z", "iopub.status.idle": "2024-04-11T22:25:17.879879Z", "shell.execute_reply": "2024-04-11T22:25:17.879229Z" } }, "outputs": [], "source": [ "def get_clf_model(hidden_layer_sizes: Iterable[int], meta: Dict[str, Any], compile_kwargs: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " model.add(inp)\n", " for hidden_layer_size in hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " if meta[\"target_type_\"] == \"binary\":\n", " n_output_units = 1\n", " output_activation = \"sigmoid\"\n", " loss = \"binary_crossentropy\"\n", " elif meta[\"target_type_\"] == \"multiclass\":\n", " n_output_units = meta[\"n_classes_\"]\n", " output_activation = \"softmax\"\n", " loss = \"sparse_categorical_crossentropy\"\n", " else:\n", " raise NotImplementedError(f\"Unsupported task type: {meta['target_type_']}\")\n", " out = keras.layers.Dense(n_output_units, activation=output_activation)\n", " model.add(out)\n", " model.compile(loss=loss, optimizer=compile_kwargs[\"optimizer\"])\n", " return model" ] }, { "cell_type": "markdown", "id": "5a76408c", "metadata": {}, "source": [ "### 2.5 Wrapping with SciKeras\n", "\n", "Our last step in defining our model is to wrap it with SciKeras. A couple of things to note are:\n", "- Every user-defined parameter in `model`/`get_clf_model` (in our case just `hidden_layer_sizes`) must be defined as a keyword argument to `KerasClassifier` with a default value.\n", "- Keras defaults to `\"rmsprop\"` for `optimizer`. We set it to `\"adam\"` to mimic MLPClassifier.\n", "- We set the learning rate for the optimizer to `0.001`, again to mimic MLPClassifier. We set this parameter using [parameter routing](https://www.adriangb.com/scikeras/stable/advanced.html#routed-parameters).\n", "- Other parameters, such as `activation`, can be added similar to `hidden_layer_sizes`, but we omit them here for simplicity." ] }, { "cell_type": "code", "execution_count": 10, "id": "ac50ec5c", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.882093Z", "iopub.status.busy": "2024-04-11T22:25:17.881897Z", "iopub.status.idle": "2024-04-11T22:25:17.885155Z", "shell.execute_reply": "2024-04-11T22:25:17.884519Z" } }, "outputs": [], "source": [ "clf = KerasClassifier(\n", " model=get_clf_model,\n", " hidden_layer_sizes=(100, ),\n", " optimizer=\"adam\",\n", " optimizer__learning_rate=0.001,\n", " epochs=50,\n", " verbose=0,\n", ")" ] }, { "cell_type": "markdown", "id": "5818ec7c", "metadata": {}, "source": [ "## 3. Testing our classifier\n", "\n", "Before continouing, we will run a small test to make sure we get somewhat reasonable results." ] }, { "cell_type": "code", "execution_count": 11, "id": "5ec807b1", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:17.887776Z", "iopub.status.busy": "2024-04-11T22:25:17.887279Z", "iopub.status.idle": "2024-04-11T22:25:19.726667Z", "shell.execute_reply": "2024-04-11T22:25:19.725985Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.95\n" ] } ], "source": [ "from sklearn.datasets import make_classification\n", "\n", "\n", "X, y = make_classification()\n", "\n", "# check that fit works\n", "clf.fit(X, y)\n", "# check score\n", "print(clf.score(X, y))" ] }, { "cell_type": "markdown", "id": "d6c0b1d5", "metadata": {}, "source": [ "We get a score above 0.7, which is reasonable and indicates that our classifier is generally working.\n", "\n", "## 4. Self contained MLPClassifier\n", "\n", "You will have noticed that up until now, we define our Keras model in a function and pass that function to `KerasClassifier` via the `model` argument.\n", "\n", "This is convenient, but it does not give us a self-contained class that we could package within a module for users to instantiate. To do that, we need to subclass `KerasClassifier`.\n", "\n", "### 4.1 Subclassing\n", "\n", "By subclassing KerasClassifier, you can embed your Keras model into directly into your estimator class. We start by inheriting from KerasClassifier and defining an `__init__` method with all of our parameters." ] }, { "cell_type": "code", "execution_count": 12, "id": "21e02279", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:19.729946Z", "iopub.status.busy": "2024-04-11T22:25:19.729182Z", "iopub.status.idle": "2024-04-11T22:25:19.733824Z", "shell.execute_reply": "2024-04-11T22:25:19.733278Z" } }, "outputs": [], "source": [ "class MLPClassifier(KerasClassifier):\n", "\n", " def __init__(\n", " self,\n", " hidden_layer_sizes=(100, ),\n", " optimizer=\"adam\",\n", " optimizer__learning_rate=0.001,\n", " epochs=200,\n", " verbose=0,\n", " **kwargs\n", " ):\n", " super().__init__(**kwargs)\n", " self.hidden_layer_sizes = hidden_layer_sizes\n", " self.optimizer = optimizer\n", " self.epochs = epochs\n", " self.verbose = verbose" ] }, { "cell_type": "markdown", "id": "6ae32328", "metadata": {}, "source": [ "Next, we will embed our model into `_keras_build_fn`, which takes the place of `get_clf_model`. Note that since this is now an part of the model, we no longer need to accept the any parameters in the function signature. We still accept `compile_kwargs` because we use it to get the optimizer initialized with all of it's parameters." ] }, { "cell_type": "code", "execution_count": 13, "id": "86693c54", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:19.736453Z", "iopub.status.busy": "2024-04-11T22:25:19.736008Z", "iopub.status.idle": "2024-04-11T22:25:19.741854Z", "shell.execute_reply": "2024-04-11T22:25:19.741316Z" } }, "outputs": [], "source": [ "class MLPClassifier(KerasClassifier):\n", "\n", " def __init__(\n", " self,\n", " hidden_layer_sizes=(100, ),\n", " optimizer=\"adam\",\n", " optimizer__learning_rate=0.001,\n", " epochs=200,\n", " verbose=0,\n", " **kwargs,\n", " ):\n", " super().__init__(**kwargs)\n", " self.hidden_layer_sizes = hidden_layer_sizes\n", " self.optimizer = optimizer\n", " self.epochs = epochs\n", " self.verbose = verbose\n", "\n", " def _keras_build_fn(self, compile_kwargs: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(self.n_features_in_,))\n", " model.add(inp)\n", " for hidden_layer_size in self.hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " if self.target_type_ == \"binary\":\n", " n_output_units = 1\n", " output_activation = \"sigmoid\"\n", " loss = \"binary_crossentropy\"\n", " elif self.target_type_ == \"multiclass\":\n", " n_output_units = self.n_classes_\n", " output_activation = \"softmax\"\n", " loss = \"sparse_categorical_crossentropy\"\n", " else:\n", " raise NotImplementedError(f\"Unsupported task type: {self.target_type_}\")\n", " out = keras.layers.Dense(n_output_units, activation=output_activation)\n", " model.add(out)\n", " model.compile(loss=loss, optimizer=compile_kwargs[\"optimizer\"])\n", " return model" ] }, { "cell_type": "markdown", "id": "4bb9ab35", "metadata": {}, "source": [ "Let's check that our subclassed model works:" ] }, { "cell_type": "code", "execution_count": 14, "id": "c6c3641a", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:19.744508Z", "iopub.status.busy": "2024-04-11T22:25:19.744024Z", "iopub.status.idle": "2024-04-11T22:25:20.981280Z", "shell.execute_reply": "2024-04-11T22:25:20.980612Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.89\n" ] } ], "source": [ "clf = MLPClassifier(epochs=20) # for notebook execution time\n", "\n", "# check score\n", "print(clf.fit(X, y).score(X, y))" ] }, { "cell_type": "markdown", "id": "45b27d4f", "metadata": {}, "source": [ "## 5. MLPRegressor\n", "\n", "The process for MLPRegressor is similar, we only change the loss function and output layers." ] }, { "cell_type": "code", "execution_count": 15, "id": "36472e1a", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:20.984131Z", "iopub.status.busy": "2024-04-11T22:25:20.983670Z", "iopub.status.idle": "2024-04-11T22:25:20.992735Z", "shell.execute_reply": "2024-04-11T22:25:20.992019Z" }, "lines_to_next_cell": 2 }, "outputs": [], "source": [ "class MLPRegressor(KerasRegressor):\n", "\n", " def __init__(\n", " self,\n", " hidden_layer_sizes=(100, ),\n", " optimizer=\"adam\",\n", " optimizer__learning_rate=0.001,\n", " epochs=200,\n", " verbose=0,\n", " **kwargs,\n", " ):\n", " super().__init__(**kwargs)\n", " self.hidden_layer_sizes = hidden_layer_sizes\n", " self.optimizer = optimizer\n", " self.epochs = epochs\n", " self.verbose = verbose\n", "\n", " def _keras_build_fn(self, compile_kwargs: Dict[str, Any]):\n", " model = keras.Sequential()\n", " inp = keras.layers.Input(shape=(self.n_features_in_,))\n", " model.add(inp)\n", " for hidden_layer_size in self.hidden_layer_sizes:\n", " layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n", " model.add(layer)\n", " out = keras.layers.Dense(1)\n", " model.add(out)\n", " model.compile(loss=\"mse\", optimizer=compile_kwargs[\"optimizer\"])\n", " return model" ] }, { "cell_type": "code", "execution_count": 16, "id": "addaca13", "metadata": { "execution": { "iopub.execute_input": "2024-04-11T22:25:20.996246Z", "iopub.status.busy": "2024-04-11T22:25:20.995969Z", "iopub.status.idle": "2024-04-11T22:25:22.153150Z", "shell.execute_reply": "2024-04-11T22:25:22.152452Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-0.008032493811841235\n" ] } ], "source": [ "from sklearn.datasets import make_regression\n", "\n", "\n", "reg = MLPRegressor(epochs=20) # for notebook execution time\n", "\n", "# Define a simple linear relationship\n", "y = np.arange(100)\n", "X = (y/2).reshape(-1, 1)\n", "\n", "# check score\n", "reg.fit(X, y)\n", "print(reg.score(X, y))" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }