MVA-2021 / computational_stats / hw1_sol.ipynb
hw1_sol.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Describe the SGD for minimizing the empiral risk and implement it"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a set of n observations ${z_i = (x_i, y_i)}_{i=1}^{K}$, the empiral risk is \n",
    "$R(w) = \\frac{1}{K} \\sum_{i=1}^K(y_i - w^tx_i)^2$, and its gradient is \n",
    "$\\nabla R(w) = \\frac{1}{K} \\sum_{i=1}^K -2x_i (y_i - w^tx_i)$. The SGD algorithm proceeds as follows:\n",
    "\n",
    "- Start from a value $w = w_0$\n",
    "- For $N$ iterations, compute the empirical risk and its gradient for the value $w_k$ on a batch of K observations. Then set $w_{k+1} = w_k - \\eta_k \\nabla R(w_k)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# x, y list of n observations, x belonging to R^d and y to {-1;1}\n",
    "def SGD(x, y, lr = 0.1, N_iter = 500, K = 10):\n",
    "    (n,d) =  x.shape\n",
    "    w = np.random.rand(d)\n",
    "    Rs = []\n",
    "    ws = []\n",
    "    \n",
    "    indices = np.arange(n)\n",
    "    for iter in range(N_iter):\n",
    "        np.random.shuffle(indices)\n",
    "        indices_iter = indices[0:K]\n",
    "        R = 0\n",
    "        grad = np.zeros(d)\n",
    "        dist = y - np.dot(x, w)\n",
    "        for i in indices_iter:  \n",
    "            R += 1/n *dist[i]**2\n",
    "            grad += -2/n * x[i] * dist[i]\n",
    "        ws.append(w)\n",
    "        w = w - lr * grad\n",
    "        Rs.append(R)\n",
    "            \n",
    "    return w, ws, Rs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1.99999539 -0.99999715]\n"
     ]
    }
   ],
   "source": [
    "y = np.array([-1,1,-1])\n",
    "x = np.array([[0,0],[1,1],[0,1]])\n",
    "\n",
    "print(SGD(x, y)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Sample a set of observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x19da7bf0550>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEgCAYAAACwxdQWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3dfbRcdX3v8feH8BCDEEsAcRHToKIg8QnSikUpYkyRXo+i13vFVonYgu0BHy69WkltdbkoXeLjFShyRaytYu1VTATUgIKCKBoQJciDCQJGpYTYABIQEr73j98+ycxkzjkzZ2bPfvq81po1Z/bes+e798zZ3/172lsRgZmZ2YSdig7AzMzKxYnBzMzaODGYmVkbJwYzM2vjxGBmZm2cGMzMrI0Tg01K0lGSQtKyomMpC0l3Srqq6DiaRtJVku4sOo6mcGJoEEl7SnqvpBskPShps6SfSjpL0pOLjq8sJL1P0quLjmNQkp4m6XxJt2bf9X9l3/e/SHpp0fFZee1cdAA2GpKeCXwD+H3gy8AFwGPA4cDbgTdLemVEfK+4KEvjH4B/Ab7SZd6zgNKPCpW0GPg26Tv+LHAz8ATgmcArgQeBKwsL0ErNiaEBJM0BvgrsD7wyIi5tmX2+pHOBK4AVkp4TEf9ZRJy9kCRg94j4bRGfHxG/K+JzZ+AfgDnACyLixtYZkk4B9iskKqsEVyU1w1tIZ4of7UgKAETEauB0YB/gf3dbgaRTJd0u6ZHs+dQuyxwi6T8k/VLS7yTdI+lKSX/asdxukk6XdHO2vk2SvirpBR3LbWvjkDQu6afAI8DfSPp3SY9K2rtLHM/K3vexlml/LWlVFtujkn4t6d8kLWxZZqGkidLACdk6omXapG0Mkl4t6buSfps9vivpVV2WuzOrLz9I0qVZld79kv6fpP06lt1L0kclrcv200ZJ10vq+h11OBDY2JkUACLi8Yj4Vcdn/U9JKyXdnX1390n6iqTnTrENz5N0Rba990r6kKSdJc3O/v5lFvd3JB3csY5l2b5dklXd3ZV97k8kvb6H7ZtYz4GS/jX7Ph/NYjtL0u69rsN2JF8rqf4kfRs4EjgwItZOsswcYBPwy4g4IJt2FKm64QbSGeYnSVUQxwN/ALwvIt6fLTsP+Gm2uvOAu4C9gcXArRHx99lyuwCrgD8C/hW4HpgL/CXwFODILFG1fv6PgXnA/wXuAX4BCLgUODUizu7YljNIie7QiPhRNu0O4PvZun4DLAL+ArgfeE5EbMwOJsdlcV0NnD+xzoj4t2w9dwJ3RsRRLZ/318A5wK2kKqgAlgEHASdHxPkty95Jqt7ZA7g4i+d5wMnAFRGxtGXZb5K+t09my83J1rkgItqSbSdJlwLHAq+NiC9PtWy2/NXZfvkhaR8/HTgJ2JW0H3/WsQ1bgD2BfydVUy0l7bsPAoeQqq2+QvoN/A2wHjg4Ih7P1rEMuJD029od+Ey2395Mqq57c0R8puUzrwIWRsTClmmHAd8i/W4vBH5J2pd/Sfpd/XFEPDbdtlsXEeFHzR/ARuCBHpa7ifTP+cTs9VHZ6weB+S3L7Qr8gHSAm59NG8uW/R/TfMY7s+X+pGP6nsDdwFUt0yY+/zfAvh3LzwJ+DfygY7pISeknHdN37xLLy7L1v6tjegCfmST+Ozti/D3gt8BaYM+O7VmX7bsndbx/h/1ESiwBHJS9npu9PneG3/mLgEezddwOfBr4K9LBudvy3fbPwcDvOmNo2YbXdUy/HngcWEF20plNf1vnd05KnJF9V3Nbps/Npv0GeELL9KtICbn1835MSsZ7dEw/Llv3siL/76r8cFVSM+xJOjOezsQyczumfy4i1k+8iIhHgY+S2qhe2fHeV0jac4rP+HPSP/P1kvaeeJCSzeXAiyU9oeM9n42Ie1snRMRW4HPAH0g6qGXWUcAC0pl76/IPAUjaSdLc7DN/nMX9wininc7LSWe8/yciHmj5vAeATwBPBJZ0vOdXEfHFjmnfyp6fkT0/TDoov7C1uqtXkToRHEbaD3NJZ+LnAj+VdLWkp3UsP7F/pNR7bW9gA3Ab3ffPLyPiPzqmXUNKzJ+I7AiduTp7PrDLev45Irb9NrO/zyMl3KMm2z5JzwGeC3we2K3jt3QN8BCpFGMz4MTQDA+QksN0JpbpTCK3dFl2otroaQAR8W1S75dlwH1ZHfv7JT27430Hk6pDNnR5nEgqCXS2G9w+SbwTB/83tUx7EzCRNLaRdHRWHfEQqeph4jPnkg5CM3VA9nxzl3lrsuendUy/o8uyG7PnebAt+b6DVOX1c6X2mE9IelmvgUXETRGxLCKeDCwETiAdpF9M6miw68Sykl4g6RJSCed+tu+f59B9//y8y7T/mmTexPR5Xd4z7W9rEhNtFu9nx9/RvaRk7S7YM+ReSc2wBjhS0jNi6jaGZ5GK6509fro1RKlzQkScIOksUt32i4HTgOWS3hHb2wFEqrL6X1PEu6Hj9eZuC0XETZJuBP5c0nJSvfZrgVURcU/Ltv0BqV1jLfC3pAPXw9l2fYHBTpB22A892NrL+iLiPEkrgD8F/hj478Apkv49InpuoM3WdRfwWUkT7SdHAH8IXCNpAfAd0gnEB0ilhIdI++djpFJPP9sw2bxu+6qn39YUy3wY+Poky/zXJNNtGk4MzfBlUiPmX5AOjN28iVSd062hsvOsH7afsbWd/UbEGlIi+qCkJwHXAf8k6ZyseuFnpN5P34qsIXJA/0Kq1nopqfF6DzqqkYA3kEoir4iIbWezWWPzIKUFSO0IkBpcv9kxb2K/dSsh9CQifg18CviUpFmkhvHjJX04In44g/WFpOtIiWH/bPJxpIP/WES0jW3IOhXk2UX32cDKjmldf1sdJhrDt0bEFUOPquFcldQMnyKdLb9T0jGdMyUdCpxJOlM/q8v7/0zS/JbldyU1Im8FLsmm7SWp7fcUEZtIZ+dzgNnZ5M+Sejh1LTGo/xHYnyf1kHlT9rif1PjZauIMtvNM9HS6/w/8Ftirx8+/nHR2faqkPSYmZn+fmq3r8h7XtY2kOVkpbpusXeUn2csp45P0ckk7nPhl7TcTde8TVTZd94+kvyT/8Q5/JWlbm1b291tJ1X3fnuJ9PyKdgLy1s70kW8/Oknr9Dq2DSwwNEBEPSRojFbkvlfQlUi+PLaTqhDeSDmCvbq2CaXE7cJ2k80h10G8gdVf9QET8IlvmTaTEczEpCT1Gqv74E+CLEfFwttzHSQ22Z0k6mtTo+gCpwfhlpHEKPV+uISLulfQ1UjXLbOCCiHikY7GLSYnsMknnk3rrvJzUeHlfl9V+H1gi6d2knlIREV+Y5PM3SXoXqVfRdZI+k81aRmpIPrm1cbUPzwS+ne3PNaRqkYNJPYt+zvYG3cl8FJgnaSWp6m4z8FTSd/dMUoP+TdmyX8vm/6uks7PPOoJUJbiOfI8T95H226dJienNpN/CX0RE1ypE2FbyeSPp9/OT7P03k05CngG8BngPqRus9avoblF+jO5Bamj9e+BGUiJ4mNRD6EPAfl2WP4rtffLfRiq+/y57fnvHss8nVeGsJZ1BP0Dq9XMasFvHsjtn6/thtuxD2To/Byzt9vnTbNdrs+UCOGKSZV5N6k75EOlg9AXSAehOWrqfZsseSGqTeGBivS3zdlg+m34ccG3L9lxLSrSdy032/rZtJTXUfjT7rjZl39VaUp3/U3r4rpeSktWPs+3dQmrgvpLUyL9Tx/JHknrzPJh93qWkhu+r2LGb6GTb8L5sGxZ2TF+YTX9fy7Rl2bQlpAbku7Pf1hrgDV3WvUMc2fTfJ/ViupOU8Ddm3/OZwFOL/p+r6sMD3Mxs5FoGuL00Iq4qNhrr5DYGMzNr48RgZmZtnBjMzKyN2xjMzKyNSwxmZtamFuMY9t5771i4cGHRYZiZVcb1119/X0Ts021eLRLDwoULWb16ddFhmJlVhqS7JpvnqiQzM2vjxGBmZm2cGMzMrE0t2hjMrFkee+wx1q9fzyOPdF4v0TrNnj2b+fPns8suu/T8HicGM6uc9evXs8cee7Bw4UKkmdwrqRkigo0bN7J+/XoOOOCA6d+QcVWSmVXOI488wrx585wUpiGJefPm9V2ycmIwq5GVK+GUU9Jz3Tkp9GYm+8mJwawmVq6E44+Hc85Jz01IDpYPJwazmli1CjZn9zzbvDm9NpsJJwazmli6FOZkd4meMye9tmo58cQT2XfffVm0aFGhcTgxmNXE2BhcdBGMj6fnsbGiI7J+LVu2jK9//etFh+HEYFYnY2Nw9tlOCqNw0003ccQRR2x7fcMNN3D00UcPtM4jjzySvfbaa9DQBuZxDGbWCCtXpnaXpUuHkzgPOeQQ1q1bx9atW5k1axannXYaH/7wh7su+5KXvIQHH3xwh+kf+tCHWLJkyeDBDJkTg5nV3kSPrc2b4cILh1PVttNOO3HIIYdw880387Of/YwFCxZw6KGHAvDe976XD3zgA9uWvfrqqwf7sBFzYjCz2uvWY2sYpYbDDz+c7373u5x77rnb2gbuuecetmzZ0racSwxmZiWzdGkqKWzePNweW4cffjjLli1jfHyc/fffH4Af/ehHPP/5z29brmolBjc+m1nt5dVj66CDDmK33Xbj3e9+97ZpN9544w6JoVfHH388L3rRi7jtttuYP38+F1xwwXAC7ZNLDGbWCGNjw++t9fGPf5wzzzyT3Xfffdu0tWvXcuCBB85ofRdddNGwQhuISwxmZn1at24dBx10EA8//DAnnHBC27wLLriAnXaq9qHVJQYzsz49/elP59Zbby06jNyULq1JeqqkKyXdIulmSW8vOiYzsyYpY4lhC3BaRNwgaQ/gekmXR8RPiw7MzKwJSldiiIhfR8QN2d8PArcA+xcb1eg16br6ZlYupUsMrSQtBF4AXNdl3kmSVktavWHDhlGHlitfV9/MilTaxCDpicCXgHdExAOd8yPi/IhYHBGL99lnn9EHmCNfV9/MilTKxCBpF1JS+FxEfLnoeEbN19U3syKVrvFZ6QalFwC3RMRHio6nCBOjNId5JUgzs16VLjEARwBvBG6SdGM27fSIuKzAmEYuj1GawzbsyxibWTmULjFExDWAio7DppbHZYzNrD8nnngil1xyCfvuuy9r1qwZ2npL2cZg5ecG8hpoWp/oTZvg7rvTc03kdStQJwabETeQV1zT+kRv2gR33AH33pueh5Ac8ri1Z7/yuhVo6aqSrBrcQF5xed25pqweeACuvBKuuw5e+ELYe2940pMGWqVv7WnWRRUayG0Sed25pqyuuQb+7u/gkUfgq1+FpzwF3vCGgVaZ5609lyxZwj333LPD9DPOOINXvepVA8XdCycGsyZqWpHv2mtTUoD0fO21AycGyO/WnldcccXAsQ3CicGsqZpU5MuphORbe5qZVVVO9/Yc9q09+5XXrUBdYjCzZsihhDTsW3v2K69bgbrEYGbWJ9/a08zM2vjWnmYV0rTBvGZ5cGKw2mjaYF6zvDgxWG34+k1mw+HEYLUxyPWbXAVVPRFRdAiVMJP95MZnq4WJe0O84x1w//0wd+72EsN0PRR9CfHqmT17Nhs3bmTevHmke3uVzKZN6fpMe+458DWZBhERbNy4kdmzZ/f1PicGq7zWA/ucOSk5fOxjvR/om3Y9uTqYP38+69evZ8OGDUWHsqPNm+G++yACpHTBvomibAFmz57N/Pnz+3qPE4NVXueBfeXK/g70TbueXB3ssssuHHDAAUWH0d0pp6QeEBPGx+Hss4uLZwbcxtBQU9WpV62+fe5c2Dk7xZkzJyWBftoacrpagjVVHW5WEhGVfxx22GFhvVuxImLOnAhIzytW9DavjFrjnTUr4vTTt08fHy9//FZTFfgBAqtjkmOqq5IaaKo69arVt7fGu3VraniGZl041Eqo4j9AVyU10FQl3aqVgqsWr1kVKGrQF3jx4sWxevXqosOolInund3u0TLVvDKqWrxmZSDp+ohY3HWeE4OZWfNMlRhclWRmZm2cGKwRqtYFt9K8syvPicFqz1ddHaE8d7YTzsg4MVjt+aqrI5TXznZ2HyknBgPqfTLmLq0jlNfOdnYfKScGq/3JmC95MUJ57eyqZfeKn2m5u6rV4Zpf1gRVGbDSebnfkp6NVKq7qqRPS7pX0pqiY2mKqp2M9aMMJ25liGEHpQxqGmNj6YylhAfZNnWo9prsIkpFPYAjgUOBNb2+xxfRG1wFrvnVtzJcELAMMeyglEHVSEX2L1NcRK90JYaI+A7wm6LjaJqqnIz1owwnbmWIYQelDKpGatCoVbrEYDYsZagiK0MMOyhlUDVT8TOtUjY+S1oIXBIRi6ZY5iTgJIAFCxYcdtddd40mOKuUMrRXliGGHfQbVCk3wgZRuYvo9ZIYWrlXklmOKtLLxvpTqV5JZlYybpNonNIlBkkXAd8DniVpvaS3FB2TdVfFHo82A26TaJxSViX1y1VJo1eV2gVXjXeY6Q7xjqwdVyXZ0FWhdqH1Uh+veQ0sX150RAUb5NonFe9lY/1xYhiCJlapVKF2oTV5bd0KH/xgs76jHVQhm1spODEMqCoXoJtJ8prqPVUYw7N0Kcyatf31li0NPxZWIZtbOUw2JLpKjyIviTE+nka+TzzGxwsLZVIzGaFfkVH90zr99Iidd07bsdtuEcceW91tGYo6XvvEZoQqXRKjaqpwEjaTGoS61DqccQZ86Utw7LEpdV92WblLdrlzW4H1wIlhQL1WqRTZDjGT5FWFhNersTE44AB49NH0usqJzmwU3F11BMrQtXMmvQ3L3ENxJld0KPo7aAumrDt2KlWN27qaqrtq4e0Dw3iU/bLbVWiHyENe1dkzbf8oRfV6VRtvqhq3TQq3MRSrTtUyvcqzt9ZM2z9KUb1e1cabqsZtM+LEMAJV6No5bHkeRyqdaKsa/LDjbuLgnwpxG4PlIu86/UpXd1c1+GHFXaoGn+aq3GW3++XEMDr9HBuqevyznJ1ySqpjnDA+nur4bKSmSgw7jzoYq67WE70LL5z+RG9szAnBuli6NP2AJkoMValOaxC3MVjP3P5YYWWq029io1vFuMRgPfOJXkX1W9QbBRcnS80lButZnU70ynQCnTsX9axPTgwVUZYDWSnGAgyoKlfEHZqqdpG1wjgxVEDjDmQ5a9wJdJ2KejYSTgwV0LgDWc4aeQJdh6KejYwTQwU08kCWI59Am03NA9wqwoPF8uH9ak3lkc9mXfjKDNZkUyUGVyVZY9W+7aYsXdmscpwYrLFq3Xbjrmw2ACeGmqvbSeMwt6fWjdC1Lw5ZntzGUGN1q0Ov2/bkyjvLpuE2hoaq20lj5/YsX16fktDQ1aE4VLfiboU4MQxZmX7LdatDb90egDVrSlp9XpYfQZUHtbmNpFBODENUtt9yHU4aW01sz6JF26eVriRUth9BVdWtuFsxTgxDVMbfcpVPGiezYAHsumv6u3QloTL+CKqobsXdinFimKFutQX+Ledr4mT8sstAgmOPLWFJKI8fQVmqpkapbsXdqomI0j2AY4DbgLXA3063/GGHHRajtGJFxJw5EZCeV6xonzc+3j7NhmN8PO3zicf4eNERTWKYP4KpfmxmAwBWxyTH1NKVGCTNAs4BXgE8Gzhe0rOLjardVLUFday6KYvKlMiG+SNoQtVUE0tEJddXYpB0gKRvSrpD0kckzW6Z94MhxfSHwNqIuCMiHgW+ALxqSOseisocoGqmkbULdf+xubG+lPotMZwLfBl4HbAX8E1Je2TzdhlSTPsDv2h5vT6b1kbSSZJWS1q9YcOGIX10bxp5gCqJxpXIRvljK+LMPc8SkUsiMzdZHVO3B3BDx+v3AD8A5nbOm+mDlHQ+1fL6jcAnpnrPqNsYqqyubSB13a6RKaotI6/PddvMtBhiG8NuHUnlTOCLwDeBPbq+o3/rgae2vJ4P/GpI6260upba67pdI1VUW0ZeJaImtM3kqN/EcLukl7dOiIgPAZ8Hnj6kmH4IHJi1Z+wKvB7wv/oQ1PWSEj4GDEGRbRl51A/WvW0mZ9MmBkmHtrx8PfCdzmUi4iO0n+XPWERsAU4BvgHcAnwxIm4exrrrYJBq08pcUqJPtTsGFFE3XreGs7ptz6hNVscU2+v47wdeOt1yRT6a0sYwjGrTFSsiFi2qyHiAPtSmjcF14zYiDNjG8HngMkmv7Zwh6cWSrhluqrLJDKPKZGwMzjijfoNza9NbyfViVgLTJoaI+CvgTOALkt4KIOk5kr5Kqlb6vXxDtAnDqjIZdim77I2/RSetvtSuXswqabKiROcDeAvwKPBtYAvwc2AZsFOv68jr0ZSqpIhyVpmU+VIVp58eMWtWxWpmyvglW+0wRVXSzr0kD0l7Ac8EtgIvAa4FjorUUGwjNDZWvuqSpUvhwgu33yysLCe5K1fCBz8IW7em1xM1M2Xbfzso45dsjTJtYpD0D8A7s2U/TLqw3XnAR4C35RqdVcJE1dSqVSkplOWYtmoVbGk5dZk1qzxJy6zMeikxLAc+Bbw/Iv4TQNLdwMWSngz8eUQ8lmOMVgFlOMldubI9ObWWZHbeGd71ruJjNKuCXhLDwRGxrnVCRHxL0kuBy4CvAy/LIzizXk00gG/enJLBRRel6UcdlZ5PPtlJwaxX0yaGzqTQMv0GSS8mDUQzK1RnL89PfhKuump7u8fJJxcanlmlDHQ/hohYC/zRkGKplEp1gWyAzl6e4OEAZjPVU6+kqUy0O1RRZ510P+/rrLZwNUWxOhvAob3E4EZns94NnBiqapCDe7fBqU4MxetsAC+sp9RMzzjKrI7bZJMq3a09R2WQKw94cGo1FHKZjLIPA5+JOm6TTamxiWGQg7sv3GiTqtK1jnptKKvSNtlQNDYxDHpwr81F26w/0x1Mq1Kc7KcUUJVtsqFRumRGtS1evDhWr15ddBhWZsOoI29tmJozZ/IziirUx59ySkoKE8bH05nOZKqwTdYXSddHxOJu8xrb+GwNMqxuZL32OijDMPDp9HuBqypskw1NY6uSrEGGVUdepyqVKjSUebBQYVyVZPXXaxXQVO9vHSDhKpX8Dfqd2bSmqkpyicHqb5Cz485GWhhNr4NezpbrfEY9aCmvzvtmFCa7UUOVHk26UY+NWB53IZruRjy93Pe57veGHmT76r5vhoQB7/lsNmOVP3EbdrtCL91EpztbXrkSli+v99iCQUp5HncxMCcGy00tBswOu5G2l4PWVMloYqeuWbN9WtUbwicz08FCdeokUBAnBstNbU7cejlA9Vo06uWgNVUyat2pAIsWuWG2UxV6XJWceyVZbhrTsaTfDR1ksFhjdqrlzQPcrBBlvRf00PV7ud1BBos1ZqdakVxiMBuUz+KtgjyOwbYZRS+hyvdE6pfrtK1mXGJokFGc2Ob5Gb6Om/XFP5gpucRgwGh6CeX1GbXo+mqj4x/MQJwYGmQU3bvz+ozadH210fAPZiClSgySXifpZkmPS+paxLGZG0VVeF6fMWnCaVyDhvXEg9wGUqo2BkkHA48DnwT+JiJ6ajhwG0Mz7FBl7N5ANhW3MUypMuMYIuIWAElFh2IltEP3/37HD1iz+OZCM1aqqqR+SDpJ0mpJqzds2FB0OFYEVxeY5WLkJQZJVwD7dZm1PCJW9LqeiDgfOB9SVdKQwrOC9VX69yhgs1yMPDFExJJRf6ZVw4xuzezqArOhq2xVktWPexhOwj2vbMRKlRgkHSdpPfAi4FJJ3yg6JhudXJsMqnpw9UAtK0CpEkNEXBwR8yNit4h4ckT8SdExWbs8j6+5jbOo8sHVxSgrQKkSg5XbKI6vM71p15SqfHB1zysrgBOD9ayyx9cqH1x95VYrQKkGuFm5LV2aegtNDDSuzPG16t1a3fPKRqxUl8SYKV8SY3R8lQGzeqjMJTGs/HzyalZ/bmMwM7M2TgyWu6oOITArrZz/qZwYLFdVHkJgVkoj+KdyYrBcVbaLq1lZjeCfyonBclXlIQRmpTSCfyr3SrJcVX0IgVnpjOCfyuMYzMwaaKpxDK5KMjOzNk4MZmbWxonBzMzaODGYmVkbJwYz8/B0a+PEYNZ0Hp5uHZwYzJrOw9OtgxODWdN5eLp18Mhns6bz8HTr4MRgZr4Dk7VxVZKZmbVxYsA99czMWjU+MbinnplZu8YnBvfUMzNr1/jE4J56Zn1y3WvtNb5XknvqmfVhou5182a48ML0z+N/mtppfGIA99Qz61m3ulf/89RO46uSbIRcBVF9rntthFKVGCSdBbwSeBRYB7w5IjYVG5UNhasg6qHuda8rV9Z32/pQthLD5cCiiHgucDvwnoLjsWFx96/6GBuDs8+u34HTfde3KVViiIhVEbEle/l9YH6R8dgQuQrCys4nL9uUKjF0OBH4WtFB2JBMVEGMj7saycrJJy/bKCJG+4HSFcB+XWYtj4gV2TLLgcXAa2KSACWdBJwEsGDBgsPuuuuunCI2s8ZoUBuDpOsjYnHXeaNODNORdALwVuBlEbG5l/csXrw4Vq9enW9gZmY1MlViKFVVkqRjgHcDY70mBTOrAHdVrpRSJQbgbGAP4HJJN0o6r+iAzGxA7u1TOaUaxxARzyg6BjMbMo+WrpyylRjMrG7c26dySlViMLMaqvto6RpyYrDhaVBXP+uTr1RZKa5KsuFwA6NZbTgx2HD4cgJmteHEYMPhBkaz2nAbgw2HGxjNasOJwYbHDYy9cSO9lZyrksxGyY30VgFODGaj5EZ6qwAnBrNRciO9VYDbGMxGyY30VgFODGaj5kZ6KzlXJZmZWRsnBjMza+PEYGZmbZwYbEe+DaNZozkxWDsPwDJrPCcGa+cBWGaN58Rg7TwAy6zxPI7B2nkAllnjOTHYjjwAy6zRXJVkZmZtnBjMzKyNE4OZmbVxYjAzszZODGZm1saJwczM2jgxmJlZG0VE0TEMTNIG4K6i42ixN3Bf0UGMgLezPpqwjeDtbPX7EbFPtxm1SAxlI2l1RCwuOo68eTvrownbCN7OXrkqyczM2jgxmJlZGyeGfJxfdAAj4u2sjyZsI3g7e+I2BjMza+MSg5mZtXFiMDOzNk4MOZF0lqRbJf1E0sWSnlR0THmQ9DpJN0t6XFKtulB8DykAAAOhSURBVAFKOkbSbZLWSvrbouPJg6RPS7pX0pqiY8mTpKdKulLSLdnv9e1Fx5QHSbMl/UDSj7PtfP9M1uPEkJ/LgUUR8VzgduA9BceTlzXAa4DvFB3IMEmaBZwDvAJ4NnC8pGcXG1UuPgMcU3QQI7AFOC0iDgYOB8Zr+n3+Djg6Ip4HPB84RtLh/a7EiSEnEbEqIrZkL78PzC8ynrxExC0RcVvRceTgD4G1EXFHRDwKfAF4VcExDV1EfAf4TdFx5C0ifh0RN2R/PwjcAuxfbFTDF8lvs5e7ZI++exg5MYzGicDXig7C+rI/8IuW1+up4YGkiSQtBF4AXFdsJPmQNEvSjcC9wOUR0fd2+p7PA5B0BbBfl1nLI2JFtsxyUjH2c6OMbZh62c4aUpdp7ttdcZKeCHwJeEdEPFB0PHmIiK3A87N2zYslLYqIvtqQnBgGEBFLppov6QTgvwEviwoPGJluO2tqPfDUltfzgV8VFIsNgaRdSEnhcxHx5aLjyVtEbJJ0FakNqa/E4KqknEg6Bng3MBYRm4uOx/r2Q+BASQdI2hV4PbCy4JhshiQJuAC4JSI+UnQ8eZG0z0QPSElPAJYAt/a7HieG/JwN7AFcLulGSecVHVAeJB0naT3wIuBSSd8oOqZhyDoOnAJ8g9RQ+cWIuLnYqIZP0kXA94BnSVov6S1Fx5STI4A3Akdn/483Sjq26KBy8BTgSkk/IZ3cXB4Rl/S7El8Sw8zM2rjEYGZmbZwYzMysjRODmZm1cWIwM7M2TgxmZtbGicHMzNo4MZgNSNIzJD3WeYljSf8s6cG6XY7c6s+JwWxAEbEW+BTwTkl7A0j6e9LFE4+LiNVFxmfWLw9wMxsCSfsB64BzSZcgOB84PiK+WGhgZjPgEoPZEETEPcDHgFOBTwJva00Kkk7P7gb3uKRXFxWnWS+cGMyG52fAbsD3IuKcjnnfBI6lZne6s3pyYjAbAklHk0oK3wOOkPS81vkRcV1ErCskOLM+OTGYDUjSocBXSA3QRwF3A/9YZExmg3BiMBuApGeQbtu6Cjg1uz/0+4FjJR1ZaHBmM+TEYDZDWU+kVaT7NfxZRDyezfosqWfSPxUVm9kgfGtPsxnKeiI9rcv0rcDBo4/IbDg8jsFsBCT9HfBWYB/gQeARYHGWXMxKxYnBzMzauI3BzMzaODGYmVkbJwYzM2vjxGBmZm2cGMzMrI0Tg5mZtXFiMDOzNk4MZmbWxonBzMza/H/gxbR2eAKRaQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "N_pts = 100\n",
    "x = np.random.normal(0,1,size=(N_pts,2))\n",
    "w0 = np.array([-2.0,1])\n",
    "w0 = w0 / np.linalg.norm(w0)\n",
    "y = np.dot(x,w0) > 0\n",
    "y = 2*y -1\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.scatter(*x[y == 1].T, color='b', s=10, label=r'$y_i=1$')\n",
    "plt.scatter(*x[y == -1].T, color='r', s=10, label=r'$y_i=-1$')\n",
    "plt.xlabel(r\"$x_1$\", fontsize=16)\n",
    "plt.ylabel(r\"$x_2$\", fontsize=16)\n",
    "plt.title(\"Observations Sample\", fontsize=18)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Test the algorithm you wrote at the first question over these observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w estimated :  [-0.88882056  0.4582554 ]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "w, ws, Rs= SGD(x, y)\n",
    "dist = [ np.linalg.norm(ws[iter]/np.linalg.norm(ws[iter]) - w0) for iter in range(len(ws))] \n",
    "\n",
    "print(\"w estimated : \", w/np.linalg.norm(w))\n",
    "\n",
    "plt.plot(Rs)\n",
    "plt.title(\"Empirical risk\")\n",
    "plt.show()\n",
    "\n",
    "plt.plot(dist)\n",
    "plt.title(\"||w - w0||\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Noise your observations with an additive Gaussian noise and perform the optimisation again. Compare with the result of question three."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x19da7ae6278>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x_noise = x + np.random.normal(scale = 0.2, size = (N_pts,2))\n",
    "\n",
    "plt.scatter(*x_noise[y == 1].T, color='b', s=10, label=r'$y_i=1$')\n",
    "plt.scatter(*x_noise[y == -1].T, color='r', s=10, label=r'$y_i=-1$')\n",
    "plt.xlabel(r\"$x_1$\", fontsize=16)\n",
    "plt.ylabel(r\"$x_2$\", fontsize=16)\n",
    "plt.title(\"Observations Sample\", fontsize=18)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w estimated :  [-0.8911448   0.45371901]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "w, ws, Rs= SGD(x_noise, y)\n",
    "dist = [ np.linalg.norm(ws[iter]/np.linalg.norm(ws[iter]) - w0) for iter in range(len(ws))] \n",
    "\n",
    "print(\"w estimated : \", w/np.linalg.norm(w))\n",
    "\n",
    "plt.plot(Rs)\n",
    "plt.title(\"Empirical risk\")\n",
    "plt.show()\n",
    "\n",
    "plt.plot(dist)\n",
    "plt.title(\"||w - w0||\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results seem less accurate than before, but is still relatively good"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Test the algorithm on the Breast Cancer Wisconsin (Diagnostic) Data Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "569  observations in the dataset\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>diagnosis</th>\n",
       "      <th>radius_mean</th>\n",
       "      <th>texture_mean</th>\n",
       "      <th>perimeter_mean</th>\n",
       "      <th>area_mean</th>\n",
       "      <th>smoothness_mean</th>\n",
       "      <th>compactness_mean</th>\n",
       "      <th>concavity_mean</th>\n",
       "      <th>concave points_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>texture_worst</th>\n",
       "      <th>perimeter_worst</th>\n",
       "      <th>area_worst</th>\n",
       "      <th>smoothness_worst</th>\n",
       "      <th>compactness_worst</th>\n",
       "      <th>concavity_worst</th>\n",
       "      <th>concave points_worst</th>\n",
       "      <th>symmetry_worst</th>\n",
       "      <th>fractal_dimension_worst</th>\n",
       "      <th>Unnamed: 32</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>842302</td>\n",
       "      <td>M</td>\n",
       "      <td>17.99</td>\n",
       "      <td>10.38</td>\n",
       "      <td>122.80</td>\n",
       "      <td>1001.0</td>\n",
       "      <td>0.11840</td>\n",
       "      <td>0.27760</td>\n",
       "      <td>0.3001</td>\n",
       "      <td>0.14710</td>\n",
       "      <td>...</td>\n",
       "      <td>17.33</td>\n",
       "      <td>184.60</td>\n",
       "      <td>2019.0</td>\n",
       "      <td>0.1622</td>\n",
       "      <td>0.6656</td>\n",
       "      <td>0.7119</td>\n",
       "      <td>0.2654</td>\n",
       "      <td>0.4601</td>\n",
       "      <td>0.11890</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>842517</td>\n",
       "      <td>M</td>\n",
       "      <td>20.57</td>\n",
       "      <td>17.77</td>\n",
       "      <td>132.90</td>\n",
       "      <td>1326.0</td>\n",
       "      <td>0.08474</td>\n",
       "      <td>0.07864</td>\n",
       "      <td>0.0869</td>\n",
       "      <td>0.07017</td>\n",
       "      <td>...</td>\n",
       "      <td>23.41</td>\n",
       "      <td>158.80</td>\n",
       "      <td>1956.0</td>\n",
       "      <td>0.1238</td>\n",
       "      <td>0.1866</td>\n",
       "      <td>0.2416</td>\n",
       "      <td>0.1860</td>\n",
       "      <td>0.2750</td>\n",
       "      <td>0.08902</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>84300903</td>\n",
       "      <td>M</td>\n",
       "      <td>19.69</td>\n",
       "      <td>21.25</td>\n",
       "      <td>130.00</td>\n",
       "      <td>1203.0</td>\n",
       "      <td>0.10960</td>\n",
       "      <td>0.15990</td>\n",
       "      <td>0.1974</td>\n",
       "      <td>0.12790</td>\n",
       "      <td>...</td>\n",
       "      <td>25.53</td>\n",
       "      <td>152.50</td>\n",
       "      <td>1709.0</td>\n",
       "      <td>0.1444</td>\n",
       "      <td>0.4245</td>\n",
       "      <td>0.4504</td>\n",
       "      <td>0.2430</td>\n",
       "      <td>0.3613</td>\n",
       "      <td>0.08758</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>84348301</td>\n",
       "      <td>M</td>\n",
       "      <td>11.42</td>\n",
       "      <td>20.38</td>\n",
       "      <td>77.58</td>\n",
       "      <td>386.1</td>\n",
       "      <td>0.14250</td>\n",
       "      <td>0.28390</td>\n",
       "      <td>0.2414</td>\n",
       "      <td>0.10520</td>\n",
       "      <td>...</td>\n",
       "      <td>26.50</td>\n",
       "      <td>98.87</td>\n",
       "      <td>567.7</td>\n",
       "      <td>0.2098</td>\n",
       "      <td>0.8663</td>\n",
       "      <td>0.6869</td>\n",
       "      <td>0.2575</td>\n",
       "      <td>0.6638</td>\n",
       "      <td>0.17300</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>84358402</td>\n",
       "      <td>M</td>\n",
       "      <td>20.29</td>\n",
       "      <td>14.34</td>\n",
       "      <td>135.10</td>\n",
       "      <td>1297.0</td>\n",
       "      <td>0.10030</td>\n",
       "      <td>0.13280</td>\n",
       "      <td>0.1980</td>\n",
       "      <td>0.10430</td>\n",
       "      <td>...</td>\n",
       "      <td>16.67</td>\n",
       "      <td>152.20</td>\n",
       "      <td>1575.0</td>\n",
       "      <td>0.1374</td>\n",
       "      <td>0.2050</td>\n",
       "      <td>0.4000</td>\n",
       "      <td>0.1625</td>\n",
       "      <td>0.2364</td>\n",
       "      <td>0.07678</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 33 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         id diagnosis  radius_mean  texture_mean  perimeter_mean  area_mean  \\\n",
       "0    842302         M        17.99         10.38          122.80     1001.0   \n",
       "1    842517         M        20.57         17.77          132.90     1326.0   \n",
       "2  84300903         M        19.69         21.25          130.00     1203.0   \n",
       "3  84348301         M        11.42         20.38           77.58      386.1   \n",
       "4  84358402         M        20.29         14.34          135.10     1297.0   \n",
       "\n",
       "   smoothness_mean  compactness_mean  concavity_mean  concave points_mean  \\\n",
       "0          0.11840           0.27760          0.3001              0.14710   \n",
       "1          0.08474           0.07864          0.0869              0.07017   \n",
       "2          0.10960           0.15990          0.1974              0.12790   \n",
       "3          0.14250           0.28390          0.2414              0.10520   \n",
       "4          0.10030           0.13280          0.1980              0.10430   \n",
       "\n",
       "   ...  texture_worst  perimeter_worst  area_worst  smoothness_worst  \\\n",
       "0  ...          17.33           184.60      2019.0            0.1622   \n",
       "1  ...          23.41           158.80      1956.0            0.1238   \n",
       "2  ...          25.53           152.50      1709.0            0.1444   \n",
       "3  ...          26.50            98.87       567.7            0.2098   \n",
       "4  ...          16.67           152.20      1575.0            0.1374   \n",
       "\n",
       "   compactness_worst  concavity_worst  concave points_worst  symmetry_worst  \\\n",
       "0             0.6656           0.7119                0.2654          0.4601   \n",
       "1             0.1866           0.2416                0.1860          0.2750   \n",
       "2             0.4245           0.4504                0.2430          0.3613   \n",
       "3             0.8663           0.6869                0.2575          0.6638   \n",
       "4             0.2050           0.4000                0.1625          0.2364   \n",
       "\n",
       "   fractal_dimension_worst  Unnamed: 32  \n",
       "0                  0.11890          NaN  \n",
       "1                  0.08902          NaN  \n",
       "2                  0.08758          NaN  \n",
       "3                  0.17300          NaN  \n",
       "4                  0.07678          NaN  \n",
       "\n",
       "[5 rows x 33 columns]"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\"data.csv\",header = 0)\n",
    "print(df.shape[0], \" observations in the dataset\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.drop(\"Unnamed: 32\",1)\n",
    "x = df.iloc[:,2:] \n",
    "y = df.diagnosis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "le = LabelEncoder()\n",
    "y = le.fit_transform(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = (x - x.mean())/ np.sqrt(x.var())\n",
    "x_np = x.to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection  import train_test_split\n",
    "x_train, x_test, y_train, y_test = train_test_split(x_np, y, test_size=0.20, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w estimated :  [ 0.38127251  0.04831026 -0.18291229 -0.01426605  0.28136352  0.30300141\n",
      " -0.01443595 -0.38641629  0.25410684 -0.11719687 -0.05078908  0.07906987\n",
      " -0.097882   -0.04137002  0.16112217 -0.16299693  0.22066419 -0.22795517\n",
      "  0.12206528 -0.09284573  0.28597531 -0.01611214 -0.13893487  0.1677988\n",
      " -0.08172362  0.00078073  0.25929497 -0.09928938 -0.09400798  0.02985984]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "w, ws, Rs= SGD(x_train,y_train, lr = 0.001, N_iter=200)\n",
    "\n",
    "print(\"w estimated : \", w/np.linalg.norm(w))\n",
    "\n",
    "plt.plot(Rs)\n",
    "plt.title(\"Empirical risk\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.array([np.dot(x_test[i], w) >0 for i in range(len(x_test))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "confusion matrix : \n",
      "[[56 16]\n",
      " [11 31]]\n",
      "----------- \n",
      "f1 score : \n",
      "0.6966292134831461\n",
      "----------- \n",
      "accuracy : \n",
      "0.7631578947368421\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score, accuracy_score, confusion_matrix\n",
    "\n",
    "def print_confusion(confusion_matrix):\n",
    "    print('Confusion matrix:')\n",
    "    print(\"{0:20} {1:10} {2}\".format(\"GT \\ Pred\", \"0\", \"1\"))\n",
    "    print(\"{0:20} {1:10} {2}\".format(\"0\", str(confusion_matrix[0,0]), str(confusion_matrix[0,1])))\n",
    "    print(\"{0:20} {1:10} {2}\".format(\"1\", str(confusion_matrix[1,0]) , str(confusion_matrix[1,1])))\n",
    "\n",
    "print(\"confusion matrix : \")\n",
    "print(confusion_matrix(y_test, y_pred))\n",
    "print(\"----------- \")\n",
    "print(\"f1 score : \")\n",
    "print(f1_score(y_test,y_pred))\n",
    "print(\"----------- \")\n",
    "print(\"accuracy : \")\n",
    "print(accuracy_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This quick method already gives intersting results (see accuracy and f1 score above), although it is obviously possible to improve them."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}