{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pycaret.regression import *\n", "import pandas as pd\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
stepsdistancerunDistancecalories
01901127917944
199566912224
279505749263182
3141995517236
4136792010933
\n", "
" ], "text/plain": [ " steps distance runDistance calories\n", "0 1901 1279 179 44\n", "1 995 669 122 24\n", "2 7950 5749 263 182\n", "3 1419 955 172 36\n", "4 1367 920 109 33" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('../../data/MiFit/Export data/ACTIVITY/ACTIVITY_1704203167901.csv')\n", "df1 = pd.read_csv('../../data/MiFit/Export data 170420/ACTIVITY/ACTIVITY_1704202151453.csv')\n", "df = pd.concat([df, df1]) \n", "df = df.reset_index(drop=True)\n", "\n", "# drop date\n", "df = df.drop(columns=['date'])\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ " Model MAE MSE RMSE \\\n", "lr Linear Regression 4.7829 45.0952 6.6513 \n", "ridge Ridge Regression 4.7829 45.0952 6.6513 \n", "llar Lasso Least Angle Regression 4.7837 45.0909 6.6511 \n", "br Bayesian Ridge 4.7844 45.0890 6.6511 \n", "lasso Lasso Regression 4.9302 47.7252 6.8417 \n", "en Elastic Net 4.9296 47.7112 6.8408 \n", "huber Huber Regressor 4.7640 48.5392 6.8969 \n", "et Extra Trees Regressor 5.5507 77.8235 8.6873 \n", "catboost CatBoost Regressor 5.6985 85.5234 8.9096 \n", "gbr Gradient Boosting Regressor 5.7667 81.8716 8.8728 \n", "rf Random Forest Regressor 5.7330 82.1203 8.9179 \n", "knn K Neighbors Regressor 5.9197 95.4778 9.4639 \n", "xgboost Extreme Gradient Boosting 6.1679 107.6687 10.0229 \n", "omp Orthogonal Matching Pursuit 6.7567 110.3968 10.3185 \n", "dt Decision Tree Regressor 6.3781 114.0464 10.4010 \n", "ada AdaBoost Regressor 9.1984 142.5733 11.7862 \n", "lightgbm Light Gradient Boosting Machine 6.9481 187.2160 13.1084 \n", "lar Least Angle Regression 9.7533 225.9983 14.1739 \n", "par Passive Aggressive Regressor 10.0749 269.2730 14.9989 \n", "dummy Dummy Regressor 86.8189 10798.3589 102.9788 \n", "\n", " R2 RMSLE MAPE TT (Sec) \n", "lr 0.9956 0.5003 0.1049 0.004 \n", "ridge 0.9956 0.5003 0.1049 0.003 \n", "llar 0.9956 0.5000 0.1049 0.004 \n", "br 0.9956 0.4994 0.1047 0.003 \n", "lasso 0.9954 0.4383 0.0899 0.005 \n", "en 0.9954 0.4384 0.0899 0.003 \n", "huber 0.9953 0.4554 0.0939 0.004 \n", "et 0.9926 0.0725 0.0620 0.022 \n", "catboost 0.9923 0.0909 0.0615 0.145 \n", "gbr 0.9923 0.1229 0.0612 0.010 \n", "rf 0.9922 0.0760 0.0642 0.028 \n", "knn 0.9912 0.0781 0.0640 0.005 \n", "xgboost 0.9901 0.0937 0.0722 0.010 \n", "omp 0.9895 0.2696 0.0793 0.003 \n", "dt 0.9894 0.0928 0.0757 0.003 \n", "ada 0.9861 0.9246 0.3523 0.009 \n", "lightgbm 0.9830 0.0815 0.0693 0.212 \n", "lar 0.9797 0.7802 0.2691 0.003 \n", "par 0.9750 0.1154 0.0963 0.003 \n", "dummy -0.0389 1.8488 3.6105 0.003 \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r" ] } ], "source": [ "setup(df, target=\"calories\", verbose=False, session_id=42, html=False)\n", "setup_df = pull()\n", "best_model = compare_models()\n", "compare_df = pull()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_model(best_model, plot='feature')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 MAEMSERMSER2RMSLEMAPE
Fold      
06.3405100.116610.00580.99070.09750.0732
18.6717241.856915.55180.98570.11150.0898
24.045134.80485.89960.99550.10600.0688
35.228355.17687.42810.99240.12640.0867
46.931087.61299.36020.99230.07750.0644
55.817791.71919.57700.99130.07500.0598
65.324361.68017.85370.99190.08470.0742
77.1095173.976513.19000.98440.08650.0697
85.703896.06159.80110.98800.11160.0880
96.5070133.682311.56210.98910.06040.0469
Mean6.1679107.668710.02290.99010.09370.0722
Std1.196758.40192.68510.00320.01930.0128
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr = create_model('xgboost')\n", "params = lr.get_params()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pycaret.classification import *\n", "from pycaret.datasets import get_data\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "ename": "ConnectionError", "evalue": "HTTPSConnectionPool(host='raw.githubusercontent.com', port=443): Max retries exceeded with url: /pycaret/datasets/main/data/common/diabetes.csv (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mgaierror\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connection.py:174\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 174\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_connection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dns_host\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mport\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mextra_kw\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketTimeout:\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/util/connection.py:72\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m six\u001b[38;5;241m.\u001b[39mraise_from(\n\u001b[1;32m 69\u001b[0m LocationParseError(\u001b[38;5;124mu\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, label empty or too long\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m host), \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 70\u001b[0m )\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m res \u001b[38;5;129;01min\u001b[39;00m \u001b[43msocket\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetaddrinfo\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhost\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mport\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfamily\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msocket\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSOCK_STREAM\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 73\u001b[0m af, socktype, proto, canonname, sa \u001b[38;5;241m=\u001b[39m res\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/socket.py:918\u001b[0m, in \u001b[0;36mgetaddrinfo\u001b[0;34m(host, port, family, type, proto, flags)\u001b[0m\n\u001b[1;32m 917\u001b[0m addrlist \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 918\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m res \u001b[38;5;129;01min\u001b[39;00m \u001b[43m_socket\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetaddrinfo\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhost\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mport\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfamily\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mproto\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflags\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 919\u001b[0m af, socktype, proto, canonname, sa \u001b[38;5;241m=\u001b[39m res\n", "\u001b[0;31mgaierror\u001b[0m: [Errno 8] nodename nor servname provided, or not known", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mNewConnectionError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connectionpool.py:715\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 714\u001b[0m \u001b[38;5;66;03m# Make the request on the httplib connection object.\u001b[39;00m\n\u001b[0;32m--> 715\u001b[0m httplib_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 717\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 718\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 720\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 725\u001b[0m \u001b[38;5;66;03m# If we're going to release the connection in ``finally:``, then\u001b[39;00m\n\u001b[1;32m 726\u001b[0m \u001b[38;5;66;03m# the response doesn't need to know about the connection. Otherwise\u001b[39;00m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;66;03m# it will also try to release it and we'll have a double-release\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \u001b[38;5;66;03m# mess.\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connectionpool.py:404\u001b[0m, in \u001b[0;36mHTTPConnectionPool._make_request\u001b[0;34m(self, conn, method, url, timeout, chunked, **httplib_request_kw)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 404\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_conn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (SocketTimeout, BaseSSLError) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 406\u001b[0m \u001b[38;5;66;03m# Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout.\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connectionpool.py:1058\u001b[0m, in \u001b[0;36mHTTPSConnectionPool._validate_conn\u001b[0;34m(self, conn)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(conn, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msock\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m): \u001b[38;5;66;03m# AppEngine might not have `.sock`\u001b[39;00m\n\u001b[0;32m-> 1058\u001b[0m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1060\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn\u001b[38;5;241m.\u001b[39mis_verified:\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connection.py:363\u001b[0m, in \u001b[0;36mHTTPSConnection.connect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconnect\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;66;03m# Add certificate verification\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msock \u001b[38;5;241m=\u001b[39m conn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_new_conn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 364\u001b[0m hostname \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhost\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connection.py:186\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m NewConnectionError(\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to establish a new connection: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m e\n\u001b[1;32m 188\u001b[0m )\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m conn\n", "\u001b[0;31mNewConnectionError\u001b[0m: : Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mMaxRetryError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/adapters.py:486\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 486\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 487\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 491\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 495\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_retries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 496\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ProtocolError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/connectionpool.py:799\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 797\u001b[0m e \u001b[38;5;241m=\u001b[39m ProtocolError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConnection aborted.\u001b[39m\u001b[38;5;124m\"\u001b[39m, e)\n\u001b[0;32m--> 799\u001b[0m retries \u001b[38;5;241m=\u001b[39m \u001b[43mretries\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mincrement\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 800\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_pool\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_stacktrace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexc_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 801\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 802\u001b[0m retries\u001b[38;5;241m.\u001b[39msleep()\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/urllib3/util/retry.py:592\u001b[0m, in \u001b[0;36mRetry.increment\u001b[0;34m(self, method, url, response, error, _pool, _stacktrace)\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_retry\u001b[38;5;241m.\u001b[39mis_exhausted():\n\u001b[0;32m--> 592\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause))\n\u001b[1;32m 594\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIncremented Retry for (url=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, url, new_retry)\n", "\u001b[0;31mMaxRetryError\u001b[0m: HTTPSConnectionPool(host='raw.githubusercontent.com', port=443): Max retries exceeded with url: /pycaret/datasets/main/data/common/diabetes.csv (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mConnectionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mget_data\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdiabetes\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m df\u001b[38;5;241m.\u001b[39mdescribe()\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pycaret/datasets.py:116\u001b[0m, in \u001b[0;36mget_data\u001b[0;34m(dataset, folder, save_copy, profile, verbose, address)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misfile(filename):\n\u001b[1;32m 115\u001b[0m data \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(filename)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[43mrequests\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcomplete_address\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mstatus_code \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m200\u001b[39m:\n\u001b[1;32m 117\u001b[0m data \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(complete_address)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m dataset \u001b[38;5;129;01min\u001b[39;00m sktime_datasets:\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/api.py:73\u001b[0m, in \u001b[0;36mget\u001b[0;34m(url, params, **kwargs)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget\u001b[39m(url, params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 63\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Sends a GET request.\u001b[39;00m\n\u001b[1;32m 64\u001b[0m \n\u001b[1;32m 65\u001b[0m \u001b[38;5;124;03m :param url: URL for the new :class:`Request` object.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;124;03m :rtype: requests.Response\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mget\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/api.py:59\u001b[0m, in \u001b[0;36mrequest\u001b[0;34m(method, url, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# By using the 'with' statement we are sure the session is closed, thus we\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# avoid leaving sockets open which can trigger a ResourceWarning in some\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# cases, and look like a memory leak in others.\u001b[39;00m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m sessions\u001b[38;5;241m.\u001b[39mSession() \u001b[38;5;28;01mas\u001b[39;00m session:\n\u001b[0;32m---> 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msession\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/sessions.py:589\u001b[0m, in \u001b[0;36mSession.request\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 584\u001b[0m send_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m: timeout,\n\u001b[1;32m 586\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m: allow_redirects,\n\u001b[1;32m 587\u001b[0m }\n\u001b[1;32m 588\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[0;32m--> 589\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msend_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/sessions.py:703\u001b[0m, in \u001b[0;36mSession.send\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m start \u001b[38;5;241m=\u001b[39m preferred_clock()\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43madapter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n\u001b[1;32m 706\u001b[0m elapsed \u001b[38;5;241m=\u001b[39m preferred_clock() \u001b[38;5;241m-\u001b[39m start\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/requests/adapters.py:519\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e\u001b[38;5;241m.\u001b[39mreason, _SSLError):\n\u001b[1;32m 516\u001b[0m \u001b[38;5;66;03m# This branch is for urllib3 v1.22 and later.\u001b[39;00m\n\u001b[1;32m 517\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m SSLError(e, request\u001b[38;5;241m=\u001b[39mrequest)\n\u001b[0;32m--> 519\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(e, request\u001b[38;5;241m=\u001b[39mrequest)\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ClosedPoolError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 522\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(e, request\u001b[38;5;241m=\u001b[39mrequest)\n", "\u001b[0;31mConnectionError\u001b[0m: HTTPSConnectionPool(host='raw.githubusercontent.com', port=443): Max retries exceeded with url: /pycaret/datasets/main/data/common/diabetes.csv (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))" ] } ], "source": [ "df = get_data('diabetes')\n", "df.describe()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 6264 entries, 0 to 6263\n", "Data columns (total 18 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 age 6264 non-null int64 \n", " 1 gender 6264 non-null int64 \n", " 2 height 6264 non-null float64\n", " 3 weight 6264 non-null float64\n", " 4 steps 6264 non-null float64\n", " 5 hear_rate 6264 non-null float64\n", " 6 calories 6264 non-null float64\n", " 7 distance 6264 non-null float64\n", " 8 entropy_heart 6264 non-null float64\n", " 9 entropy_setps 6264 non-null float64\n", " 10 resting_heart 6264 non-null float64\n", " 11 corr_heart_steps 6264 non-null float64\n", " 12 norm_heart 6264 non-null float64\n", " 13 intensity_karvonen 6264 non-null float64\n", " 14 sd_norm_heart 6264 non-null float64\n", " 15 steps_times_distance 6264 non-null float64\n", " 16 device 6264 non-null object \n", " 17 activity 6264 non-null object \n", "dtypes: float64(14), int64(2), object(2)\n", "memory usage: 881.0+ KB\n" ] } ], "source": [ "df = pd.read_csv('../../data/aw_fb/aw_fb_data.csv')\n", "df = df.drop(['Unnamed: 0', 'X1'], axis=1)\n", "df.info()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Number of times pregnantPlasma glucose concentration a 2 hours in an oral glucose tolerance testDiastolic blood pressure (mm Hg)Triceps skin fold thickness (mm)2-Hour serum insulin (mu U/ml)Body mass index (weight in kg/(height in m)^2)Diabetes pedigree functionAge (years)Class variable
061487235033.60.627501
11856629026.60.351310
28183640023.30.672321
318966239428.10.167210
40137403516843.12.288331
\n", "
" ], "text/plain": [ " Number of times pregnant \\\n", "0 6 \n", "1 1 \n", "2 8 \n", "3 1 \n", "4 0 \n", "\n", " Plasma glucose concentration a 2 hours in an oral glucose tolerance test \\\n", "0 148 \n", "1 85 \n", "2 183 \n", "3 89 \n", "4 137 \n", "\n", " Diastolic blood pressure (mm Hg) Triceps skin fold thickness (mm) \\\n", "0 72 35 \n", "1 66 29 \n", "2 64 0 \n", "3 66 23 \n", "4 40 35 \n", "\n", " 2-Hour serum insulin (mu U/ml) \\\n", "0 0 \n", "1 0 \n", "2 0 \n", "3 94 \n", "4 168 \n", "\n", " Body mass index (weight in kg/(height in m)^2) Diabetes pedigree function \\\n", "0 33.6 0.627 \n", "1 26.6 0.351 \n", "2 23.3 0.672 \n", "3 28.1 0.167 \n", "4 43.1 2.288 \n", "\n", " Age (years) Class variable \n", "0 50 1 \n", "1 31 0 \n", "2 32 1 \n", "3 21 0 \n", "4 33 1 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Number of times pregnantPlasma glucose concentration a 2 hours in an oral glucose tolerance testDiastolic blood pressure (mm Hg)Triceps skin fold thickness (mm)2-Hour serum insulin (mu U/ml)Body mass index (weight in kg/(height in m)^2)Diabetes pedigree functionAge (years)Class variable
061487235033.60.627501
11856629026.60.351310
28183640023.30.672321
318966239428.10.167210
\n", "
" ], "text/plain": [ " Number of times pregnant \\\n", "0 6 \n", "1 1 \n", "2 8 \n", "3 1 \n", "\n", " Plasma glucose concentration a 2 hours in an oral glucose tolerance test \\\n", "0 148 \n", "1 85 \n", "2 183 \n", "3 89 \n", "\n", " Diastolic blood pressure (mm Hg) Triceps skin fold thickness (mm) \\\n", "0 72 35 \n", "1 66 29 \n", "2 64 0 \n", "3 66 23 \n", "\n", " 2-Hour serum insulin (mu U/ml) \\\n", "0 0 \n", "1 0 \n", "2 0 \n", "3 94 \n", "\n", " Body mass index (weight in kg/(height in m)^2) Diabetes pedigree function \\\n", "0 33.6 0.627 \n", "1 26.6 0.351 \n", "2 23.3 0.672 \n", "3 28.1 0.167 \n", "\n", " Age (years) Class variable \n", "0 50 1 \n", "1 31 0 \n", "2 32 1 \n", "3 21 0 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# df = get_data('diabetes')÷\n", "# group by last column\n", "train_df = df.groupby(df.columns[-1]).head(2)\n", "# drop train_df from df\n", "# df = df.drop(train_df.index)\n", "train_df" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "le = LabelEncoder()\n", "df_aw['activity'] = le.fit_transform(df_aw['activity'])\n", "df_fb['activity'] = le.fit_transform(df_fb['activity'])" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agegenderheightweightstepshear_ratecaloriesdistanceentropy_heartentropy_setpsresting_heartcorr_heart_stepsnorm_heartintensity_karvonensd_norm_heartsteps_times_distanceactivity
0201168.065.410.77142978.5313020.3445330.0083276.2216126.11634959.0000001.00000019.5313020.1385201.0000000.0896920
1201168.065.411.47532578.4533903.2876250.0088966.2216126.11634959.0000001.00000019.4533900.1379671.0000000.1020880
2201168.065.412.17922178.5408259.4840000.0094666.2216126.11634959.0000001.00000019.5408250.1385871.0000000.1152870
3201168.065.412.88311778.62826010.1545560.0100356.2216126.11634959.0000001.00000019.6282600.1392081.0000000.1292860
4201168.065.413.58701378.71569510.8251110.0106056.2216126.11634959.0000000.98281619.7156950.1398280.2415670.1440880
......................................................
3651460157.571.4163.000000157.2500000.7015000.0752006.1624275.65531079.4217951.00000077.8282050.8228987.27020412.2576003
3652460157.571.46.666667157.3076920.7015000.0754756.1624275.65531079.4217951.00000077.8858970.8235081.0000000.5031673
3653460157.571.46.750000156.2500000.7320000.0756956.1624275.65531079.4217951.00000076.8282050.8123251.0000000.5109413
3654460157.571.46.791667158.0909090.6125000.0772706.1624275.65531079.4217951.00000078.6691140.8317891.0000000.5247923
3655460157.571.46.750000157.2307690.6710000.0759656.1624275.65531079.4217951.00000077.8089740.8226951.0000000.5127643
\n", "

3656 rows × 17 columns

\n", "
" ], "text/plain": [ " age gender height weight steps hear_rate calories \\\n", "0 20 1 168.0 65.4 10.771429 78.531302 0.344533 \n", "1 20 1 168.0 65.4 11.475325 78.453390 3.287625 \n", "2 20 1 168.0 65.4 12.179221 78.540825 9.484000 \n", "3 20 1 168.0 65.4 12.883117 78.628260 10.154556 \n", "4 20 1 168.0 65.4 13.587013 78.715695 10.825111 \n", "... ... ... ... ... ... ... ... \n", "3651 46 0 157.5 71.4 163.000000 157.250000 0.701500 \n", "3652 46 0 157.5 71.4 6.666667 157.307692 0.701500 \n", "3653 46 0 157.5 71.4 6.750000 156.250000 0.732000 \n", "3654 46 0 157.5 71.4 6.791667 158.090909 0.612500 \n", "3655 46 0 157.5 71.4 6.750000 157.230769 0.671000 \n", "\n", " distance entropy_heart entropy_setps resting_heart corr_heart_steps \\\n", "0 0.008327 6.221612 6.116349 59.000000 1.000000 \n", "1 0.008896 6.221612 6.116349 59.000000 1.000000 \n", "2 0.009466 6.221612 6.116349 59.000000 1.000000 \n", "3 0.010035 6.221612 6.116349 59.000000 1.000000 \n", "4 0.010605 6.221612 6.116349 59.000000 0.982816 \n", "... ... ... ... ... ... \n", "3651 0.075200 6.162427 5.655310 79.421795 1.000000 \n", "3652 0.075475 6.162427 5.655310 79.421795 1.000000 \n", "3653 0.075695 6.162427 5.655310 79.421795 1.000000 \n", "3654 0.077270 6.162427 5.655310 79.421795 1.000000 \n", "3655 0.075965 6.162427 5.655310 79.421795 1.000000 \n", "\n", " norm_heart intensity_karvonen sd_norm_heart steps_times_distance \\\n", "0 19.531302 0.138520 1.000000 0.089692 \n", "1 19.453390 0.137967 1.000000 0.102088 \n", "2 19.540825 0.138587 1.000000 0.115287 \n", "3 19.628260 0.139208 1.000000 0.129286 \n", "4 19.715695 0.139828 0.241567 0.144088 \n", "... ... ... ... ... \n", "3651 77.828205 0.822898 7.270204 12.257600 \n", "3652 77.885897 0.823508 1.000000 0.503167 \n", "3653 76.828205 0.812325 1.000000 0.510941 \n", "3654 78.669114 0.831789 1.000000 0.524792 \n", "3655 77.808974 0.822695 1.000000 0.512764 \n", "\n", " activity \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", "... ... \n", "3651 3 \n", "3652 3 \n", "3653 3 \n", "3654 3 \n", "3655 3 \n", "\n", "[3656 rows x 17 columns]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_aw" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agegenderheightweightstepshear_ratecaloriesdistanceentropy_heartentropy_setpsresting_heartcorr_heart_stepsnorm_heartintensity_karvonensd_norm_heartsteps_times_distanceactivity
0201168.065.410.77142978.5313020.3445330.0083276.2216126.11634959.0000001.00000019.5313020.1385201.0000000.089692Lying
1201168.065.411.47532578.4533903.2876250.0088966.2216126.11634959.0000001.00000019.4533900.1379671.0000000.102088Lying
2201168.065.412.17922178.5408259.4840000.0094666.2216126.11634959.0000001.00000019.5408250.1385871.0000000.115287Lying
3201168.065.412.88311778.62826010.1545560.0100356.2216126.11634959.0000001.00000019.6282600.1392081.0000000.129286Lying
4201168.065.413.58701378.71569510.8251110.0106056.2216126.11634959.0000000.98281619.7156950.1398280.2415670.144088Lying
......................................................
3651460157.571.4163.000000157.2500000.7015000.0752006.1624275.65531079.4217951.00000077.8282050.8228987.27020412.257600Running 7 METs
3652460157.571.46.666667157.3076920.7015000.0754756.1624275.65531079.4217951.00000077.8858970.8235081.0000000.503167Running 7 METs
3653460157.571.46.750000156.2500000.7320000.0756956.1624275.65531079.4217951.00000076.8282050.8123251.0000000.510941Running 7 METs
3654460157.571.46.791667158.0909090.6125000.0772706.1624275.65531079.4217951.00000078.6691140.8317891.0000000.524792Running 7 METs
3655460157.571.46.750000157.2307690.6710000.0759656.1624275.65531079.4217951.00000077.8089740.8226951.0000000.512764Running 7 METs
\n", "

3656 rows × 17 columns

\n", "
" ], "text/plain": [ " age gender height weight steps hear_rate calories \\\n", "0 20 1 168.0 65.4 10.771429 78.531302 0.344533 \n", "1 20 1 168.0 65.4 11.475325 78.453390 3.287625 \n", "2 20 1 168.0 65.4 12.179221 78.540825 9.484000 \n", "3 20 1 168.0 65.4 12.883117 78.628260 10.154556 \n", "4 20 1 168.0 65.4 13.587013 78.715695 10.825111 \n", "... ... ... ... ... ... ... ... \n", "3651 46 0 157.5 71.4 163.000000 157.250000 0.701500 \n", "3652 46 0 157.5 71.4 6.666667 157.307692 0.701500 \n", "3653 46 0 157.5 71.4 6.750000 156.250000 0.732000 \n", "3654 46 0 157.5 71.4 6.791667 158.090909 0.612500 \n", "3655 46 0 157.5 71.4 6.750000 157.230769 0.671000 \n", "\n", " distance entropy_heart entropy_setps resting_heart corr_heart_steps \\\n", "0 0.008327 6.221612 6.116349 59.000000 1.000000 \n", "1 0.008896 6.221612 6.116349 59.000000 1.000000 \n", "2 0.009466 6.221612 6.116349 59.000000 1.000000 \n", "3 0.010035 6.221612 6.116349 59.000000 1.000000 \n", "4 0.010605 6.221612 6.116349 59.000000 0.982816 \n", "... ... ... ... ... ... \n", "3651 0.075200 6.162427 5.655310 79.421795 1.000000 \n", "3652 0.075475 6.162427 5.655310 79.421795 1.000000 \n", "3653 0.075695 6.162427 5.655310 79.421795 1.000000 \n", "3654 0.077270 6.162427 5.655310 79.421795 1.000000 \n", "3655 0.075965 6.162427 5.655310 79.421795 1.000000 \n", "\n", " norm_heart intensity_karvonen sd_norm_heart steps_times_distance \\\n", "0 19.531302 0.138520 1.000000 0.089692 \n", "1 19.453390 0.137967 1.000000 0.102088 \n", "2 19.540825 0.138587 1.000000 0.115287 \n", "3 19.628260 0.139208 1.000000 0.129286 \n", "4 19.715695 0.139828 0.241567 0.144088 \n", "... ... ... ... ... \n", "3651 77.828205 0.822898 7.270204 12.257600 \n", "3652 77.885897 0.823508 1.000000 0.503167 \n", "3653 76.828205 0.812325 1.000000 0.510941 \n", "3654 78.669114 0.831789 1.000000 0.524792 \n", "3655 77.808974 0.822695 1.000000 0.512764 \n", "\n", " activity \n", "0 Lying \n", "1 Lying \n", "2 Lying \n", "3 Lying \n", "4 Lying \n", "... ... \n", "3651 Running 7 METs \n", "3652 Running 7 METs \n", "3653 Running 7 METs \n", "3654 Running 7 METs \n", "3655 Running 7 METs \n", "\n", "[3656 rows x 17 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_aw = df[df['device'] == 'apple watch']\n", "df_fb = df[df['device'] == 'fitbit']\n", "df_aw = df_aw.drop('device', axis=1)\n", "df_fb = df_fb.drop('device',axis=1)\n", "df_aw" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['Lying', 'Running 3 METs', 'Running 5 METs', 'Running 7 METs',\n", " 'Self Pace walk', 'Sitting'], dtype=object)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_aw['activity'].unique()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def stratified_partition_with_all_values(df, column, n_partitions, partition_id):\n", " # Group the data by the column of interest\n", " grouped = df.groupby(column)\n", "\n", " # Initialize empty list to store partitions\n", " partitions = [[] for _ in range(n_partitions)]\n", "\n", " # Iterate over groups\n", " for name, group in grouped:\n", " # Randomly shuffle the data within the group\n", " group = group.sample(frac=1).reset_index(drop=True)\n", "\n", " # Calculate the number of samples in each partition for this group\n", " samples_per_partition = len(group) // n_partitions\n", "\n", " # Distribute the data evenly among partitions, ensuring each partition has all values\n", " for i in range(n_partitions):\n", " start_idx = i * samples_per_partition\n", " end_idx = (i + 1) * samples_per_partition\n", " if i == n_partitions - 1:\n", " end_idx = None # Include remaining samples in the last partition\n", " partition_data = group.iloc[start_idx:end_idx]\n", " partitions[i].append(partition_data)\n", "\n", " # Concatenate data frames in each partition\n", " partitions = [pd.concat(partition) for partition in partitions]\n", "\n", " return partitions[partition_id].reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 1, 2, 3, 4, 5])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = stratified_partition_with_all_values(df_aw, 'activity', 3, 2)\n", "df['activity'].unique()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 DescriptionValue
0Session id42
1Targetactivity
2Target typeMulticlass
3Original data shape(1222, 17)
4Transformed data shape(1222, 17)
5Transformed train set shape(855, 17)
6Transformed test set shape(367, 17)
7Numeric features16
8PreprocessTrue
9Imputation typesimple
10Numeric imputationmean
11Categorical imputationmode
12Fold GeneratorStratifiedKFold
13Fold Number10
14CPU Jobs-1
15Use GPUFalse
16Log ExperimentFalse
17Experiment Nameclf-default-name
18USI16c3
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Name Log Loss\n", "Display Name Log Loss\n", "Score Function \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
NameReferenceTurbomodel_type
ID
lrLogistic Regressionsklearn.linear_model._logistic.LogisticRegressionTruelinear
knnK Neighbors Classifiersklearn.neighbors._classification.KNeighborsCl...Trueother
nbNaive Bayessklearn.naive_bayes.GaussianNBTrueother
dtDecision Tree Classifiersklearn.tree._classes.DecisionTreeClassifierTruetree
svmSVM - Linear Kernelsklearn.linear_model._stochastic_gradient.SGDC...Truelinear
rbfsvmSVM - Radial Kernelsklearn.svm._classes.SVCFalseother
gpcGaussian Process Classifiersklearn.gaussian_process._gpc.GaussianProcessC...Falseother
mlpMLP Classifiersklearn.neural_network._multilayer_perceptron....Falseother
ridgeRidge Classifiersklearn.linear_model._ridge.RidgeClassifierTruelinear
rfRandom Forest Classifiersklearn.ensemble._forest.RandomForestClassifierTrueensemble
qdaQuadratic Discriminant Analysissklearn.discriminant_analysis.QuadraticDiscrim...Trueother
adaAda Boost Classifiersklearn.ensemble._weight_boosting.AdaBoostClas...Trueensemble
gbcGradient Boosting Classifiersklearn.ensemble._gb.GradientBoostingClassifierTrueensemble
ldaLinear Discriminant Analysissklearn.discriminant_analysis.LinearDiscrimina...Trueother
etExtra Trees Classifiersklearn.ensemble._forest.ExtraTreesClassifierTrueensemble
xgboostExtreme Gradient Boostingxgboost.sklearn.XGBClassifierTrueensemble
lightgbmLight Gradient Boosting Machinelightgbm.sklearn.LGBMClassifierTrueensemble
catboostCatBoost Classifiercatboost.core.CatBoostClassifierTrueensemble
dummyDummy Classifiersklearn.dummy.DummyClassifierTrueother
\n", "" ], "text/plain": [ " Name \\\n", "ID \n", "lr Logistic Regression \n", "knn K Neighbors Classifier \n", "nb Naive Bayes \n", "dt Decision Tree Classifier \n", "svm SVM - Linear Kernel \n", "rbfsvm SVM - Radial Kernel \n", "gpc Gaussian Process Classifier \n", "mlp MLP Classifier \n", "ridge Ridge Classifier \n", "rf Random Forest Classifier \n", "qda Quadratic Discriminant Analysis \n", "ada Ada Boost Classifier \n", "gbc Gradient Boosting Classifier \n", "lda Linear Discriminant Analysis \n", "et Extra Trees Classifier \n", "xgboost Extreme Gradient Boosting \n", "lightgbm Light Gradient Boosting Machine \n", "catboost CatBoost Classifier \n", "dummy Dummy Classifier \n", "\n", " Reference Turbo model_type \n", "ID \n", "lr sklearn.linear_model._logistic.LogisticRegression True linear \n", "knn sklearn.neighbors._classification.KNeighborsCl... True other \n", "nb sklearn.naive_bayes.GaussianNB True other \n", "dt sklearn.tree._classes.DecisionTreeClassifier True tree \n", "svm sklearn.linear_model._stochastic_gradient.SGDC... True linear \n", "rbfsvm sklearn.svm._classes.SVC False other \n", "gpc sklearn.gaussian_process._gpc.GaussianProcessC... False other \n", "mlp sklearn.neural_network._multilayer_perceptron.... False other \n", "ridge sklearn.linear_model._ridge.RidgeClassifier True linear \n", "rf sklearn.ensemble._forest.RandomForestClassifier True ensemble \n", "qda sklearn.discriminant_analysis.QuadraticDiscrim... True other \n", "ada sklearn.ensemble._weight_boosting.AdaBoostClas... True ensemble \n", "gbc sklearn.ensemble._gb.GradientBoostingClassifier True ensemble \n", "lda sklearn.discriminant_analysis.LinearDiscrimina... True other \n", "et sklearn.ensemble._forest.ExtraTreesClassifier True ensemble \n", "xgboost xgboost.sklearn.XGBClassifier True ensemble \n", "lightgbm lightgbm.sklearn.LGBMClassifier True ensemble \n", "catboost catboost.core.CatBoostClassifier True ensemble \n", "dummy sklearn.dummy.DummyClassifier True other " ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_type = {\n", " \"linear\": [\n", " \"lr\",\n", " \"ridge\",\n", " \"svm\",\n", " \"lasso\",\n", " \"en\",\n", " \"lar\",\n", " \"llar\",\n", " \"omp\",\n", " \"br\",\n", " \"ard\",\n", " \"par\",\n", " \"ransac\",\n", " \"tr\",\n", " \"huber\",\n", " \"kr\",\n", " ],\n", " \"tree\": [\"dt\"],\n", " \"ensemble\": [\n", " \"rf\",\n", " \"et\",\n", " \"gbc\",\n", " \"gbr\",\n", " \"xgboost\",\n", " \"lightgbm\",\n", " \"catboost\",\n", " \"ada\",\n", " ],\n", "}\n", "\n", "models = exp.models()\n", "\n", "def fil(x):\n", "\treturn False\n", "\n", "# add model type to models dataframe based on index\n", "models['model_type'] = models.index.map(lambda x: 'linear' if x in model_type['linear'] else 'tree' if x in model_type['tree'] else 'ensemble' if x in model_type['ensemble'] else 'other')\n", "models" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "def get_model_weights(model):\n", " \"\"\"\n", " Returns model weights (coefficients) if the model supports them.\n", " For ensemble models, it returns feature importances.\n", " \"\"\"\n", " model_type = exp._get_model_id(model)\n", " print(model_type)\n", " if model_type in model_type['linear']:\n", " return model.coef_\n", " elif model_type in model_type['tree']:\n", " return model.feature_importances_\n", " else:\n", " raise ValueError(\n", " f\"Model type {model_type} does not support weight extraction.\")\n", "\n", "\n", "def set_model_weights(model, weights):\n", " \"\"\"\n", " Sets model weights (coefficients) if the model allows it.\n", " For ensemble models, setting weights is not allowed as it doesn't make sense.\n", " \"\"\"\n", " model_type = type(model)\n", " if model_type in [LogisticRegression, DecisionTreeClassifier]:\n", " model.coef_ = weights\n", " elif model_type is RandomForestClassifier:\n", " raise ValueError(\n", " \"Cannot set weights for ensemble models like RandomForest.\")\n", " else:\n", " raise ValueError(\n", " f\"Model type {model_type} does not support setting weights.\")" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 AccuracyAUCRecallPrec.F1KappaMCCLog Loss
Fold        
00.72090.91840.72090.72020.71740.66440.66591.0252
10.65120.86520.65120.66490.65140.57860.58081.6663
20.67440.92240.67440.69740.67480.60670.61111.0495
30.63950.90450.63950.64450.63850.56540.56661.1714
40.66280.88170.66280.67390.66090.59450.59681.3829
50.61180.87430.61180.60380.60180.53090.53301.5517
60.70590.90910.70590.70860.70360.64490.64661.1767
70.70590.91540.70590.71910.70330.64570.64881.1926
80.65880.89730.65880.66320.65950.58920.58981.2476
90.76470.92610.76470.77130.76410.71580.71701.0081
Mean0.67960.90140.67960.68670.67750.61360.61561.2472
Std0.04250.02010.04250.04450.04360.05140.05130.2112
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# best = exp.compare_models(cross_validation=False)\n", "model = exp.create_model('lightgbm', train_model=True)\n", "# metrics = exp.pull()" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'age': 504,\n", " 'gender': 77,\n", " 'height': 462,\n", " 'weight': 541,\n", " 'steps': 1607,\n", " 'hear_rate': 1823,\n", " 'calories': 2296,\n", " 'distance': 1470,\n", " 'entropy_heart': 644,\n", " 'entropy_setps': 714,\n", " 'resting_heart': 478,\n", " 'corr_heart_steps': 1686,\n", " 'norm_heart': 1722,\n", " 'intensity_karvonen': 1260,\n", " 'sd_norm_heart': 1888,\n", " 'steps_times_distance': 802}" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.booster_.dump_model()['feature_importances']" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'ExtraTreesClassifier' object has no attribute 'named_parameters'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[75], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mbest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnamed_parameters\u001b[49m()\n", "\u001b[0;31mAttributeError\u001b[0m: 'ExtraTreesClassifier' object has no attribute 'named_parameters'" ] } ], "source": [ "best.named_parameters()" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'boosting_type': 'gbdt',\n", " 'objective': None,\n", " 'num_leaves': 31,\n", " 'max_depth': -1,\n", " 'learning_rate': 0.1,\n", " 'n_estimators': 100,\n", " 'subsample_for_bin': 200000,\n", " 'min_split_gain': 0.0,\n", " 'min_child_weight': 0.001,\n", " 'min_child_samples': 20,\n", " 'subsample': 1.0,\n", " 'subsample_freq': 0,\n", " 'colsample_bytree': 1.0,\n", " 'reg_alpha': 0.0,\n", " 'reg_lambda': 0.0,\n", " 'random_state': 42,\n", " 'n_jobs': -1,\n", " 'importance_type': 'split',\n", " '_Booster': ,\n", " '_evals_result': {},\n", " '_best_score': defaultdict(collections.OrderedDict, {}),\n", " '_best_iteration': 0,\n", " '_other_params': {},\n", " '_objective': 'multiclass',\n", " 'class_weight': None,\n", " '_class_weight': None,\n", " '_class_map': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5},\n", " '_n_features': 16,\n", " '_n_features_in': 16,\n", " '_classes': array([0, 1, 2, 3, 4, 5], dtype=int8),\n", " '_n_classes': 6,\n", " '_le': LabelEncoder(),\n", " 'fitted_': True}" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vars(model)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 1.0,\n", " 'class_weight': None,\n", " 'dual': False,\n", " 'fit_intercept': True,\n", " 'intercept_scaling': 1,\n", " 'l1_ratio': None,\n", " 'max_iter': 1000,\n", " 'multi_class': 'auto',\n", " 'n_jobs': None,\n", " 'penalty': 'l2',\n", " 'random_state': 42,\n", " 'solver': 'lbfgs',\n", " 'tol': 0.0001,\n", " 'verbose': 0,\n", " 'warm_start': False}" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.get_params()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 ModelAccuracyAUCRecallPrec.F1KappaMCCLog Loss
0CatBoost Classifier0.70570.92740.70570.70900.70610.64470.64520
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.predict_model(best)\n", "df = exp.pull()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Model': 'CatBoost Classifier',\n", " 'Accuracy': 0.7057,\n", " 'AUC': 0.9274,\n", " 'Recall': 0.7057,\n", " 'Prec.': 0.709,\n", " 'F1': 0.7061,\n", " 'Kappa': 0.6447,\n", " 'MCC': 0.6452,\n", " 'Log Loss': 0}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.iloc[0].to_dict()\n", "# remove model key" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "ename": "KeyError", "evalue": "'model'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pandas/core/indexes/base.py:3802\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3801\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3802\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3803\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pandas/_libs/index.pyx:138\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pandas/_libs/index.pyx:165\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5745\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5753\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mKeyError\u001b[0m: 'model'", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[17], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pandas/core/generic.py:4243\u001b[0m, in \u001b[0;36mNDFrame.__delitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4238\u001b[0m deleted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 4239\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m deleted:\n\u001b[1;32m 4240\u001b[0m \u001b[38;5;66;03m# If the above loop ran and didn't delete anything because\u001b[39;00m\n\u001b[1;32m 4241\u001b[0m \u001b[38;5;66;03m# there was no match, this call should raise the appropriate\u001b[39;00m\n\u001b[1;32m 4242\u001b[0m \u001b[38;5;66;03m# exception:\u001b[39;00m\n\u001b[0;32m-> 4243\u001b[0m loc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maxes\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4244\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mgr \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mgr\u001b[38;5;241m.\u001b[39midelete(loc)\n\u001b[1;32m 4246\u001b[0m \u001b[38;5;66;03m# delete from the caches\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/3.10env/lib/python3.8/site-packages/pandas/core/indexes/base.py:3804\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3802\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_engine\u001b[38;5;241m.\u001b[39mget_loc(casted_key)\n\u001b[1;32m 3803\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m-> 3804\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3805\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n", "\u001b[0;31mKeyError\u001b[0m: 'model'" ] } ], "source": [ ".pop('key', None)" ] }, { "cell_type": "code", "execution_count": 145, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[None,\n", " ('objective', 'multi:softprob'),\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " ('verbosity', 0),\n", " ('booster', 'gbtree'),\n", " ('tree_method', 'auto'),\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " ('missing', nan),\n", " None,\n", " ('random_state', 42),\n", " ('n_jobs', -1),\n", " None,\n", " None,\n", " None,\n", " ('device', 'cpu'),\n", " None,\n", " ('enable_categorical', False),\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " None,\n", " ('kwargs', {'train': True}),\n", " ('n_classes_', 6),\n", " ('_Booster', )]" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[(f,getattr(best, f)) if getattr(best,f) is not None else None for f in vars(best)]" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgboost.sklearn.XGBClassifier" ] }, "execution_count": 133, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(best)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'objective': 'multi:softprob',\n", " 'base_score': None,\n", " 'booster': 'gbtree',\n", " 'callbacks': None,\n", " 'colsample_bylevel': None,\n", " 'colsample_bynode': None,\n", " 'colsample_bytree': None,\n", " 'device': 'cpu',\n", " 'early_stopping_rounds': None,\n", " 'enable_categorical': False,\n", " 'eval_metric': None,\n", " 'feature_types': None,\n", " 'gamma': None,\n", " 'grow_policy': None,\n", " 'importance_type': None,\n", " 'interaction_constraints': None,\n", " 'learning_rate': None,\n", " 'max_bin': None,\n", " 'max_cat_threshold': None,\n", " 'max_cat_to_onehot': None,\n", " 'max_delta_step': None,\n", " 'max_depth': None,\n", " 'max_leaves': None,\n", " 'min_child_weight': None,\n", " 'missing': nan,\n", " 'monotone_constraints': None,\n", " 'multi_strategy': None,\n", " 'n_estimators': None,\n", " 'n_jobs': -1,\n", " 'num_parallel_tree': None,\n", " 'random_state': 42,\n", " 'reg_alpha': None,\n", " 'reg_lambda': None,\n", " 'sampling_method': None,\n", " 'scale_pos_weight': None,\n", " 'subsample': None,\n", " 'tree_method': 'auto',\n", " 'validate_parameters': None,\n", " 'verbosity': 0}" ] }, "execution_count": 130, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best.get_params()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Running 3 METs', 'Lying', 'Sitting', 'Self Pace walk', 'Running 5 METs', 'Running 7 METs']\n", "Categories (6, object): ['Lying', 'Running 3 METs', 'Running 5 METs', 'Running 7 METs', 'Self Pace walk', 'Sitting']" ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = exp.y_test\n", "y.unique()" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 ModelAccuracyAUCRecallPrec.F1KappaMCCLog Loss
0Random Forest Classifier0.67300.90800.67300.68150.67500.60600.60670
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agegenderheightweightstepshear_ratecaloriesdistanceentropy_heartentropy_setpsresting_heartcorr_heart_stepsnorm_heartintensity_karvonensd_norm_heartsteps_times_distanceactivityprediction_labelprediction_score
299250166.068.000000106.000000119.0000004.6278000.0768306.1754855.84189177.0397640.17547841.9602360.3557155.3148128.143980Running 3 METsRunning 3 METs0.69
251231181.095.19999770.00000065.29068017.2764000.0527206.2597616.25976156.333332-1.0000008.9573450.0636780.9057293.690367LyingRunning 5 METs0.43
374231178.077.3000034.92381065.30477916.8444000.0036825.9438936.20945455.0000001.00000010.3047820.0725691.4744220.018131Running 3 METsLying0.86
1167250160.057.700001342.53247193.00000019.9540000.1887816.1421476.02055278.5313030.74171314.4686980.1242283.09552964.663727SittingSitting0.56
420290159.055.000000119.000000112.0593030.7350000.0448406.2094545.97283580.621513-0.17175431.4377860.2848183.2821605.335960Running 3 METsRunning 3 METs0.88
............................................................
423460157.571.40000264.000000114.2843550.7580000.0342006.1624275.65531079.4217910.88291634.8625640.3686112.9319902.188800Running 3 METsRunning 3 METs0.62
979310158.059.0999983.28571492.0000000.2520000.0024976.1952966.00115384.1999970.2586837.8000000.0744272.6239640.008205Self Pace walkLying0.61
1064191183.065.69999729.33333490.9206310.3236000.0218136.3037816.27846434.1538470.20590156.7667880.3402342.0412930.639858SittingSitting0.54
943220168.062.000000566.59997660.33333217.2387280.3889986.0751656.15308756.200001-0.9679084.1333330.0291491.009217220.406265Self Pace walkSelf Pace walk0.52
284231178.077.3000033.67074862.38287018.2136000.0026875.9438936.20945455.0000001.0000007.3828690.0519921.4744220.009863Running 3 METsRunning 3 METs0.53
\n", "

367 rows × 19 columns

\n", "
" ], "text/plain": [ " age gender height weight steps hear_rate calories \\\n", "299 25 0 166.0 68.000000 106.000000 119.000000 4.627800 \n", "251 23 1 181.0 95.199997 70.000000 65.290680 17.276400 \n", "374 23 1 178.0 77.300003 4.923810 65.304779 16.844400 \n", "1167 25 0 160.0 57.700001 342.532471 93.000000 19.954000 \n", "420 29 0 159.0 55.000000 119.000000 112.059303 0.735000 \n", "... ... ... ... ... ... ... ... \n", "423 46 0 157.5 71.400002 64.000000 114.284355 0.758000 \n", "979 31 0 158.0 59.099998 3.285714 92.000000 0.252000 \n", "1064 19 1 183.0 65.699997 29.333334 90.920631 0.323600 \n", "943 22 0 168.0 62.000000 566.599976 60.333332 17.238728 \n", "284 23 1 178.0 77.300003 3.670748 62.382870 18.213600 \n", "\n", " distance entropy_heart entropy_setps resting_heart corr_heart_steps \\\n", "299 0.076830 6.175485 5.841891 77.039764 0.175478 \n", "251 0.052720 6.259761 6.259761 56.333332 -1.000000 \n", "374 0.003682 5.943893 6.209454 55.000000 1.000000 \n", "1167 0.188781 6.142147 6.020552 78.531303 0.741713 \n", "420 0.044840 6.209454 5.972835 80.621513 -0.171754 \n", "... ... ... ... ... ... \n", "423 0.034200 6.162427 5.655310 79.421791 0.882916 \n", "979 0.002497 6.195296 6.001153 84.199997 0.258683 \n", "1064 0.021813 6.303781 6.278464 34.153847 0.205901 \n", "943 0.388998 6.075165 6.153087 56.200001 -0.967908 \n", "284 0.002687 5.943893 6.209454 55.000000 1.000000 \n", "\n", " norm_heart intensity_karvonen sd_norm_heart steps_times_distance \\\n", "299 41.960236 0.355715 5.314812 8.143980 \n", "251 8.957345 0.063678 0.905729 3.690367 \n", "374 10.304782 0.072569 1.474422 0.018131 \n", "1167 14.468698 0.124228 3.095529 64.663727 \n", "420 31.437786 0.284818 3.282160 5.335960 \n", "... ... ... ... ... \n", "423 34.862564 0.368611 2.931990 2.188800 \n", "979 7.800000 0.074427 2.623964 0.008205 \n", "1064 56.766788 0.340234 2.041293 0.639858 \n", "943 4.133333 0.029149 1.009217 220.406265 \n", "284 7.382869 0.051992 1.474422 0.009863 \n", "\n", " activity prediction_label prediction_score \n", "299 Running 3 METs Running 3 METs 0.69 \n", "251 Lying Running 5 METs 0.43 \n", "374 Running 3 METs Lying 0.86 \n", "1167 Sitting Sitting 0.56 \n", "420 Running 3 METs Running 3 METs 0.88 \n", "... ... ... ... \n", "423 Running 3 METs Running 3 METs 0.62 \n", "979 Self Pace walk Lying 0.61 \n", "1064 Sitting Sitting 0.54 \n", "943 Self Pace walk Self Pace walk 0.52 \n", "284 Running 3 METs Running 3 METs 0.53 \n", "\n", "[367 rows x 19 columns]" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.predict_model(best)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Initiated. . . . . . . . . . . . . . . . . .15:55:35
Status. . . . . . . . . . . . . . . . . .Selecting Estimator
Estimator. . . . . . . . . . . . . . . . . .Extreme Gradient Boosting
\n", "
" ], "text/plain": [ " \n", " \n", "Initiated . . . . . . . . . . . . . . . . . . 15:55:35\n", "Status . . . . . . . . . . . . . . . . . . Selecting Estimator\n", "Estimator . . . . . . . . . . . . . . . . . . Extreme Gradient Boosting" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = exp.create_model('xgboost', train_model=False)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 AccuracyAUCRecallPrec.F1KappaMCCLog Loss
Fold        
00.72090.93410.72090.72120.72070.66310.6632-0.0000
10.60470.86910.60470.61330.60580.52310.5244-0.0000
20.68600.91370.68600.70590.68760.62160.6246-0.0000
30.63950.90220.63950.64780.64110.56660.5675-0.0000
40.61630.87440.61630.62110.61360.53700.5388-0.0000
50.63530.86830.63530.63860.63400.55990.5609-0.0000
60.69410.91510.69410.70270.69250.63120.6334-0.0000
70.65880.90640.65880.66420.65730.58880.5904-0.0000
80.62350.88610.62350.63740.62550.54610.5474-0.0000
90.75290.92150.75290.76110.75360.70250.7037-0.0000
Mean0.66320.89910.66320.67130.66320.59400.59540.0000
Std0.04630.02210.04630.04630.04640.05590.05580.0000
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Fitting 10 folds for each of 10 candidates, totalling 100 fits\n", "Original model was better than the tuned model, hence it will be returned. NOTE: The display metrics are for the tuned model (not the original one).\n" ] } ], "source": [ "tuned = exp.tune_model(model)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('_Booster', ), ('n_classes_', 6)]" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# find diff between attributes in model and tuned\n", "[(i, getattr(tuned, i)) for i in set(vars(tuned).keys()) - set(vars(model).keys())]" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['feature_names_in_',\n", " 'n_features_in_',\n", " 'n_outputs_',\n", " 'classes_',\n", " 'n_classes_',\n", " 'estimator_',\n", " 'estimators_',\n", " array(['age', 'gender', 'height', 'weight', 'steps', 'hear_rate',\n", " 'calories', 'distance', 'entropy_heart', 'entropy_setps',\n", " 'resting_heart', 'corr_heart_steps', 'norm_heart',\n", " 'intensity_karvonen', 'sd_norm_heart', 'steps_times_distance'],\n", " dtype=object),\n", " 16,\n", " 1,\n", " array([0, 1, 2, 3, 4, 5], dtype=int8),\n", " 6,\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=None, splitter='best'),\n", " [DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1608637542, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1273642419, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1935803228, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787846414, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=996406378, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1201263687, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=423734972, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=415968276, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=670094950, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1914837113, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=669991378, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=429389014, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=249467210, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1972458954, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1572714583, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1433267572, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=434285667, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=613608295, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=893664919, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=648061058, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=88409749, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=242285876, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2018247425, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=953477463, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1427830251, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1883569565, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=911989541, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=3344769, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=780932287, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2114032571, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787716372, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=504579232, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1306710475, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=479546681, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=106328085, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=30349564, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1855189739, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=99052376, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1250819632, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=106406362, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=480404538, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1717389822, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=599121577, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=200427519, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1254751707, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2034764475, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1573512143, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=999745294, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1958805693, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=389151677, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1224821422, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=508464061, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=857592370, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1642661739, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=61136438, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2075460851, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=396917567, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2004731384, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=199502978, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1545932260, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=461901618, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=774414982, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=732395540, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1934879560, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=279394470, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=56972561, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1927948675, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1899242072, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1999874363, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=271820813, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1324556529, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1655351289, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1308306184, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=68574553, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=419498548, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=991681409, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=791274835, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1035196507, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1890440558, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787110843, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=524150214, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=472432043, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2126768636, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1431061255, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=147697582, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=744595490, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1758017741, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1679592528, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1111451555, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=782698033, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=698027879, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1096768899, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1338788865, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1826030589, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=86191493, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=893102645, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=200619113, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=290770691, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=793943861, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=134489564, splitter='best')]]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from flwr.common import NDArrays\n", "\n", "def get_model_parameters(model) -> NDArrays:\n", " \"\"\"Returns the parameters of a sklearn LogisticRegression model.\"\"\"\n", " attrs = [v for v in vars(model)\n", " if v.endswith(\"_\") and not v.startswith(\"__\")]\n", " params = attrs\n", " params += [getattr(model, v) for v in vars(model)\n", " if v.endswith(\"_\") and not v.startswith(\"__\")]\n", "\n", " return params\n", "\n", "params = get_model_parameters(model)\n", "params" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'bootstrap': True,\n", " 'ccp_alpha': 0.0,\n", " 'class_weight': None,\n", " 'criterion': 'gini',\n", " 'max_depth': None,\n", " 'max_features': 'sqrt',\n", " 'max_leaf_nodes': None,\n", " 'max_samples': None,\n", " 'min_impurity_decrease': 0.0,\n", " 'min_samples_leaf': 1,\n", " 'min_samples_split': 2,\n", " 'min_weight_fraction_leaf': 0.0,\n", " 'n_estimators': 100,\n", " 'n_jobs': -1,\n", " 'oob_score': False,\n", " 'random_state': 42,\n", " 'verbose': 0,\n", " 'warm_start': False}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.get_params()" ] }, { "cell_type": "code", "execution_count": 339, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'estimator': DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=None, splitter='best'),\n", " 'n_estimators': 100,\n", " 'estimator_params': ('criterion',\n", " 'max_depth',\n", " 'min_samples_split',\n", " 'min_samples_leaf',\n", " 'min_weight_fraction_leaf',\n", " 'max_features',\n", " 'max_leaf_nodes',\n", " 'min_impurity_decrease',\n", " 'random_state',\n", " 'ccp_alpha'),\n", " 'base_estimator': 'deprecated',\n", " 'bootstrap': True,\n", " 'oob_score': False,\n", " 'n_jobs': -1,\n", " 'random_state': 42,\n", " 'verbose': 0,\n", " 'warm_start': False,\n", " 'class_weight': None,\n", " 'max_samples': None,\n", " 'criterion': 'gini',\n", " 'max_depth': None,\n", " 'min_samples_split': 2,\n", " 'min_samples_leaf': 1,\n", " 'min_weight_fraction_leaf': 0.0,\n", " 'max_features': 'sqrt',\n", " 'max_leaf_nodes': None,\n", " 'min_impurity_decrease': 0.0,\n", " 'ccp_alpha': 0.0,\n", " 'feature_names_in_': array(['age', 'gender', 'height', 'weight', 'steps', 'hear_rate',\n", " 'calories', 'distance', 'entropy_heart', 'entropy_setps',\n", " 'resting_heart', 'corr_heart_steps', 'norm_heart',\n", " 'intensity_karvonen', 'sd_norm_heart', 'steps_times_distance'],\n", " dtype=object),\n", " 'n_features_in_': 16,\n", " 'n_outputs_': 1,\n", " 'classes_': array([0, 1, 2, 3, 4, 5]),\n", " 'n_classes_': 6,\n", " 'estimator_': DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=None, splitter='best'),\n", " 'estimators_': [DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1608637542, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1273642419, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1935803228, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787846414, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=996406378, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1201263687, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=423734972, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=415968276, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=670094950, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1914837113, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=669991378, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=429389014, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=249467210, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1972458954, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1572714583, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1433267572, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=434285667, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=613608295, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=893664919, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=648061058, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=88409749, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=242285876, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2018247425, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=953477463, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1427830251, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1883569565, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=911989541, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=3344769, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=780932287, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2114032571, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787716372, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=504579232, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1306710475, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=479546681, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=106328085, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=30349564, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1855189739, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=99052376, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1250819632, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=106406362, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=480404538, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1717389822, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=599121577, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=200427519, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1254751707, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2034764475, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1573512143, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=999745294, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1958805693, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=389151677, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1224821422, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=508464061, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=857592370, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1642661739, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=61136438, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2075460851, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=396917567, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2004731384, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=199502978, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1545932260, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=461901618, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=774414982, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=732395540, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1934879560, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=279394470, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=56972561, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1927948675, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1899242072, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1999874363, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=271820813, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1324556529, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1655351289, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1308306184, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=68574553, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=419498548, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=991681409, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=791274835, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1035196507, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1890440558, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=787110843, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=524150214, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=472432043, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=2126768636, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1431061255, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=147697582, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=744595490, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1758017741, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1679592528, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1111451555, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=782698033, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=698027879, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1096768899, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1338788865, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=1826030589, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=86191493, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=893102645, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=200619113, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=290770691, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=793943861, splitter='best'),\n", " DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='sqrt', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " random_state=134489564, splitter='best')]}" ] }, "execution_count": 339, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def set_model_params(model, params: NDArrays):\n", " \"\"\"Sets the parameters of a sklean model.\"\"\"\n", " for i in range(0, len(params) // 2):\n", " k, v = params[i], params[i+len(params) // 2]\n", " setattr(model, k, v)\n", " return model\n", "\n", "vars(set_model_params(model, params))" ] }, { "cell_type": "code", "execution_count": 341, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[100, None, 2, 1]" ] }, "execution_count": 341, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# random classifier attribut\n", "attrs = [\"n_estimators\", \"max_depth\",\n", " \"min_samples_split\", \"min_samples_leaf\"]\n", "[getattr(tuned, v) for v in attrs]" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Initiated. . . . . . . . . . . . . . . . . .09:36:24
Status. . . . . . . . . . . . . . . . . .Selecting Estimator
Estimator. . . . . . . . . . . . . . . . . .Random Forest Classifier
\n", "
" ], "text/plain": [ " \n", " \n", "Initiated . . . . . . . . . . . . . . . . . . 09:36:24\n", "Status . . . . . . . . . . . . . . . . . . Selecting Estimator\n", "Estimator . . . . . . . . . . . . . . . . . . Random Forest Classifier" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "unfitted = exp.create_model('rf', train_model=False)" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[110, 95, 101, 115, 116, 105, 109, 97, 116, 111, 114, 115],\n", " 100,\n", " [109,\n", " 105,\n", " 110,\n", " 95,\n", " 115,\n", " 97,\n", " 109,\n", " 112,\n", " 108,\n", " 101,\n", " 115,\n", " 95,\n", " 115,\n", " 112,\n", " 108,\n", " 105,\n", " 116],\n", " 2,\n", " [109, 105, 110, 95, 115, 97, 109, 112, 108, 101, 115, 95, 108, 101, 97, 102],\n", " 1]" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import utils\n", "params = utils.get_model_parameters(best)\n", "params" ] }, { "cell_type": "code", "execution_count": 340, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 ModelAccuracyAUCRecallPrec.F1KappaMCCLog Loss
0Random Forest Classifier0.83060.96880.83060.83070.83020.79610.79620
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# setattr(model, '_label_binarizer', getattr(tuned, '_label_binarizer'))\n", "df = exp.predict_model(model)" ] }, { "cell_type": "code", "execution_count": 301, "metadata": {}, "outputs": [], "source": [ "from sklearn.utils.validation import check_is_fitted\n", "check_is_fitted(model)" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0 9.362\n", " Name: Log Loss, dtype: float64,\n", " 0 0.7403\n", " Name: Accuracy, dtype: float64)" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = exp.pull()\n", "df['Log Loss'], df['Accuracy']" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-8.54252966])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr.intercept_" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 1.0,\n", " 'class_weight': None,\n", " 'dual': False,\n", " 'fit_intercept': True,\n", " 'intercept_scaling': 1,\n", " 'l1_ratio': None,\n", " 'max_iter': 1000,\n", " 'multi_class': 'auto',\n", " 'n_jobs': None,\n", " 'penalty': 'l2',\n", " 'random_state': 42,\n", " 'solver': 'lbfgs',\n", " 'tol': 0.0001,\n", " 'verbose': 0,\n", " 'warm_start': False}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr.get_params()" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
00.0063218.02.3100.5386.57565.24.0900129615.3396.904.9824.0
10.027310.07.0700.4696.42178.94.9671224217.8396.909.1421.6
20.027290.07.0700.4697.18561.14.9671224217.8392.834.0334.7
30.032370.02.1800.4586.99845.86.0622322218.7394.632.9433.4
40.069050.02.1800.4587.14754.26.0622322218.7396.905.3336.2
\n", "
" ], "text/plain": [ " crim zn indus chas nox rm age dis rad tax ptratio \\\n", "0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 \n", "1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 \n", "2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 \n", "3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 \n", "4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 \n", "\n", " black lstat medv \n", "0 396.90 4.98 24.0 \n", "1 396.90 9.14 21.6 \n", "2 392.83 4.03 34.7 \n", "3 394.63 2.94 33.4 \n", "4 396.90 5.33 36.2 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 DescriptionValue
0Session id2305
1Targetmedv
2Target typeRegression
3Original data shape(506, 14)
4Transformed data shape(506, 14)
5Transformed train set shape(354, 14)
6Transformed test set shape(152, 14)
7Numeric features13
8PreprocessTrue
9Imputation typesimple
10Numeric imputationmean
11Categorical imputationmode
12Fold GeneratorKFold
13Fold Number10
14CPU Jobs-1
15Use GPUFalse
16Log ExperimentFalse
17Experiment Namereg-default-name
18USI5063
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 ModelMAEMSERMSER2RMSLEMAPETT (Sec)
gbrGradient Boosting Regressor2.351310.57753.25230.88160.16440.12770.0700
catboostCatBoost Regressor2.358212.17893.48980.86370.16270.12350.6200
xgboostExtreme Gradient Boosting2.424113.46203.66910.84940.17010.12750.1300
rfRandom Forest Regressor2.546313.53743.67930.84850.18470.14290.0800
etExtra Trees Regressor2.380313.94843.73480.84390.17260.12590.0700
lightgbmLight Gradient Boosting Machine2.644615.87803.98470.82230.19420.14750.1300
adaAdaBoost Regressor3.039518.80474.33640.78960.21770.17700.0400
dtDecision Tree Regressor3.360523.39514.83690.73820.20980.17180.0100
lrLinear Regression3.705226.36965.13510.70490.26460.18750.0100
larLeast Angle Regression3.731526.72225.16940.70100.26450.18840.0100
ridgeRidge Regression3.741027.00795.19690.69780.29170.19140.0100
brBayesian Ridge3.785827.59155.25280.69130.28930.19570.0100
enElastic Net4.050831.57155.61890.64670.31070.20640.0100
lassoLasso Regression4.166033.29695.77030.62740.32040.21220.0000
llarLasso Least Angle Regression4.166333.29955.77060.62740.32060.21220.0000
huberHuber Regressor4.094534.13715.84270.61800.31130.20440.0200
knnK Neighbors Regressor4.819745.47616.74360.49120.27210.23070.0000
ompOrthogonal Matching Pursuit6.001169.84498.35730.21850.34260.30170.0100
dummyDummy Regressor6.894889.96639.4851-0.00670.41010.38090.0000
parPassive Aggressive Regressor8.6557126.617911.2525-0.41680.60460.37500.0100
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelMAEMSERMSER2RMSLEMAPETT (Sec)
gbrGradient Boosting Regressor2.351310.57753.25230.88160.16440.12770.07
catboostCatBoost Regressor2.358212.17893.48980.86370.16270.12350.62
xgboostExtreme Gradient Boosting2.424113.46203.66910.84940.17010.12750.13
rfRandom Forest Regressor2.546313.53743.67930.84850.18470.14290.08
etExtra Trees Regressor2.380313.94843.73480.84390.17260.12590.07
lightgbmLight Gradient Boosting Machine2.644615.87803.98470.82230.19420.14750.13
adaAdaBoost Regressor3.039518.80474.33640.78960.21770.17700.04
dtDecision Tree Regressor3.360523.39514.83690.73820.20980.17180.01
lrLinear Regression3.705226.36965.13510.70490.26460.18750.01
larLeast Angle Regression3.731526.72225.16940.70100.26450.18840.01
ridgeRidge Regression3.741027.00795.19690.69780.29170.19140.01
brBayesian Ridge3.785827.59155.25280.69130.28930.19570.01
enElastic Net4.050831.57155.61890.64670.31070.20640.01
lassoLasso Regression4.166033.29695.77030.62740.32040.21220.00
llarLasso Least Angle Regression4.166333.29955.77060.62740.32060.21220.00
huberHuber Regressor4.094534.13715.84270.61800.31130.20440.02
knnK Neighbors Regressor4.819745.47616.74360.49120.27210.23070.00
ompOrthogonal Matching Pursuit6.001169.84498.35730.21850.34260.30170.01
dummyDummy Regressor6.894889.96639.4851-0.00670.41010.38090.00
parPassive Aggressive Regressor8.6557126.617911.2525-0.41680.60460.37500.01
\n", "
" ], "text/plain": [ " Model MAE MSE RMSE R2 \\\n", "gbr Gradient Boosting Regressor 2.3513 10.5775 3.2523 0.8816 \n", "catboost CatBoost Regressor 2.3582 12.1789 3.4898 0.8637 \n", "xgboost Extreme Gradient Boosting 2.4241 13.4620 3.6691 0.8494 \n", "rf Random Forest Regressor 2.5463 13.5374 3.6793 0.8485 \n", "et Extra Trees Regressor 2.3803 13.9484 3.7348 0.8439 \n", "lightgbm Light Gradient Boosting Machine 2.6446 15.8780 3.9847 0.8223 \n", "ada AdaBoost Regressor 3.0395 18.8047 4.3364 0.7896 \n", "dt Decision Tree Regressor 3.3605 23.3951 4.8369 0.7382 \n", "lr Linear Regression 3.7052 26.3696 5.1351 0.7049 \n", "lar Least Angle Regression 3.7315 26.7222 5.1694 0.7010 \n", "ridge Ridge Regression 3.7410 27.0079 5.1969 0.6978 \n", "br Bayesian Ridge 3.7858 27.5915 5.2528 0.6913 \n", "en Elastic Net 4.0508 31.5715 5.6189 0.6467 \n", "lasso Lasso Regression 4.1660 33.2969 5.7703 0.6274 \n", "llar Lasso Least Angle Regression 4.1663 33.2995 5.7706 0.6274 \n", "huber Huber Regressor 4.0945 34.1371 5.8427 0.6180 \n", "knn K Neighbors Regressor 4.8197 45.4761 6.7436 0.4912 \n", "omp Orthogonal Matching Pursuit 6.0011 69.8449 8.3573 0.2185 \n", "dummy Dummy Regressor 6.8948 89.9663 9.4851 -0.0067 \n", "par Passive Aggressive Regressor 8.6557 126.6179 11.2525 -0.4168 \n", "\n", " RMSLE MAPE TT (Sec) \n", "gbr 0.1644 0.1277 0.07 \n", "catboost 0.1627 0.1235 0.62 \n", "xgboost 0.1701 0.1275 0.13 \n", "rf 0.1847 0.1429 0.08 \n", "et 0.1726 0.1259 0.07 \n", "lightgbm 0.1942 0.1475 0.13 \n", "ada 0.2177 0.1770 0.04 \n", "dt 0.2098 0.1718 0.01 \n", "lr 0.2646 0.1875 0.01 \n", "lar 0.2645 0.1884 0.01 \n", "ridge 0.2917 0.1914 0.01 \n", "br 0.2893 0.1957 0.01 \n", "en 0.3107 0.2064 0.01 \n", "lasso 0.3204 0.2122 0.00 \n", "llar 0.3206 0.2122 0.00 \n", "huber 0.3113 0.2044 0.02 \n", "knn 0.2721 0.2307 0.00 \n", "omp 0.3426 0.3017 0.01 \n", "dummy 0.4101 0.3809 0.00 \n", "par 0.6046 0.3750 0.01 " ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# write example for regression\n", "from pycaret.regression import RegressionExperiment\n", "exp1 = RegressionExperiment()\n", "df = get_data('boston')\n", "exp1.setup(data=df, target='medv')\n", "# exp1.add_metric('mae', 'Mean Absolute Error', 'mean_absolute_error', greater_is_better=False)\n", "# exp1.add_metric('r2', 'R^2', 'r2', greater_is_better=True)\n", "# add loss function\n", "\n", "best = exp1.compare_models(cross_validation=False)\n", "df = exp1.pull()\n", "df" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelMAEMSERMSER2RMSLEMAPETT (Sec)
\n", "
" ], "text/plain": [ "Empty DataFrame\n", "Columns: [Model, MAE, MSE, RMSE, R2, RMSLE, MAPE, TT (Sec)]\n", "Index: []" ] }, "execution_count": 164, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ALLOWED_MODEL = ['lr', 'lf']\n", "# get only allowed indices from df\n" ] }, { "cell_type": "code", "execution_count": 175, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_values(['Gradient Boosting Regressor', 'CatBoost Regressor', 'Extreme Gradient Boosting', 'Random Forest Regressor', 'Extra Trees Regressor', 'Light Gradient Boosting Machine', 'AdaBoost Regressor', 'Decision Tree Regressor', 'Linear Regression', 'Least Angle Regression', 'Ridge Regression', 'Bayesian Ridge', 'Elastic Net', 'Lasso Regression', 'Lasso Least Angle Regression', 'Huber Regressor', 'K Neighbors Regressor', 'Orthogonal Matching Pursuit', 'Dummy Regressor', 'Passive Aggressive Regressor'])" ] }, "execution_count": 175, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get key value pairs of df where key is df index and value is model name\n", "dfl = df.to_dict()['Model']\n", "dfl.values()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 2 }