Phase-prediction-of-HEAs-private-share / SMOTETomeklink_augmented_data.ipynb
SMOTETomeklink_augmented_data.ipynb
Raw
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X5AMaKsOmkjy"
      },
      "source": [
        "## 1) Problem statement.\n",
        "Phase prediction of HEAs for SMOTETomek links augmented data"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pip install miceforest   ###for MICE imputer"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0BnBo3LO4u8o",
        "outputId": "d7c82579-4650-47a9-c5c8-02cb76d85426"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting miceforest\n",
            "  Downloading miceforest-5.6.3-py3-none-any.whl (57 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.0/58.0 KB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from miceforest) (1.21.6)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from miceforest) (0.3.6)\n",
            "Collecting lightgbm>=3.3.1\n",
            "  Downloading lightgbm-3.3.5-py3-none-manylinux1_x86_64.whl (2.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m34.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting blosc\n",
            "  Downloading blosc-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m59.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: wheel in /usr/local/lib/python3.8/dist-packages (from lightgbm>=3.3.1->miceforest) (0.38.4)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from lightgbm>=3.3.1->miceforest) (1.7.3)\n",
            "Requirement already satisfied: scikit-learn!=0.22.0 in /usr/local/lib/python3.8/dist-packages (from lightgbm>=3.3.1->miceforest) (1.0.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn!=0.22.0->lightgbm>=3.3.1->miceforest) (3.1.0)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn!=0.22.0->lightgbm>=3.3.1->miceforest) (1.2.0)\n",
            "Installing collected packages: blosc, lightgbm, miceforest\n",
            "  Attempting uninstall: lightgbm\n",
            "    Found existing installation: lightgbm 2.2.3\n",
            "    Uninstalling lightgbm-2.2.3:\n",
            "      Successfully uninstalled lightgbm-2.2.3\n",
            "Successfully installed blosc-1.11.1 lightgbm-3.3.5 miceforest-5.6.3\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OYaew9gYmkj1"
      },
      "source": [
        "## 2) Import required libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SG7iLs0Zmkj1"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "from statistics import mean\n",
        "import matplotlib.pyplot as plt\n",
        "import warnings\n",
        "from sklearn.preprocessing import PowerTransformer\n",
        "import numpy as np\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "from sklearn.pipeline import Pipeline\n",
        "from sklearn.utils import resample\n",
        "\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier\n",
        "from sklearn.neighbors import KNeighborsClassifier\n",
        "from sklearn.tree import DecisionTreeClassifier\n",
        "from sklearn.svm import SVC\n",
        "from sklearn.metrics import accuracy_score, classification_report,ConfusionMatrixDisplay, \\\n",
        "                            precision_score, recall_score, f1_score, roc_auc_score,roc_curve,confusion_matrix\n",
        "\n",
        "\n",
        "from sklearn import metrics \n",
        "from sklearn.model_selection import  train_test_split, RepeatedStratifiedKFold, cross_val_score\n",
        "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n",
        "from sklearn.compose import ColumnTransformer\n",
        "from sklearn.impute import SimpleImputer, KNNImputer\n",
        "from xgboost import XGBClassifier\n",
        "from sklearn.preprocessing import StandardScaler, MinMaxScaler,RobustScaler\n",
        "from sklearn.compose import ColumnTransformer\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Do4oeoyNmkj3"
      },
      "source": [
        "### Read Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W_8SVbfwmkj3"
      },
      "outputs": [],
      "source": [
        "# Load csv file\n",
        "df3 = pd.read_excel('Phase_data.xlsx', na_values=\"na\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "miijRwgmmkj4",
        "outputId": "20e57ed6-cde0-418b-a511-05538e205794"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1200, 36)"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ],
      "source": [
        "# check rows and columns of the dataset\n",
        "df3.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 235
        },
        "id": "ANWg7y2hmkj4",
        "outputId": "ac28407e-dec2-40f3-a2a3-b4b7dbf790ec"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "      Al  Co  Cr  Fe  Ni  Cu  Mn     Ti      V     Nb  ...  Pt   Y  Pd  Au  \\\n",
              "0  0.111 NaN NaN NaN NaN NaN NaN  0.222  0.222  0.222  ... NaN NaN NaN NaN   \n",
              "1  0.158 NaN NaN NaN NaN NaN NaN  0.215  0.215  0.215  ... NaN NaN NaN NaN   \n",
              "2  0.588 NaN NaN NaN NaN NaN NaN  0.235  0.235  0.235  ... NaN NaN NaN NaN   \n",
              "3  0.588 NaN NaN NaN NaN NaN NaN  0.235  0.235  0.235  ... NaN NaN NaN NaN   \n",
              "4  0.476 NaN NaN NaN NaN NaN NaN  0.239  0.239    NaN  ... NaN NaN NaN NaN   \n",
              "\n",
              "   dHmix   dSmix      δ     ᐃχ    VEC  Phases  \n",
              "0 -8.395  13.146  3.738  0.050  4.556     BCC  \n",
              "1 -9.352  13.333  3.863  0.233  4.684     BCC  \n",
              "2 -4.042  12.708  4.003  0.243  4.882     BCC  \n",
              "3 -4.817  12.708  3.832  0.050  4.647     BCC  \n",
              "4 -3.356  12.569  4.018  0.244  4.905     BCC  \n",
              "\n",
              "[5 rows x 36 columns]"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-b44f0d14-2622-433c-a1a7-4e22a96a97f3\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Al</th>\n",
              "      <th>Co</th>\n",
              "      <th>Cr</th>\n",
              "      <th>Fe</th>\n",
              "      <th>Ni</th>\n",
              "      <th>Cu</th>\n",
              "      <th>Mn</th>\n",
              "      <th>Ti</th>\n",
              "      <th>V</th>\n",
              "      <th>Nb</th>\n",
              "      <th>...</th>\n",
              "      <th>Pt</th>\n",
              "      <th>Y</th>\n",
              "      <th>Pd</th>\n",
              "      <th>Au</th>\n",
              "      <th>dHmix</th>\n",
              "      <th>dSmix</th>\n",
              "      <th>δ</th>\n",
              "      <th>ᐃχ</th>\n",
              "      <th>VEC</th>\n",
              "      <th>Phases</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.111</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>0.222</td>\n",
              "      <td>0.222</td>\n",
              "      <td>0.222</td>\n",
              "      <td>...</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>-8.395</td>\n",
              "      <td>13.146</td>\n",
              "      <td>3.738</td>\n",
              "      <td>0.050</td>\n",
              "      <td>4.556</td>\n",
              "      <td>BCC</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.158</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>0.215</td>\n",
              "      <td>0.215</td>\n",
              "      <td>0.215</td>\n",
              "      <td>...</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>-9.352</td>\n",
              "      <td>13.333</td>\n",
              "      <td>3.863</td>\n",
              "      <td>0.233</td>\n",
              "      <td>4.684</td>\n",
              "      <td>BCC</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.588</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>0.235</td>\n",
              "      <td>0.235</td>\n",
              "      <td>0.235</td>\n",
              "      <td>...</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>-4.042</td>\n",
              "      <td>12.708</td>\n",
              "      <td>4.003</td>\n",
              "      <td>0.243</td>\n",
              "      <td>4.882</td>\n",
              "      <td>BCC</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.588</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>0.235</td>\n",
              "      <td>0.235</td>\n",
              "      <td>0.235</td>\n",
              "      <td>...</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>-4.817</td>\n",
              "      <td>12.708</td>\n",
              "      <td>3.832</td>\n",
              "      <td>0.050</td>\n",
              "      <td>4.647</td>\n",
              "      <td>BCC</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.476</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>0.239</td>\n",
              "      <td>0.239</td>\n",
              "      <td>NaN</td>\n",
              "      <td>...</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>NaN</td>\n",
              "      <td>-3.356</td>\n",
              "      <td>12.569</td>\n",
              "      <td>4.018</td>\n",
              "      <td>0.244</td>\n",
              "      <td>4.905</td>\n",
              "      <td>BCC</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>5 rows × 36 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b44f0d14-2622-433c-a1a7-4e22a96a97f3')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-b44f0d14-2622-433c-a1a7-4e22a96a97f3 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-b44f0d14-2622-433c-a1a7-4e22a96a97f3');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "df3.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EzuTamjymkj4",
        "outputId": "d09c6fb6-5d96-413b-ac7d-63ec6f6cae4d"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MIP        441\n",
              "BCC        372\n",
              "FCC        220\n",
              "FCC_BCC    167\n",
              "Name: Phases, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "# Check unique values of target varaible\n",
        "df3['Phases'].value_counts()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HHD7hYOzmkj5",
        "outputId": "20dcf07a-f13b-42bd-d82d-59de761e42fa"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "We have 35 numerical features : ['Al', 'Co', 'Cr', 'Fe', 'Ni', 'Cu', 'Mn', 'Ti', 'V', 'Nb', 'Mo', 'Zr', 'Hf', 'Ta', 'W', 'C', 'Mg', 'Zn', 'Si', 'Re', 'N', 'Li', 'Sn', 'Be', 'B', 'Ag', 'Pt', 'Y', 'Pd', 'Au', 'dHmix', 'dSmix', 'δ', 'ᐃχ', 'VEC']\n",
            "\n",
            "We have 1 categorical features : ['Phases']\n"
          ]
        }
      ],
      "source": [
        "# define numerical & categorical columns\n",
        "numeric_features = [feature for feature in df3.columns if df3[feature].dtype != 'O']\n",
        "categorical_features = [feature for feature in df3.columns if df3[feature].dtype == 'O']\n",
        "\n",
        "# print columns\n",
        "print('We have {} numerical features : {}'.format(len(numeric_features), numeric_features))\n",
        "print('\\nWe have {} categorical features : {}'.format(len(categorical_features), categorical_features))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_GxWn4Ukmkj6"
      },
      "source": [
        "### Checking missing values"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8OSwnxNPmkj6",
        "outputId": "7adda5ea-14b2-445a-971c-fcc5af1d61ec"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Al         515\n",
              "Co         514\n",
              "Cr         442\n",
              "Fe         386\n",
              "Ni         378\n",
              "Cu         858\n",
              "Mn         971\n",
              "Ti         754\n",
              "V          954\n",
              "Nb         882\n",
              "Mo         969\n",
              "Zr         964\n",
              "Hf        1081\n",
              "Ta        1074\n",
              "W         1164\n",
              "C         1158\n",
              "Mg        1161\n",
              "Zn        1165\n",
              "Si        1137\n",
              "Re        1199\n",
              "N         1194\n",
              "Li        1179\n",
              "Sn        1160\n",
              "Be        1199\n",
              "B         1196\n",
              "Ag        1199\n",
              "Pt        1199\n",
              "Y         1194\n",
              "Pd        1198\n",
              "Au        1199\n",
              "dHmix        0\n",
              "dSmix        0\n",
              "δ            0\n",
              "ᐃχ           0\n",
              "VEC          0\n",
              "Phases       0\n",
              "dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "df3.isnull().sum()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "feature_names =  ['Al', 'Co', 'Cr', 'Fe', 'Ni', 'Cu', 'Mn', 'Ti', 'V', 'Nb', 'Mo', 'Zr',\n",
        "       'Hf', 'Ta', 'W', 'C', 'Mg', 'Zn', 'Si', 'Re', 'N', 'Li', 'Sn', 'Be',\n",
        "       'B', 'Ag', 'Pt', 'Y', 'Pd', 'Au', 'dHmix', 'dSmix', 'δ',\n",
        "       'ᐃχ', 'VEC', 'Phases']"
      ],
      "metadata": {
        "id": "Bqb2WLli8aio"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# sns.catplot(data=df33, x=\"Phases\", kind=\"count\", palette=\"winter_r\")\n",
        "# plt.show()\n",
        "# sns.set_style('whitegrid')\n",
        "# sns.set(font_scale = 1.2)\n",
        "\n",
        "# df3['Phases'].value_counts()\n",
        "\n",
        "# p= sns.countplot(df3['Phases'], dodge=False , hatch='/',palette=\"Accent\", ec='black')\n",
        "# # sns.set(rc={'figure.figsize': (8,10)})\n",
        "\n",
        "# p.set_xlabel(\"Phases\")\n",
        "# p.set_ylabel(\"Samples\")"
      ],
      "metadata": {
        "id": "lv0NNLuwynDH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df3['Phases'].replace({'MIP':0, 'BCC':1, 'FCC':2, 'FCC_BCC':3}, inplace=True) "
      ],
      "metadata": {
        "id": "gana74hF9F6V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.impute import KNNImputer\n",
        "from sklearn.pipeline import Pipeline\n",
        "imputer = KNNImputer(n_neighbors=10)\n",
        "df_imp = imputer.fit_transform(df3)\n",
        "df_afterimp =pd.DataFrame(df_imp, columns=feature_names[0:36])"
      ],
      "metadata": {
        "id": "-Mm5xHNY8Eze"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "outlier removal after missing value imputation\n"
      ],
      "metadata": {
        "id": "yNVhMH0m_hu-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "df = df_afterimp.copy()"
      ],
      "metadata": {
        "id": "HkRVz8WW_n_U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def IQR_capping(df_afterimp, cols, factor):\n",
        "    \n",
        "    for col in cols:\n",
        "        Q1 = df_afterimp[col].quantile(0.10)\n",
        "        Q3 = df_afterimp[col].quantile(0.90)\n",
        "        IQR=Q3-Q1\n",
        "\n",
        "        lower_boundary = Q1-(factor*IQR)\n",
        "        upper_boundary = Q3+(factor*IQR)\n",
        "        \n",
        "\n",
        "        df_afterimp[col] = np.where(df_afterimp[col]>upper_boundary, upper_boundary, np.where(df_afterimp[col]<lower_boundary, lower_boundary, df_afterimp[col]))"
      ],
      "metadata": {
        "id": "mS450J8m_ux1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "IQR_capping(df, feature_names, 1.5)"
      ],
      "metadata": {
        "id": "eS3oxy9-_v3r"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# for col in feature_names:\n",
        "#     plt.figure(figsize=(16,4))\n",
        "\n",
        "#     plt.subplot(141)\n",
        "#     sns.distplot(df_afterimp[col], label='skew: '+ str(np.round(df_afterimp[col].skew(),2)))\n",
        "#     plt.title('Before')\n",
        "#     plt.legend()\n",
        "\n",
        "#     plt.subplot(142)\n",
        "#     sns.distplot(df[col], label ='skew '+ str(np.round(df[col].skew(),2)))\n",
        "#     plt.title('After')\n",
        "#     plt.legend()\n",
        "\n",
        "#     plt.subplot(143)\n",
        "#     sns.boxplot(df_afterimp[col])\n",
        "#     plt.title('Before')\n",
        "\n",
        "#     plt.subplot(144)\n",
        "#     sns.boxplot(df[col])\n",
        "#     plt.title('After')\n",
        "#     plt.tight_layout()\n",
        "#     plt.show()"
      ],
      "metadata": {
        "id": "CfPevh6E_y8s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e3SW6fs3mkj8"
      },
      "source": [
        "## Visualization of unique values in Target variable"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pRzK0OL9mkj9"
      },
      "outputs": [],
      "source": [
        "## Visualization of unique values in Target variable\n",
        "# sns.catplot(data=df, x=\"Phases\", kind=\"count\", palette=\"winter_r\", alpha=.6)\n",
        "# plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N5RYFNGjmkj9"
      },
      "source": [
        "## Create Functions for model training and evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "snr8L-EVmkj9"
      },
      "outputs": [],
      "source": [
        "def evaluate_clf(true, predicted):\n",
        "    '''\n",
        "    This function takes in true values and predicted values\n",
        "    Returns: Accuracy, F1-Score, Precision, Recall\n",
        "    '''\n",
        "    acc = accuracy_score(true, predicted) # Calculate Accuracy\n",
        "    f1 = f1_score(true, predicted,average='weighted') # Calculate F1-score\n",
        "    precision = precision_score(true, predicted, average='weighted') # Calculate Precision\n",
        "    recall = recall_score(true, predicted, average='weighted')  # Calculate Recall\n",
        "   \n",
        "    return acc, f1 , precision, recall"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qSI82B3xmkj-"
      },
      "outputs": [],
      "source": [
        "# Create a function which can evaluate models and return a report \n",
        "def evaluate_models(X, y, models):\n",
        "    '''\n",
        "    This function takes in X and y and models dictionary as input\n",
        "    It splits the data into Train Test split\n",
        "    Iterates through the given model dictionary and evaluates the metrics\n",
        "    Returns: Dataframe which contains report of all models metrics\n",
        "    '''\n",
        "    # separate dataset into train and test\n",
        "    X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=42)\n",
        "    \n",
        "    cost_list=[]\n",
        "    models_list = []\n",
        "    accuracy_list = []\n",
        "    \n",
        "    for i in range(len(list(models))):\n",
        "        model = list(models.values())[i]\n",
        "        model.fit(X_train, y_train) # Train model\n",
        "\n",
        "        # Make predictions\n",
        "        y_train_pred = model.predict(X_train)\n",
        "        y_test_pred = model.predict(X_test)\n",
        "\n",
        "        # Training set performance\n",
        "        model_train_accuracy, model_train_f1,model_train_precision,\\\n",
        "        model_train_recall=evaluate_clf(y_train ,y_train_pred)\n",
        "\n",
        "\n",
        "        # Test set performance\n",
        "        model_test_accuracy,model_test_f1,model_test_precision,\\\n",
        "        model_test_recall=evaluate_clf(y_test, y_test_pred)\n",
        "       \n",
        "\n",
        "        print(list(models.keys())[i])\n",
        "        models_list.append(list(models.keys())[i])\n",
        "\n",
        "        print('Model performance for Training set')\n",
        "        print(\"- Accuracy: {:.4f}\".format(model_train_accuracy))\n",
        "        print('- F1 score: {:.4f}'.format(model_train_f1)) \n",
        "        print('- Precision: {:.4f}'.format(model_train_precision))\n",
        "        print('- Recall: {:.4f}'.format(model_train_recall))\n",
        "       \n",
        "\n",
        "        print('----------------------------------')\n",
        "\n",
        "        print('Model performance for Test set')\n",
        "        print('- Accuracy: {:.4f}'.format(model_test_accuracy))\n",
        "        print('- F1 score: {:.4f}'.format(model_test_f1))\n",
        "        print('- Precision: {:.4f}'.format(model_test_precision))\n",
        "        print('- Recall: {:.4f}'.format(model_test_recall))\n",
        "       \n",
        "        print('='*35)\n",
        "        print('\\n')\n",
        "        \n",
        "    report=pd.DataFrame(list(zip(models_list, cost_list)), columns=['Model Name', 'Cost']).sort_values(by=[\"Cost\"])\n",
        "        \n",
        "    return report"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5k1CaW9cmkj-"
      },
      "source": [
        "### Plot  distribution of all Independent Numerical variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YvJ2TGK9mkj-"
      },
      "outputs": [],
      "source": [
        "# numeric_features = [feature for feature in df.columns if df[feature].dtype != 'O']\n",
        "\n",
        "# plt.figure(figsize=(15, 100))\n",
        "# for i, col in enumerate(numeric_features):\n",
        "#     plt.subplot(60, 3, i+1)\n",
        "#     sns.distplot(x=df[col], color='indianred')\n",
        "#     plt.xlabel(col, weight='bold')\n",
        "#     plt.tight_layout()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "df2= df.copy()\n",
        "Al_trans = np.log(df2['Al'])\n",
        "Co_trans = np.log(df2['Co'])\n",
        "Cr_trans = np.log(df2['Cr'])\n",
        "Fe_trans = np.log(df2['Fe'])\n",
        "Ni_trans = np.log(df2['Ni'])\n",
        "Cu_trans = np.log(df2['Cu'])\n",
        "Mn_trans = np.log(df2['Mn'])\n",
        "Ti_trans = np.log(df2['Ti'])\n",
        "V_trans = np.log(df2['V'])\n",
        "Nb_trans = np.log(df2['Nb'])\n",
        "Mo_trans = np.log(df2['Mo'])\n",
        "Zr_trans = np.log(df2['Zr'])\n",
        "Hf_trans = np.log(df2['Hf'])\n",
        "Ta_trans = np.log(df2['Ta'])\n",
        "W_trans = np.log(df2['W'])\n",
        "C_trans = np.log(df2['C'])\n",
        "Mg_trans = np.log(df2['Mg'])\n",
        "Zn_trans = np.log(df2['Zn'])\n",
        "Si_trans = np.log(df2['Si'])\n",
        "Re_trans = np.log(df2['Re'])\n",
        "N_trans = np.log(df2['N'])\n",
        "Li_trans = np.log(df2['Li'])\n",
        "Sn_trans = np.log(df2['Sn'])\n",
        "Be_trans = np.log(df2['Sn'])\n",
        "B_trans = np.log(df2['B'])\n",
        "Ag_trans = np.log(df2['Ag'])\n",
        "Pt_trans = np.log(df2['Pt'])\n",
        "Y_trans = np.log(df2['Y'])\n",
        "Pd_trans = np.log(df2['Pd'])\n",
        "Au_trans = np.log(df2['Au'])\n",
        "#dHmix_trans = np.log(df2['dHmix'])\n",
        "dSmix_trans = np.log(df2['dSmix'])\n",
        "Atom_Size_Diff_trans = np.log(df2['δ'])\n",
        "Elect_Diff_trans = np.log(df2['ᐃχ'])\n",
        "VEC_trans = np.log(df2['VEC']) "
      ],
      "metadata": {
        "id": "xhftSkwu2TuH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_final = pd.DataFrame(pd.concat([Al_trans, Co_trans, Cr_trans, Fe_trans, Ni_trans, Cu_trans, Mn_trans, \n",
        "Ti_trans, V_trans,Nb_trans, Mo_trans, Zr_trans, Hf_trans, Ta_trans, W_trans, C_trans, Mg_trans, Zn_trans, Si_trans, Re_trans, N_trans, \n",
        "Li_trans, Sn_trans, Be_trans, B_trans, Ag_trans, Pt_trans, Y_trans, Pd_trans, Au_trans, df2['dHmix'], dSmix_trans, Atom_Size_Diff_trans, \n",
        "Elect_Diff_trans, VEC_trans, df2['Phases']], axis=1))\n",
        "df_final"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 487
        },
        "id": "YqyteeKkBzb2",
        "outputId": "3681e51b-f3c8-4f22-ee69-a4851d9b19df"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "            Al        Co        Cr        Fe        Ni        Cu        Mn  \\\n",
              "0    -2.198225 -1.445195 -0.998586 -1.275111 -1.738977 -1.192044 -1.534794   \n",
              "1    -1.845160 -1.582796 -1.289894 -1.427116 -1.581823 -1.543650 -1.541312   \n",
              "2    -0.531028 -1.595535 -1.283377 -1.352473 -1.621005 -0.931404 -1.552113   \n",
              "3    -0.531028 -1.581336 -1.375552 -1.322381 -1.565421 -1.119325 -1.539446   \n",
              "4    -0.742337 -1.716466 -1.388296 -1.358679 -1.543650 -0.886004 -1.566378   \n",
              "...        ...       ...       ...       ...       ...       ...       ...   \n",
              "1195 -1.944911 -1.944911 -1.944911 -1.944911 -1.944911 -1.944911 -1.308223   \n",
              "1196 -1.697723 -1.789761 -1.789761 -1.789761 -1.789761 -0.813509 -1.789761   \n",
              "1197 -1.002666 -1.789761 -1.789761 -1.789761 -1.789761 -1.085301 -1.789761   \n",
              "1198 -1.128247 -1.944911 -1.944911 -1.944911 -1.944911 -1.944911 -1.810942   \n",
              "1199 -0.877070 -1.609438 -1.505979 -1.609438 -0.916291 -1.488106 -1.609438   \n",
              "\n",
              "            Ti         V        Nb  ...        Pt         Y        Pd  \\\n",
              "0    -1.505078 -1.505078 -1.505078  ... -1.609438 -1.806873 -1.203973   \n",
              "1    -1.537117 -1.537117 -1.537117  ... -1.609438 -1.806873 -1.203973   \n",
              "2    -1.448170 -1.448170 -1.448170  ... -1.609438 -1.806873 -1.203973   \n",
              "3    -1.448170 -1.448170 -1.448170  ... -1.609438 -1.806873 -1.203973   \n",
              "4    -1.431292 -1.431292 -1.480166  ... -1.609438 -1.806873 -1.203973   \n",
              "...        ...       ...       ...  ...       ...       ...       ...   \n",
              "1195 -0.604770 -1.061895 -0.568631  ... -1.609438 -1.806873 -1.203973   \n",
              "1196 -1.864330 -1.101115 -1.789761  ... -1.609438 -1.806873 -1.203973   \n",
              "1197 -1.789761 -1.113827 -1.515948  ... -1.609438 -1.806873 -1.203973   \n",
              "1198 -1.251763 -1.013077 -1.197659  ... -1.609438 -1.806873 -1.203973   \n",
              "1199 -0.909315 -1.855978 -0.944176  ... -1.609438 -1.806873 -1.203973   \n",
              "\n",
              "            Au   dHmix     dSmix         δ        ᐃχ       VEC  Phases  \n",
              "0    -1.789761  -8.395  2.576118  1.318551 -2.995732  1.516445     1.0  \n",
              "1    -1.789761  -9.352  2.590242  1.351444 -1.456717  1.544152     1.0  \n",
              "2    -1.789761  -4.042  2.542232  1.387044 -1.414694  1.585555     1.0  \n",
              "3    -1.789761  -4.817  2.542232  1.343387 -2.995732  1.536222     1.0  \n",
              "4    -1.789761  -3.356  2.531233  1.390784 -1.410587  1.590255     1.0  \n",
              "...        ...     ...       ...       ...       ...       ...     ...  \n",
              "1195 -1.789761 -18.857  2.783467  1.812542 -2.154165  1.985955     0.0  \n",
              "1196 -1.789761 -12.000  2.700958  1.703839 -1.958995  2.014903     0.0  \n",
              "1197 -1.789761 -13.444  2.700958  1.839438 -1.883875  1.992385     0.0  \n",
              "1198 -1.789761 -14.041  2.676147  1.979759 -1.864330  2.005391     0.0  \n",
              "1199 -1.789761  -4.160  2.404239  1.190888 -1.994365  2.174752     2.0  \n",
              "\n",
              "[1200 rows x 36 columns]"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-c73e2d5d-6b28-46f1-ab87-71c17cc63a90\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Al</th>\n",
              "      <th>Co</th>\n",
              "      <th>Cr</th>\n",
              "      <th>Fe</th>\n",
              "      <th>Ni</th>\n",
              "      <th>Cu</th>\n",
              "      <th>Mn</th>\n",
              "      <th>Ti</th>\n",
              "      <th>V</th>\n",
              "      <th>Nb</th>\n",
              "      <th>...</th>\n",
              "      <th>Pt</th>\n",
              "      <th>Y</th>\n",
              "      <th>Pd</th>\n",
              "      <th>Au</th>\n",
              "      <th>dHmix</th>\n",
              "      <th>dSmix</th>\n",
              "      <th>δ</th>\n",
              "      <th>ᐃχ</th>\n",
              "      <th>VEC</th>\n",
              "      <th>Phases</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>-2.198225</td>\n",
              "      <td>-1.445195</td>\n",
              "      <td>-0.998586</td>\n",
              "      <td>-1.275111</td>\n",
              "      <td>-1.738977</td>\n",
              "      <td>-1.192044</td>\n",
              "      <td>-1.534794</td>\n",
              "      <td>-1.505078</td>\n",
              "      <td>-1.505078</td>\n",
              "      <td>-1.505078</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-8.395</td>\n",
              "      <td>2.576118</td>\n",
              "      <td>1.318551</td>\n",
              "      <td>-2.995732</td>\n",
              "      <td>1.516445</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>-1.845160</td>\n",
              "      <td>-1.582796</td>\n",
              "      <td>-1.289894</td>\n",
              "      <td>-1.427116</td>\n",
              "      <td>-1.581823</td>\n",
              "      <td>-1.543650</td>\n",
              "      <td>-1.541312</td>\n",
              "      <td>-1.537117</td>\n",
              "      <td>-1.537117</td>\n",
              "      <td>-1.537117</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-9.352</td>\n",
              "      <td>2.590242</td>\n",
              "      <td>1.351444</td>\n",
              "      <td>-1.456717</td>\n",
              "      <td>1.544152</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>-0.531028</td>\n",
              "      <td>-1.595535</td>\n",
              "      <td>-1.283377</td>\n",
              "      <td>-1.352473</td>\n",
              "      <td>-1.621005</td>\n",
              "      <td>-0.931404</td>\n",
              "      <td>-1.552113</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-4.042</td>\n",
              "      <td>2.542232</td>\n",
              "      <td>1.387044</td>\n",
              "      <td>-1.414694</td>\n",
              "      <td>1.585555</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>-0.531028</td>\n",
              "      <td>-1.581336</td>\n",
              "      <td>-1.375552</td>\n",
              "      <td>-1.322381</td>\n",
              "      <td>-1.565421</td>\n",
              "      <td>-1.119325</td>\n",
              "      <td>-1.539446</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>-1.448170</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-4.817</td>\n",
              "      <td>2.542232</td>\n",
              "      <td>1.343387</td>\n",
              "      <td>-2.995732</td>\n",
              "      <td>1.536222</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>-0.742337</td>\n",
              "      <td>-1.716466</td>\n",
              "      <td>-1.388296</td>\n",
              "      <td>-1.358679</td>\n",
              "      <td>-1.543650</td>\n",
              "      <td>-0.886004</td>\n",
              "      <td>-1.566378</td>\n",
              "      <td>-1.431292</td>\n",
              "      <td>-1.431292</td>\n",
              "      <td>-1.480166</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-3.356</td>\n",
              "      <td>2.531233</td>\n",
              "      <td>1.390784</td>\n",
              "      <td>-1.410587</td>\n",
              "      <td>1.590255</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1195</th>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.308223</td>\n",
              "      <td>-0.604770</td>\n",
              "      <td>-1.061895</td>\n",
              "      <td>-0.568631</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-18.857</td>\n",
              "      <td>2.783467</td>\n",
              "      <td>1.812542</td>\n",
              "      <td>-2.154165</td>\n",
              "      <td>1.985955</td>\n",
              "      <td>0.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1196</th>\n",
              "      <td>-1.697723</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-0.813509</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.864330</td>\n",
              "      <td>-1.101115</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-12.000</td>\n",
              "      <td>2.700958</td>\n",
              "      <td>1.703839</td>\n",
              "      <td>-1.958995</td>\n",
              "      <td>2.014903</td>\n",
              "      <td>0.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1197</th>\n",
              "      <td>-1.002666</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.085301</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-1.113827</td>\n",
              "      <td>-1.515948</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-13.444</td>\n",
              "      <td>2.700958</td>\n",
              "      <td>1.839438</td>\n",
              "      <td>-1.883875</td>\n",
              "      <td>1.992385</td>\n",
              "      <td>0.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1198</th>\n",
              "      <td>-1.128247</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.944911</td>\n",
              "      <td>-1.810942</td>\n",
              "      <td>-1.251763</td>\n",
              "      <td>-1.013077</td>\n",
              "      <td>-1.197659</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-14.041</td>\n",
              "      <td>2.676147</td>\n",
              "      <td>1.979759</td>\n",
              "      <td>-1.864330</td>\n",
              "      <td>2.005391</td>\n",
              "      <td>0.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1199</th>\n",
              "      <td>-0.877070</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.505979</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-0.916291</td>\n",
              "      <td>-1.488106</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-0.909315</td>\n",
              "      <td>-1.855978</td>\n",
              "      <td>-0.944176</td>\n",
              "      <td>...</td>\n",
              "      <td>-1.609438</td>\n",
              "      <td>-1.806873</td>\n",
              "      <td>-1.203973</td>\n",
              "      <td>-1.789761</td>\n",
              "      <td>-4.160</td>\n",
              "      <td>2.404239</td>\n",
              "      <td>1.190888</td>\n",
              "      <td>-1.994365</td>\n",
              "      <td>2.174752</td>\n",
              "      <td>2.0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>1200 rows × 36 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c73e2d5d-6b28-46f1-ab87-71c17cc63a90')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-c73e2d5d-6b28-46f1-ab87-71c17cc63a90 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-c73e2d5d-6b28-46f1-ab87-71c17cc63a90');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 911
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IcaPyF-cmkj-"
      },
      "source": [
        "# Evaluate Model on Different experiments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZLnlk8lLmkj_"
      },
      "outputs": [],
      "source": [
        "# Splitting X and y for all Experiments\n",
        "X= df_final.drop('Phases', axis=1)\n",
        "y = df_final['Phases']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "19sBusUMmkj_"
      },
      "outputs": [],
      "source": [
        "y= y.replace({'MIP': 0, 'BCC': 1,'FCC': 2,'FCC_BCC': 3})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gYETL8wmmkj_"
      },
      "source": [
        "#### All datset is not normally distributed so I am going for robust scaler, if you want you can do the log trabnsfrmation and then standard scaler to see the result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AGRi1lDdmkj_"
      },
      "outputs": [],
      "source": [
        "# Fit with robust scaler for KNN best K-selection experminet\n",
        "robustscaler = RobustScaler()\n",
        "X1 = robustscaler.fit_transform(X)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K5UXXvl0mkj_"
      },
      "outputs": [],
      "source": [
        "results=[]\n",
        "# define imputer\n",
        "imputer = KNNImputer(n_neighbors=5, weights='uniform', metric='nan_euclidean')\n",
        "strategies = [str(i) for i in [1,3,5,7,9]]\n",
        "for s in strategies:\n",
        "    pipeline = Pipeline(steps=[('i', KNNImputer(n_neighbors=int(s))), ('m', LogisticRegression())])\n",
        "    scores = cross_val_score(pipeline, X1, y, scoring='accuracy', cv=2, n_jobs=-1)\n",
        "    results.append(scores)\n",
        "    #print('n_neighbors= %s || accuracy (%.4f)' % (s , mean(scores)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lc9ydKsbmkj_"
      },
      "outputs": [],
      "source": [
        "### Pipeline for KNN imputer\n",
        "num_features = X.select_dtypes(exclude=\"object\").columns\n",
        "\n",
        "# Fit the KNN imputer with selected K-value\n",
        "knn_pipeline = Pipeline(steps=[\n",
        "    ('imputer', KNNImputer(n_neighbors=3)),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LNeqM7AqmkkA"
      },
      "outputs": [],
      "source": [
        "X_knn =knn_pipeline.fit_transform(X)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U2WtW76UmkkA"
      },
      "source": [
        "## Handling Imbalanced data using SMOTE-Tomek links"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a_qoKqBImkkA"
      },
      "outputs": [],
      "source": [
        "from imblearn.combine import SMOTETomek\n",
        "\n",
        "# Resampling the minority class. The strategy can be changed as required.\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority',n_jobs=-1)\n",
        "# Fit the model to generate the data.\n",
        "X_res, y_res = smt.fit_resample(X_knn, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##### Number of samples for SMOTE-Tomek link data"
      ],
      "metadata": {
        "id": "Jv2WfROa_Jqf"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gnCJnOrimkkA",
        "outputId": "ea6b533c-fee9-4143-eb1e-b917ca41588f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1392,)"
            ]
          },
          "metadata": {},
          "execution_count": 919
        }
      ],
      "source": [
        "y_res.shape "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RJSirT9GmkkA",
        "outputId": "98364b68-8b33-4978-b8ea-2a290d30063f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1392, 35)"
            ]
          },
          "metadata": {},
          "execution_count": 920
        }
      ],
      "source": [
        "X_res.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "##### Visualization of samples for SMOTE-Tomek link data"
      ],
      "metadata": {
        "id": "gTSW-AQX_Auv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# ## Visualization of unique values in Target variable\n",
        "# sns.set_style('whitegrid')\n",
        "# sns.set(font_scale = 1.2)\n",
        "# y_res= y_res.replace({0:'MIP', 1:'BCC', 2:'FCC', 3:'FCC_BCC'})\n",
        "# y_res.value_counts()\n",
        "\n",
        "# p= sns.countplot(y_res,  dodge=False,palette=\"Accent\", ec='black')\n",
        "# # sns.set(rc={'figure.figsize': (8,10)})\n",
        "# p.set_xlabel(\"Phases\")\n",
        "# p.set_ylabel(\"Samples\")"
      ],
      "metadata": {
        "id": "pycJqRHdxpNl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJqwtYDWmkkB"
      },
      "source": [
        "### Initialize Default Models in a dictionary"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nz1HT0QAmkkB"
      },
      "outputs": [],
      "source": [
        "# Dictionary which contains models for experiment\n",
        "models = {\n",
        "    \"Random Forest\": RandomForestClassifier(),\n",
        "    \"Decision Tree\": DecisionTreeClassifier(),\n",
        "     \"K-Neighbors Classifier\": KNeighborsClassifier(),\n",
        "    \"XGBClassifier\": XGBClassifier(), \n",
        "    \"SVM\": SVC()\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xy77_DslmkkB"
      },
      "source": [
        "### Fit KNN imputed data for models in dictionary"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6eJUzP0TmkkB",
        "outputId": "4f189dbb-5705-4eeb-a611-bf42545ce187"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Random Forest\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.9211\n",
            "- F1 score: 0.9209\n",
            "- Precision: 0.9216\n",
            "- Recall: 0.9211\n",
            "===================================\n",
            "\n",
            "\n",
            "Decision Tree\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8459\n",
            "- F1 score: 0.8461\n",
            "- Precision: 0.8471\n",
            "- Recall: 0.8459\n",
            "===================================\n",
            "\n",
            "\n",
            "K-Neighbors Classifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8302\n",
            "- F1 score: 0.8284\n",
            "- Precision: 0.8325\n",
            "- Recall: 0.8302\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.7957\n",
            "- F1 score: 0.7910\n",
            "- Precision: 0.7970\n",
            "- Recall: 0.7957\n",
            "===================================\n",
            "\n",
            "\n",
            "XGBClassifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.9443\n",
            "- F1 score: 0.9440\n",
            "- Precision: 0.9442\n",
            "- Recall: 0.9443\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8925\n",
            "- F1 score: 0.8930\n",
            "- Precision: 0.8971\n",
            "- Recall: 0.8925\n",
            "===================================\n",
            "\n",
            "\n",
            "SVM\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8392\n",
            "- F1 score: 0.8370\n",
            "- Precision: 0.8419\n",
            "- Recall: 0.8392\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8208\n",
            "- F1 score: 0.8164\n",
            "- Precision: 0.8293\n",
            "- Recall: 0.8208\n",
            "===================================\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "report_knn = evaluate_models(X_res, y_res, models)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "full_data= df_final.copy()\n",
        "X_full= full_data.drop('Phases', axis=1)\n",
        "y_full = full_data['Phases']"
      ],
      "metadata": {
        "id": "_kgNC0ANuQeL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "knn_pipeline2= Pipeline(steps=[('imputer', KNNImputer(n_neighbors=3)),('RobustScaler', RobustScaler())])\n",
        "X_knn2 =knn_pipeline2.fit_transform(X_full)\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority',n_jobs=-1)\n",
        "# Fit the model to generate the data.\n",
        "X_res, y_res = smt.fit_resample(X_knn2, y_full)\n",
        "X_train_fullkn, X_test_fullkn, y_train_fullkn, y_test_fullkn = train_test_split(X_res, y_res, test_size=0.20, random_state=42)\n",
        "\n",
        "## Hyper-tuned RFC\n",
        "model_rf_hyperkn = RandomForestClassifier( ).fit(X_train_fullkn, y_train_fullkn)\n",
        "probs_rf_hyper = model_rf_hyperkn .predict_proba(X_test_fullkn)#[:, 1]\n",
        "\n",
        "auc_rf = roc_auc_score(y_test_fullkn, probs_rf_hyper, multi_class='ovr')\n",
        "print('AUC of RFC: {:.4f}'.format(auc_rf))\n",
        "\n",
        "\n",
        "RFC_KNN = RandomForestClassifier() #random_state= 100, min_samples_split= 6,\n",
        "    \n",
        "scores_cv_rf= cross_val_score(RFC_KNN, X_knn2, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores_cv_rf_std= cross_val_score(RFC_KNN, X_knn2, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "#print('10-fold CV of RFC', (scores_cv_rf))\n",
        "print('10-fold CV mean of RFC: {:.4f}'.format(scores_cv_rf))\n",
        "print('10-fold CV std of RFC: {:.4f}'.format(scores_cv_rf_std))"
      ],
      "metadata": {
        "id": "kIlLH9Qnt9AG",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c8e95b28-ffb3-4a49-c2d0-c9794b382901"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "AUC of RFC: 0.9841\n",
            "10-fold CV mean of RFC: 0.9262\n",
            "10-fold CV std of RFC: 0.0307\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(X_train_fullkn.shape)\n",
        "print(X_test_fullkn.shape)\n",
        "print(y_train_fullkn.shape)\n",
        "print(y_test_fullkn.shape)"
      ],
      "metadata": {
        "id": "b4T_WjV6z5uy",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0a307dd9-12d6-4956-8940-0c570eac97b8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "(1113, 35)\n",
            "(279, 35)\n",
            "(1113,)\n",
            "(279,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "y_test_fullkn.value_counts()"
      ],
      "metadata": {
        "id": "b5em3eaK1IBv",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d660d86f-e630-4586-d447-343293bb5e10"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "3.0    94\n",
              "0.0    78\n",
              "1.0    67\n",
              "2.0    40\n",
              "Name: Phases, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 928
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "DTC = DecisionTreeClassifier()\n",
        "RFC = RandomForestClassifier()\n",
        "XGB = XGBClassifier()\n",
        "KNN = KNeighborsClassifier()\n",
        "SVM = SVC(probability=True)\n",
        "\n",
        "#scores1_cv_cb= cross_val_score(CBC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores1_cv_dt= cross_val_score(DTC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores1_cv_rf= cross_val_score(RFC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores1_cv_xg= cross_val_score(XGB, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores1_cv_kn= cross_val_score(KNN, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "scores1_cv_SV= cross_val_score(SVM, X_knn, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "\n",
        "#scores1_std_cb= cross_val_score(CBC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "scores1_std_dt= cross_val_score(DTC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "scores1_std_rf= cross_val_score(RFC, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "scores1_std_xg= cross_val_score(XGB, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "scores1_std_kn= cross_val_score(KNN, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "scores1_std_SV= cross_val_score(SVM, X_knn, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "\n",
        "#print('10-fold CV for CBC', (scores1_cv_cb)) \n",
        "print('10-fold CV forDTC', (scores1_cv_dt)) \n",
        "print('10-fold CV forRFC', (scores1_cv_rf))\n",
        "print('10-fold CV forXGB',(scores1_cv_xg))\n",
        "print('10-fold CV forKNN',(scores1_cv_kn))\n",
        "print('10-fold CV forSVM',(scores1_cv_SV))\n",
        "\n",
        "#print('std for CBC', (scores1_std_cb)) \n",
        "print('std forDTC', (scores1_std_dt)) \n",
        "print('std for RFC', (scores1_std_rf))\n",
        "print('std for XGB',(scores1_std_xg))\n",
        "print('std for KNN',(scores1_std_kn))\n",
        "print('std for SVM',(scores1_std_SV))"
      ],
      "metadata": {
        "id": "YT423e-hupzS",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bf62567f-19af-41d2-8a3a-2d0a665e0c7f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "10-fold CV forDTC 0.7667898532261462\n",
            "10-fold CV forRFC 0.9211907342671587\n",
            "10-fold CV forXGB 0.905462515029783\n",
            "10-fold CV forKNN 0.8233527212676955\n",
            "10-fold CV forSVM 0.8677643965852303\n",
            "std forDTC 0.04311219027196229\n",
            "std for RFC 0.02890207696361916\n",
            "std for XGB 0.03213222959377166\n",
            "std for KNN 0.05565549213845232\n",
            "std for SVM 0.057071330411747924\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yjk_MgxRmkkB"
      },
      "source": [
        "### Experiment: 2 = Simple Imputer with Strategy Median "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X7iqoWQUmkkB"
      },
      "outputs": [],
      "source": [
        "num_features = X.select_dtypes(exclude=\"object\").columns\n",
        "\n",
        "# Fit the Simple imputer with strategy median\n",
        "median_pipeline = Pipeline(steps=[\n",
        "    ('imputer', SimpleImputer(strategy='median')),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M7q92I-WmkkB"
      },
      "outputs": [],
      "source": [
        "# Fit X with median_pipeline\n",
        "X_median = median_pipeline.fit_transform(X)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EGy0ilAxmkkC"
      },
      "outputs": [],
      "source": [
        "# Resampling the minority class. The strategy can be changed as required.\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority')\n",
        "# Fit the model to generate the data.\n",
        "X_res1, y_res1 = smt.fit_resample(X_median, y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NfbuhaBSmkkC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "16738bb8-cd01-4254-9e55-1a98ccaa3015"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Random Forest\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.9140\n",
            "- F1 score: 0.9143\n",
            "- Precision: 0.9154\n",
            "- Recall: 0.9140\n",
            "===================================\n",
            "\n",
            "\n",
            "Decision Tree\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8351\n",
            "- F1 score: 0.8354\n",
            "- Precision: 0.8363\n",
            "- Recall: 0.8351\n",
            "===================================\n",
            "\n",
            "\n",
            "K-Neighbors Classifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8293\n",
            "- F1 score: 0.8277\n",
            "- Precision: 0.8282\n",
            "- Recall: 0.8293\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8029\n",
            "- F1 score: 0.8000\n",
            "- Precision: 0.8033\n",
            "- Recall: 0.8029\n",
            "===================================\n",
            "\n",
            "\n",
            "XGBClassifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.9443\n",
            "- F1 score: 0.9440\n",
            "- Precision: 0.9442\n",
            "- Recall: 0.9443\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8925\n",
            "- F1 score: 0.8930\n",
            "- Precision: 0.8971\n",
            "- Recall: 0.8925\n",
            "===================================\n",
            "\n",
            "\n",
            "SVM\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8392\n",
            "- F1 score: 0.8371\n",
            "- Precision: 0.8421\n",
            "- Recall: 0.8392\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8208\n",
            "- F1 score: 0.8164\n",
            "- Precision: 0.8293\n",
            "- Recall: 0.8208\n",
            "===================================\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Training the models\n",
        "report_median = evaluate_models(X_res1, y_res1, models)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "median_pipeline1 = Pipeline(steps=[\n",
        "    ('imputer', SimpleImputer(strategy='median')),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])\n",
        "X_median_full = median_pipeline1.fit_transform(X_full)\n",
        "\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority',n_jobs=-1)\n",
        "# Fit the model to generate the data.\n",
        "X_res2, y_res2 = smt.fit_resample(X_median_full , y_full)\n",
        "X_train_full_median, X_test_full_median, y_train_full_median, y_test_full_median = train_test_split(X_res2, y_res2, test_size=0.20, random_state=42)\n",
        "\n",
        "model_dt = DecisionTreeClassifier().fit(X_train_full_median, y_train_full_median)\n",
        "probs2_dt = model_dt.predict_proba(X_test_full_median)#[:, 1]\n",
        "\n",
        "model_rf = RandomForestClassifier().fit(X_train_full_median, y_train_full_median)\n",
        "probs2_rf = model_rf.predict_proba(X_test_full_median)#[:, 1]\n",
        "\n",
        "model_xg = XGBClassifier().fit(X_train_full_median, y_train_full_median)\n",
        "probs2_xg = model_xg.predict_proba(X_test_full_median)#[:, 1]\n",
        "\n",
        "model_kn =KNeighborsClassifier().fit(X_train_full_median, y_train_full_median)\n",
        "probs2_kn = model_xg.predict_proba(X_test_full_median)#[:, 1]\n",
        "\n",
        "model_SV = SVC(probability=True).fit(X_train_full_median, y_train_full_median)\n",
        "probs2_SV = model_SV.predict_proba(X_test_full_median)\n",
        "\n",
        "\n",
        "auc_dt2= roc_auc_score(y_test_full_median, probs2_dt, multi_class='ovr')\n",
        "auc_rf2 = roc_auc_score(y_test_full_median, probs2_rf, multi_class='ovr')\n",
        "auc_xg2 = roc_auc_score(y_test_full_median, probs2_xg, multi_class='ovr')\n",
        "auc_kn2 = roc_auc_score(y_test_full_median, probs2_kn, multi_class='ovr')\n",
        "auc_SV2 = roc_auc_score(y_test_full_median, probs2_SV, multi_class='ovr')\n",
        "\n",
        "print('- ROC_AUC of DTC: {:.4f}'.format(auc_dt2)) \n",
        "print('- ROC_AUC of RFC: {:.4f}'.format(auc_rf2))\n",
        "print('- ROC_AUC of XGB: {:.4f}'.format(auc_xg2))\n",
        "print('- ROC_AUC of KNN: {:.4f}'.format(auc_kn2))\n",
        "print('- ROC_AUC of SVM: {:.4f}'.format(auc_SV2))"
      ],
      "metadata": {
        "id": "Iv1P2HF6vagl",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c4ea309b-42be-44b5-9b66-02f658843794"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "- ROC_AUC of DTC: 0.8929\n",
            "- ROC_AUC of RFC: 0.9862\n",
            "- ROC_AUC of XGB: 0.9751\n",
            "- ROC_AUC of KNN: 0.9751\n",
            "- ROC_AUC of SVM: 0.9514\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "DTC = DecisionTreeClassifier()\n",
        "RFC = RandomForestClassifier()\n",
        "XGB = XGBClassifier()\n",
        "KNN = KNeighborsClassifier()\n",
        "SVM = SVC(probability=True)\n",
        "\n",
        "median2_cv_dt= cross_val_score(DTC, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "median2_cv_rf= cross_val_score(RFC, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "median2_cv_xg= cross_val_score(XGB, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "median2_cv_kn= cross_val_score(KNN, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "median2_cv_SV= cross_val_score(SVM, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').mean()\n",
        "\n",
        "median2_std_dt= cross_val_score(DTC, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "median2_std_rf= cross_val_score(RFC, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "median2_std_xg= cross_val_score(XGB, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "median2_std_kn= cross_val_score(KNN, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "median2_std_SV= cross_val_score(SVM, X_median_full, y_full, cv=10, scoring='roc_auc_ovr').std()\n",
        "\n",
        " \n",
        "print('10-fold CV forDTC', (median2_cv_dt)) \n",
        "print('10-fold CV forRFC', (median2_cv_rf))\n",
        "print('10-fold CV forXGB',(median2_cv_xg))\n",
        "print('10-fold CV forKNN',(median2_cv_kn))\n",
        "print('10-fold CV forSVM',(median2_cv_SV))\n",
        "\n",
        "print('std forDTC', (median2_std_dt)) \n",
        "print('std for RFC', (median2_std_rf))\n",
        "print('std for XGB',(median2_std_xg))\n",
        "print('std for KNN',(median2_std_kn))\n",
        "print('std for SVM',(median2_std_SV))"
      ],
      "metadata": {
        "id": "x2tr5QyFvbP1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "33623cbd-8560-4548-b922-d931b442bc39"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "10-fold CV forDTC 0.7643985387226901\n",
            "10-fold CV forRFC 0.9225299487824339\n",
            "10-fold CV forXGB 0.905462515029783\n",
            "10-fold CV forKNN 0.8233527212676955\n",
            "10-fold CV forSVM 0.8672833044572993\n",
            "std forDTC 0.04192060406958924\n",
            "std for RFC 0.027988450085827486\n",
            "std for XGB 0.03213222959377166\n",
            "std for KNN 0.05565549213845232\n",
            "std for SVM 0.056710516142963456\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VgbT9-l4mkkC"
      },
      "source": [
        "### Experiment: 3 = MICE for Imputing Null values"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "In5Oel1cneam"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "full_data= df_final.copy()\n",
        "X_full= full_data.drop('Phases', axis=1)\n",
        "y_full = full_data['Phases']"
      ],
      "metadata": {
        "id": "N60dXbTDUZdi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Ci2VDdgmkkC"
      },
      "outputs": [],
      "source": [
        "import miceforest as mf\n",
        "\n",
        "X_mice = X.copy()\n",
        "kernel = mf.ImputationKernel(\n",
        "  X_mice,\n",
        "  save_all_iterations=True,\n",
        "  random_state=1989\n",
        ")# Run the MICE algorithm for 3 iterations kernel.mice(3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zs81EaTXmkkC"
      },
      "outputs": [],
      "source": [
        "X_mice = kernel.complete_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0aKXmwOumkkC"
      },
      "outputs": [],
      "source": [
        "# fit robust scaler\n",
        "mice_pipeline = Pipeline(steps=[\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u9pKQrPmmkkC"
      },
      "outputs": [],
      "source": [
        "# Fit X with Mice imputer \n",
        "X_mice= mice_pipeline.fit_transform(X_mice)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dFb-8NDvmkkD"
      },
      "outputs": [],
      "source": [
        "# Resampling the minority class. The strategy can be changed as required.\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority', n_jobs=-1 )\n",
        "# Fit the model to generate the data.\n",
        "X_res2, y_res2 = smt.fit_resample(X_mice, y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d1ytZXJRmkkD",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "250bc020-5b57-4c79-bfcf-8fb4b0aaf1a3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Random Forest\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.9176\n",
            "- F1 score: 0.9181\n",
            "- Precision: 0.9197\n",
            "- Recall: 0.9176\n",
            "===================================\n",
            "\n",
            "\n",
            "Decision Tree\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8530\n",
            "- F1 score: 0.8531\n",
            "- Precision: 0.8536\n",
            "- Recall: 0.8530\n",
            "===================================\n",
            "\n",
            "\n",
            "K-Neighbors Classifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8293\n",
            "- F1 score: 0.8277\n",
            "- Precision: 0.8282\n",
            "- Recall: 0.8293\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8029\n",
            "- F1 score: 0.8000\n",
            "- Precision: 0.8033\n",
            "- Recall: 0.8029\n",
            "===================================\n",
            "\n",
            "\n",
            "XGBClassifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.9443\n",
            "- F1 score: 0.9440\n",
            "- Precision: 0.9442\n",
            "- Recall: 0.9443\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8925\n",
            "- F1 score: 0.8930\n",
            "- Precision: 0.8971\n",
            "- Recall: 0.8925\n",
            "===================================\n",
            "\n",
            "\n",
            "SVM\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8392\n",
            "- F1 score: 0.8371\n",
            "- Precision: 0.8421\n",
            "- Recall: 0.8392\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8208\n",
            "- F1 score: 0.8164\n",
            "- Precision: 0.8293\n",
            "- Recall: 0.8208\n",
            "===================================\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Training the models\n",
        "report_mice = evaluate_models(X_res2, y_res2, models)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "mice_pipeline1 = Pipeline(steps=[\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])\n",
        "\n",
        "X_mice1= mice_pipeline1.fit_transform(X_full)\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority', n_jobs=-1 )\n",
        "# Fit the model to generate the data.\n",
        "X_res34, y_res34 = smt.fit_resample(X_mice1, y_full)\n",
        "\n",
        "X_train_mice12, X_test_mice12, y_train_mice12, y_test_mice12 = train_test_split(X_res34, y_res34, test_size=0.2, random_state=42)\n",
        "\n",
        "# model_rf_mice = RandomForestClassifier().fit(X_train_mice12, y_train_mice12)\n",
        "\n",
        "\n",
        "model_dt = DecisionTreeClassifier().fit(X_train_mice12, y_train_mice12)\n",
        "DT_mice_model_cb = model_dt.predict(X_test_mice12)\n",
        "probs3_dt = model_dt.predict_proba(X_test_mice12)#[:, 1]\n",
        "\n",
        "model_rf = RandomForestClassifier().fit(X_train_mice12, y_train_mice12)\n",
        "RF_mice_model_cb = model_rf.predict(X_test_mice12)\n",
        "probs3_rf = model_rf.predict_proba(X_test_mice12)#[:, 1]\n",
        "\n",
        "model_xg = XGBClassifier().fit(X_train_mice12, y_train_mice12)\n",
        "XG_mice_model_cb = model_xg.predict(X_test_mice12)\n",
        "probs3_xg = model_xg.predict_proba(X_test_mice12)#[:, 1]\n",
        "\n",
        "model_kn =KNeighborsClassifier().fit(X_train_mice12, y_train_mice12)\n",
        "KN_mice_model_cb = model_kn.predict(X_test_mice12)\n",
        "probs3_kn = model_kn.predict_proba(X_test_mice12)#[:, 1]\n",
        "\n",
        "model_SV = SVC(probability=True).fit(X_train_mice12, y_train_mice12)\n",
        "SVM_mice_model_cb = model_SV.predict(X_test_mice12)\n",
        "probs3_SV = model_SV.predict_proba(X_test_mice12)\n",
        "\n",
        "#\n",
        "auc_dt3 = roc_auc_score(y_test_mice12, probs3_dt, multi_class='ovr')\n",
        "auc_rf3 = roc_auc_score(y_test_mice12, probs3_rf, multi_class='ovr')\n",
        "auc_xg3 = roc_auc_score(y_test_mice12, probs3_xg , multi_class='ovr')\n",
        "auc_kn3 = roc_auc_score(y_test_mice12, probs3_kn, multi_class='ovr')\n",
        "auc_SV3 = roc_auc_score(y_test_mice12, probs3_SV, multi_class='ovr')\n",
        "\n",
        "#\n",
        "print('- AUC of DTC: {:.4f}'.format(auc_dt3)) \n",
        "print('- AUC of RFC: {:.4f}'.format(auc_rf3))\n",
        "print('- AUC of XGB: {:.4f}'.format(auc_xg3))\n",
        "print('- AUC of KNN: {:.4f}'.format(auc_kn3))\n",
        "print('- AUC of SVM: {:.4f}'.format(auc_SV3))"
      ],
      "metadata": {
        "id": "BvPlGR--UFpo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "98d0b131-9c90-4a52-c3e9-52e720c41bc4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "- AUC of DTC: 0.8932\n",
            "- AUC of RFC: 0.9861\n",
            "- AUC of XGB: 0.9751\n",
            "- AUC of KNN: 0.9282\n",
            "- AUC of SVM: 0.9509\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# sns.set_style('white')\n",
        "# plt.figure(figsize=(9,8))\n",
        "\n",
        "# fpr1, tpr1, thresh1 = roc_curve(y_test_mice12, probs3_kn[:,1], pos_label=1)\n",
        "# fpr2, tpr2, thresh2 = roc_curve(y_test_mice12, probs3_SV[:,1], pos_label=1)\n",
        "# fpr3, tpr3, thresh3 = roc_curve(y_test_mice12, probs3_dt[:,1], pos_label=1) #DTC\n",
        "# fpr4, tpr4, thresh4 = roc_curve(y_test_mice12, probs3_rf[:,1], pos_label=1)#RFC\n",
        "# fpr6, tpr6, thresh6 = roc_curve(y_test_mice12, probs3_xg[:,1], pos_label=1) #XGB\n",
        "\n",
        "# a= roc_auc_score(y_test_mice12, probs3_kn, multi_class= 'ovr')\n",
        "# b= roc_auc_score(y_test_mice12, probs3_SV, multi_class= 'ovr')\n",
        "# c= roc_auc_score(y_test_mice12, probs3_dt, multi_class= 'ovr')\n",
        "# d= roc_auc_score(y_test_mice12,probs3_rf, multi_class= 'ovr')\n",
        "# e= roc_auc_score(y_test_mice12,probs3_xg, multi_class= 'ovr')\n",
        "\n",
        "# random_probs = [0 for i in range(len(y_test_mice12))]\n",
        "# p_fpr, p_tpr, _ = roc_curve(y_test_mice12, random_probs, pos_label=1)\n",
        "\n",
        "# knn_roc_auc = a\n",
        "# svm_roc_auc = b\n",
        "# DT_roc_auc = c #auc(fpr3, tpr3)\n",
        "# RF_roc_auc = d  #auc(fpr4, tpr4)\n",
        "# #GB_roc_auc = auc(fpr5, tpr5)\n",
        "# XGB_roc_auc = e  #auc(fpr6, tpr6)\n",
        "\n",
        "# plt.plot(fpr1, tpr1, linestyle= '--', color= 'red', label= 'ROC-AUC of ST-KNN = %0.4f' % knn_roc_auc, markersize=1)\n",
        "# plt.plot(fpr2, tpr2 ,  linestyle= '--', color= 'green', label= 'ROC-AUC of ST-SVM = %0.4f' % svm_roc_auc, markersize=1)\n",
        "# plt.plot(fpr3, tpr3,  linestyle= '--', color= 'blue',label= 'ROC-AUC of ST-DTC =%0.4f' % DT_roc_auc, markersize=1 )\n",
        "# plt.plot(fpr4, tpr4,  linestyle= '--', color= 'purple', label= 'ROC-AUC of ST-RFC = %0.4f' % RF_roc_auc,markersize=1)\n",
        "\n",
        "# #plt.plot(fpr5, tpr5,  linestyle= '--', color= 'cyan',label= 'AUC of Gradient Boost =%0.2f' % GB_roc_auc )\n",
        "# plt.plot(fpr6, tpr6,  linestyle= '--', color= 'magenta', label= 'ROC-AUC of ST-XGB = %0.4f' % XGB_roc_auc, markersize=1)\n",
        "# plt.plot(p_fpr, p_tpr,  linestyle= '--', color= 'black' )\n",
        "# plt.title(' ROC Curve for all five SMOTE-Tomek models', fontsize=20)\n",
        "# plt.xlabel('False positive rate', fontsize=20)\n",
        "# plt.ylabel('True positive rate', fontsize=20)\n",
        "# plt.legend(loc='best', fontsize=15)\n",
        "# plt.xticks( fontsize=15)\n",
        "# plt.yticks(fontsize=15)\n",
        "# #plt.savefig('ROC', dpi=300)"
      ],
      "metadata": {
        "id": "goRz0DsEYq2f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mHPzp3N6mkkD"
      },
      "source": [
        "### Experiment: 4 = Simple Imputer with Strategy Constant "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_6TzxxAomkkD"
      },
      "outputs": [],
      "source": [
        "# Create a pipeline with simple imputer with strategy constant and fill value 0\n",
        "constant_pipeline = Pipeline(steps=[\n",
        "    ('Imputer', SimpleImputer(strategy='constant', fill_value=0)),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vztriY6kmkkD"
      },
      "outputs": [],
      "source": [
        "X_const =constant_pipeline.fit_transform(X)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gn4rGwRwmkkD"
      },
      "outputs": [],
      "source": [
        "# Resampling the minority class. The strategy can be changed as required.\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority', n_jobs=-1 )\n",
        "# Fit the model to generate the data.\n",
        "X_res3, y_res3 = smt.fit_resample(X_const, y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eGPMRNjjmkkE",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a5af99da-c4df-46d3-a762-f4683f400ff1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Random Forest\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.9068\n",
            "- F1 score: 0.9070\n",
            "- Precision: 0.9081\n",
            "- Recall: 0.9068\n",
            "===================================\n",
            "\n",
            "\n",
            "Decision Tree\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8459\n",
            "- F1 score: 0.8465\n",
            "- Precision: 0.8483\n",
            "- Recall: 0.8459\n",
            "===================================\n",
            "\n",
            "\n",
            "K-Neighbors Classifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8293\n",
            "- F1 score: 0.8277\n",
            "- Precision: 0.8282\n",
            "- Recall: 0.8293\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8029\n",
            "- F1 score: 0.8000\n",
            "- Precision: 0.8033\n",
            "- Recall: 0.8029\n",
            "===================================\n",
            "\n",
            "\n",
            "XGBClassifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.9443\n",
            "- F1 score: 0.9440\n",
            "- Precision: 0.9442\n",
            "- Recall: 0.9443\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8925\n",
            "- F1 score: 0.8930\n",
            "- Precision: 0.8971\n",
            "- Recall: 0.8925\n",
            "===================================\n",
            "\n",
            "\n",
            "SVM\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8392\n",
            "- F1 score: 0.8371\n",
            "- Precision: 0.8421\n",
            "- Recall: 0.8392\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8208\n",
            "- F1 score: 0.8164\n",
            "- Precision: 0.8293\n",
            "- Recall: 0.8208\n",
            "===================================\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# training the models\n",
        "report_const = evaluate_models(X_res3, y_res3, models)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "full_data= df_final.copy()\n",
        "X_full= full_data.drop('Phases', axis=1)\n",
        "y_full = full_data['Phases']"
      ],
      "metadata": {
        "id": "CqgUrDjButSW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "constant_pipeline1 = Pipeline(steps=[\n",
        "    ('Imputer', SimpleImputer(strategy='constant', fill_value=0)),\n",
        "    ('RobustScaler', RobustScaler())])\n",
        "\n",
        "X_constan1= constant_pipeline1.fit_transform(X_full)\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority', n_jobs=-1 )\n",
        "# Fit the model to generate the data.\n",
        "X_res34, y_res34 = smt.fit_resample(X_constan1, y_full)\n",
        "\n",
        "X_train_constan12, X_test_constan12, y_train_constan12, y_test_constan12 = train_test_split(X_res34, y_res34, test_size=0.2, random_state=42)\n",
        "\n",
        "model_rf_constan = RandomForestClassifier().fit(X_train_constan12, y_train_constan12)"
      ],
      "metadata": {
        "id": "0Od8eSFXujBO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "X_res34.shape"
      ],
      "metadata": {
        "id": "bSo2-ab8vfmw",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "55a782d7-cdc4-4a70-9ff1-ca9108407233"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1392, 35)"
            ]
          },
          "metadata": {},
          "execution_count": 952
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "y_res34.value_counts()"
      ],
      "metadata": {
        "id": "FNdRriv4vjHJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1f8ccd28-d625-4b3f-806b-9598df804fd5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "3.0    434\n",
              "0.0    408\n",
              "1.0    350\n",
              "2.0    200\n",
              "Name: Phases, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 953
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00cLvBTlmkkE"
      },
      "source": [
        "## Experiment: 5 = Simple Imputer with Strategy Mean \n",
        "\n",
        "- Another strategy which can be used is replacing missing values with mean\n",
        "- Here we replace the missing values with the mean of the column"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kPxglohumkkE"
      },
      "outputs": [],
      "source": [
        "# Create a pipeline with Simple imputer with strategy mean\n",
        "mean_pipeline = Pipeline(steps=[\n",
        "    ('Imputer', SimpleImputer(strategy='mean')),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f1Kv0g93mkkE"
      },
      "outputs": [],
      "source": [
        "X_mean = mean_pipeline.fit_transform(X)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zj6nWgwrmkkE"
      },
      "outputs": [],
      "source": [
        "# Resampling the minority class. The strategy can be changed as required.\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority' , n_jobs=-1)\n",
        "# Fit the model to generate the data.\n",
        "X_res4, y_res4 = smt.fit_resample(X_mean, y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OiW9jYe7mkkE",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d04b1d28-c5dd-4275-d4c5-4d51138c6602"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Random Forest\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.9140\n",
            "- F1 score: 0.9138\n",
            "- Precision: 0.9159\n",
            "- Recall: 0.9140\n",
            "===================================\n",
            "\n",
            "\n",
            "Decision Tree\n",
            "Model performance for Training set\n",
            "- Accuracy: 1.0000\n",
            "- F1 score: 1.0000\n",
            "- Precision: 1.0000\n",
            "- Recall: 1.0000\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8459\n",
            "- F1 score: 0.8467\n",
            "- Precision: 0.8493\n",
            "- Recall: 0.8459\n",
            "===================================\n",
            "\n",
            "\n",
            "K-Neighbors Classifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8293\n",
            "- F1 score: 0.8277\n",
            "- Precision: 0.8282\n",
            "- Recall: 0.8293\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8029\n",
            "- F1 score: 0.8000\n",
            "- Precision: 0.8033\n",
            "- Recall: 0.8029\n",
            "===================================\n",
            "\n",
            "\n",
            "XGBClassifier\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.9443\n",
            "- F1 score: 0.9440\n",
            "- Precision: 0.9442\n",
            "- Recall: 0.9443\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8925\n",
            "- F1 score: 0.8930\n",
            "- Precision: 0.8971\n",
            "- Recall: 0.8925\n",
            "===================================\n",
            "\n",
            "\n",
            "SVM\n",
            "Model performance for Training set\n",
            "- Accuracy: 0.8392\n",
            "- F1 score: 0.8371\n",
            "- Precision: 0.8421\n",
            "- Recall: 0.8392\n",
            "----------------------------------\n",
            "Model performance for Test set\n",
            "- Accuracy: 0.8208\n",
            "- F1 score: 0.8164\n",
            "- Precision: 0.8293\n",
            "- Recall: 0.8208\n",
            "===================================\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Training all models\n",
        "report_mean = evaluate_models(X_res4, y_res4, models)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hCflOLkYmkkG"
      },
      "outputs": [],
      "source": [
        "mean_pipeline1 = Pipeline(steps=[\n",
        "    ('Imputer', SimpleImputer(strategy='mean')),\n",
        "    ('RobustScaler', RobustScaler())\n",
        "])\n",
        "\n",
        "\n",
        "X_mean1= mean_pipeline1.fit_transform(X_full)\n",
        "smt = SMOTETomek(random_state=42,sampling_strategy='minority', n_jobs=-1 )\n",
        "# Fit the model to generate the data.\n",
        "X_res35, y_res35 = smt.fit_resample(X_mean1, y_full)\n",
        "\n",
        "X_train_mean12, X_test_mean12, y_train_mean12, y_test_mean12 = train_test_split(X_res35, y_res35, test_size=0.2, random_state=42)\n",
        "\n",
        "model_rf_mean = RandomForestClassifier().fit(X_train_mean12, y_train_mean12)\n",
        "RF_Prob = model_rf_mean.predict_proba(X_test_mean12)\n",
        "\n",
        "# from sklearn.metrics import plot_confusion_matrix\n",
        "# class_names = ['IM', 'BCC', 'FCC', 'FCC+BCC']\n",
        "# cmplot= plot_confusion_matrix(model_rf_mean, X_test_mean12, y_test_mean12, display_labels= class_names, cmap='gist_heat')\n",
        "# cmplot.ax_.set_title('Confusion Matrix', color='red')\n",
        "# plt.xlabel('Predicted Label', color='black')\n",
        "# plt.ylabel('True Label', color='black')\n",
        "# plt.gcf().axes[0].tick_params(color='black')\n",
        "# plt.gcf().axes[1].tick_params(color='black')\n",
        "# plt.gcf().set_size_inches(10,6)\n",
        "# plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MI-BqUpNmkkG"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    },
    "vscode": {
      "interpreter": {
        "hash": "b7082a90f0341f66b325168da8fc238f0b2aba7ee16848d917086bb4ed45c134"
      }
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}