{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "import numpy as np\n", "import torch\n", "from torchvision import models\n", "from torchvision.models import AlexNet_Weights, VGG16_Weights\n", "import matplotlib.pyplot as plt\n", "%matplotlib ipympl\n", "\n", "# sys.path.append('..')\n", "# from stimulus import draw_bar\n", "# from mapping import RfMapper\n", "# import constants as c" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Please specify some details here:\n", "model = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1).to(c.DEVICE)\n", "model_name = 'alexnet'\n", "# model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).to(c.DEVICE)\n", "# model_name = 'vgg16'\n", "conv_i = 1 # 0 is Conv1, 1 is Conv2, etc.\n", "xn = yn = 227" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The RF mapper is for Conv2 (not Conv1) with input shape (yn = 227, xn = 227).\n" ] } ], "source": [ "# Not going to use the mapper object. Just using it to get info about the\n", "# conv layer.\n", "mapper = RfMapper(model, conv_i, (yn, xn))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c511944a736148979e5de634a4f2c37f", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", " <div style=\"display: inline-block;\">\n", " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n", " Figure\n", " </div>\n", " <img src='' width=640.0/>\n", " </div>\n", " " ], "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib ipympl\n", "class Cursor(object):\n", " def __init__(self, ax):\n", " self.ax = ax\n", " self.xn = 227\n", " self.yn = 227\n", " self.theta = 45\n", " self.blen = 50\n", " self.bwid = 25\n", " self.laa = 0.5\n", " self.fgval = 1\n", " self.bgval = -1\n", "\n", " # text location in axes coords\n", " self.txt = ax.text(15, 15, '', transform=ax.transAxes)\n", "\n", " def mouse_move(self, event):\n", " if not event.inaxes:\n", " return\n", "\n", " x, y = event.xdata, event.ydata\n", " bar = draw_bar(self.xn, self.yn, x, y, self.theta,\n", " self.blen, self.bwid, self.laa, self.fgval, self.bgval)\n", " self.txt.set_text('x=%1.2f, y=%1.2f' % (x, y))\n", " self.ax.imshow(bar, cmap='gray')\n", " \n", " \n", "\n", "fig, ax = plt.subplots()\n", "cursor = Cursor(ax)\n", "plt.connect('motion_notify_event', cursor.mouse_move)\n", "\n", "# ax.plot(t, s, 'o')\n", "plt.axis([0, 227, 0, 227])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "93ec8a6a950c4dd5a732b4c39973d823", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", " <div style=\"display: inline-block;\">\n", " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n", " Figure\n", " </div>\n", " <img src='' width=640.0/>\n", " </div>\n", " " ], "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "x = 0, y = 0\n" ] } ], "source": [ "# from __future__ import print_function\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "%matplotlib ipympl\n", "\n", "\n", "class Cursor(object):\n", " def __init__(self, ax):\n", " self.ax = ax\n", " self.lx = ax.axhline(color='k') # the horiz line\n", " self.ly = ax.axvline(color='k') # the vert line\n", " self.x = 0\n", " self.y = 0\n", "\n", " # text location in axes coords\n", " self.txt = ax.text(0.7, 0.9, '', transform=ax.transAxes)\n", "\n", " def mouse_move(self, event):\n", " if not event.inaxes:\n", " return\n", "\n", " self.x, self.y = event.xdata, event.ydata\n", " # update the line positions\n", " self.lx.set_ydata(self.y)\n", " self.ly.set_xdata(self.x)\n", " \n", " # bar = draw_bar(self.xn, self.yn, x, y, self.theta,\n", " # self.blen, self.bwid, self.laa, self.fgval, self.bgval)\n", " # bar = np.random.rand((100,100))\n", " # self.ax.imshow(, cmap='gray')\n", "\n", " self.txt.set_text('x=%1.2f, y=%1.2f' % (self.x, self.y))\n", " plt.draw()\n", " \n", " def print_xy(self):\n", " print(f\"x = {self.x}, y = {self.y}\")\n", "\n", "t = np.arange(0.0, 1.0, 0.01)\n", "s = np.sin(2*2*np.pi*t)\n", "fig, ax = plt.subplots()\n", "\n", "cursor = Cursor(ax)\n", "# cursor = SnaptoCursor(ax, t, s)\n", "plt.connect('motion_notify_event', cursor.mouse_move)\n", "\n", "# ax.plot(t, s, 'o')\n", "plt.axis([0, 10, 0, 10])\n", "plt.show()\n", "\n", "cursor.print_xy()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d03fd0edbf84f03ab83a21e310827cc", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", " <div style=\"display: inline-block;\">\n", " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n", " Figure\n", " </div>\n", " <img src='' width=640.0/>\n", " </div>\n", " " ], "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "class DraggableRectangle:\n", " def __init__(self, rect):\n", " self.rect = rect\n", " self.press = None\n", "\n", " def connect(self):\n", " 'connect to all the events we need'\n", " self.cidpress = self.rect.figure.canvas.mpl_connect(\n", " 'button_press_event', self.on_press)\n", " self.cidrelease = self.rect.figure.canvas.mpl_connect(\n", " 'button_release_event', self.on_release)\n", " self.cidmotion = self.rect.figure.canvas.mpl_connect(\n", " 'motion_notify_event', self.on_motion)\n", "\n", " def on_press(self, event):\n", " 'on button press we will see if the mouse is over us and store some data'\n", " if event.inaxes != self.rect.axes: return\n", "\n", " contains, attrd = self.rect.contains(event)\n", " if not contains: return\n", " print('event contains', self.rect.xy)\n", " x0, y0 = self.rect.xy\n", " self.press = x0, y0, event.xdata, event.ydata\n", "\n", " def on_motion(self, event):\n", " 'on motion we will move the rect if the mouse is over us'\n", " if self.press is None: return\n", " if event.inaxes != self.rect.axes: return\n", " x0, y0, xpress, ypress = self.press\n", " dx = event.xdata - xpress\n", " dy = event.ydata - ypress\n", " #print('x0=%f, xpress=%f, event.xdata=%f, dx=%f, x0+dx=%f' %\n", " # (x0, xpress, event.xdata, dx, x0+dx))\n", " self.rect.set_x(x0+dx)\n", " self.rect.set_y(y0+dy)\n", "\n", " self.rect.figure.canvas.draw()\n", "\n", "\n", " def on_release(self, event):\n", " 'on release we reset the press data'\n", " self.press = None\n", " self.rect.figure.canvas.draw()\n", "\n", " def disconnect(self):\n", " 'disconnect all the stored connection ids'\n", " self.rect.figure.canvas.mpl_disconnect(self.cidpress)\n", " self.rect.figure.canvas.mpl_disconnect(self.cidrelease)\n", " self.rect.figure.canvas.mpl_disconnect(self.cidmotion)\n", "\n", "fig = plt.figure()\n", "ax = fig.add_subplot(111)\n", "rects = ax.bar(range(10), 20*np.random.rand(10))\n", "drs = []\n", "for rect in rects:\n", " dr = DraggableRectangle(rect)\n", " dr.connect()\n", " drs.append(dr)\n", "\n", "plt.show()" ] } ], "metadata": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" }, "kernelspec": { "display_name": "Python 3.9.1 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }