{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict with Pre-trained Model\n", " \n", "This notebook shows how to load a pre-trained model and to use this for predicting multi-pitch estimates of an unknown audio file.\n", "\n", "© Christof Weiss and Geoffroy Peeters, Télécom Paris 2021" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "basepath = os.path.abspath(os.path.dirname(os.path.dirname('.')))\n", "sys.path.append(basepath)\n", "import numpy as np, os, scipy, scipy.spatial, matplotlib.pyplot as plt, IPython.display as ipd\n", "from numba import jit\n", "import librosa\n", "import libfmp.b, libfmp.c3, libfmp.c5\n", "import pandas as pd, pickle, re\n", "from numba import jit\n", "import torch.utils.data\n", "import torch.nn as nn\n", "import libdl.data_preprocessing\n", "from libdl.data_loaders import dataset_context\n", "from libdl.nn_models import basic_cnn_segm_sigmoid, deep_cnn_segm_sigmoid, simple_u_net_largekernels\n", "from libdl.nn_models import simple_u_net_doubleselfattn, simple_u_net_doubleselfattn_twolayers\n", "from libdl.nn_models import u_net_blstm_varlayers, simple_u_net_polyphony_classif_softmax\n", "from libdl.data_preprocessing import compute_hopsize_cqt, compute_hcqt, compute_efficient_hcqt, compute_annotation_array_nooverlap\n", "from torchinfo import summary" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Specify and load model (Example)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "├─LayerNorm: 1-1 [1, 174, 6, 216] 2,592\n", "├─double_conv: 1-2 [1, 32, 174, 216] --\n", "| └─Sequential: 2-1 [1, 32, 174, 216] --\n", "| | └─Conv2d: 3-1 [1, 32, 174, 216] 43,232\n", "| | └─BatchNorm2d: 3-2 [1, 32, 174, 216] 64\n", "| | └─ReLU: 3-3 [1, 32, 174, 216] --\n", "| | └─Dropout: 3-4 [1, 32, 174, 216] --\n", "| | └─Conv2d: 3-5 [1, 32, 174, 216] 230,432\n", "| | └─BatchNorm2d: 3-6 [1, 32, 174, 216] 64\n", "| | └─ReLU: 3-7 [1, 32, 174, 216] --\n", "| | └─Dropout: 3-8 [1, 32, 174, 216] --\n", "├─Sequential: 1-3 [1, 64, 87, 108] --\n", "| └─MaxPool2d: 2-2 [1, 32, 87, 108] --\n", "| └─double_conv: 2-3 [1, 64, 87, 108] --\n", "| | └─Sequential: 3-9 [1, 64, 87, 108] 1,382,784\n", "├─Sequential: 1-4 [1, 128, 43, 54] --\n", "| └─MaxPool2d: 2-4 [1, 64, 43, 54] --\n", "| └─double_conv: 2-5 [1, 128, 43, 54] --\n", "| | └─Sequential: 3-10 [1, 128, 43, 54] 1,991,424\n", "├─Sequential: 1-5 [1, 256, 21, 27] --\n", "| └─MaxPool2d: 2-6 [1, 128, 21, 27] --\n", "| └─double_conv: 2-7 [1, 256, 21, 27] --\n", "| | └─Sequential: 3-11 [1, 256, 21, 27] 2,459,136\n", "├─Sequential: 1-6 [1, 256, 10, 13] --\n", "| └─MaxPool2d: 2-8 [1, 256, 10, 13] --\n", "| └─double_conv: 2-9 [1, 256, 10, 13] --\n", "| | └─Sequential: 3-12 [1, 256, 10, 13] 1,181,184\n", "├─unet_up_concat_padding: 1-7 [1, 512, 21, 27] --\n", "| └─Upsample: 2-10 [1, 256, 20, 26] --\n", "├─double_conv: 1-8 [1, 128, 21, 27] --\n", "| └─Sequential: 2-11 [1, 128, 21, 27] --\n", "| | └─Conv2d: 3-13 [1, 256, 21, 27] 1,179,904\n", "| | └─BatchNorm2d: 3-14 [1, 256, 21, 27] 512\n", "| | └─ReLU: 3-15 [1, 256, 21, 27] --\n", "| | └─Dropout: 3-16 [1, 256, 21, 27] --\n", "| | └─Conv2d: 3-17 [1, 128, 21, 27] 295,040\n", "| | └─BatchNorm2d: 3-18 [1, 128, 21, 27] 256\n", "| | └─ReLU: 3-19 [1, 128, 21, 27] --\n", "| | └─Dropout: 3-20 [1, 128, 21, 27] --\n", "├─unet_up_concat_padding: 1-9 [1, 256, 43, 54] --\n", "| └─Upsample: 2-12 [1, 128, 42, 54] --\n", "├─double_conv: 1-10 [1, 64, 43, 54] --\n", "| └─Sequential: 2-13 [1, 64, 43, 54] --\n", "| | └─Conv2d: 3-21 [1, 128, 43, 54] 819,328\n", "| | └─BatchNorm2d: 3-22 [1, 128, 43, 54] 256\n", "| | └─ReLU: 3-23 [1, 128, 43, 54] --\n", "| | └─Dropout: 3-24 [1, 128, 43, 54] --\n", "| | └─Conv2d: 3-25 [1, 64, 43, 54] 204,864\n", "| | └─BatchNorm2d: 3-26 [1, 64, 43, 54] 128\n", "| | └─ReLU: 3-27 [1, 64, 43, 54] --\n", "| | └─Dropout: 3-28 [1, 64, 43, 54] --\n", "├─unet_up_concat_padding: 1-11 [1, 128, 87, 108] --\n", "| └─Upsample: 2-14 [1, 64, 86, 108] --\n", "├─double_conv: 1-12 [1, 32, 87, 108] --\n", "| └─Sequential: 2-15 [1, 32, 87, 108] --\n", "| | └─Conv2d: 3-29 [1, 64, 87, 108] 663,616\n", "| | └─BatchNorm2d: 3-30 [1, 64, 87, 108] 128\n", "| | └─ReLU: 3-31 [1, 64, 87, 108] --\n", "| | └─Dropout: 3-32 [1, 64, 87, 108] --\n", "| | └─Conv2d: 3-33 [1, 32, 87, 108] 165,920\n", "| | └─BatchNorm2d: 3-34 [1, 32, 87, 108] 64\n", "| | └─ReLU: 3-35 [1, 32, 87, 108] --\n", "| | └─Dropout: 3-36 [1, 32, 87, 108] --\n", "├─unet_up_concat_padding: 1-13 [1, 64, 174, 216] --\n", "| └─Upsample: 2-16 [1, 32, 174, 216] --\n", "├─double_conv: 1-14 [1, 128, 174, 216] --\n", "| └─Sequential: 2-17 [1, 128, 174, 216] --\n", "| | └─Conv2d: 3-37 [1, 32, 174, 216] 460,832\n", "| | └─BatchNorm2d: 3-38 [1, 32, 174, 216] 64\n", "| | └─ReLU: 3-39 [1, 32, 174, 216] --\n", "| | └─Dropout: 3-40 [1, 32, 174, 216] --\n", "| | └─Conv2d: 3-41 [1, 128, 174, 216] 921,728\n", "| | └─BatchNorm2d: 3-42 [1, 128, 174, 216] 256\n", "| | └─ReLU: 3-43 [1, 128, 174, 216] --\n", "| | └─Dropout: 3-44 [1, 128, 174, 216] --\n", "├─Sequential: 1-15 [1, 180, 174, 72] --\n", "| └─Conv2d: 2-18 [1, 180, 174, 72] 207,540\n", "| └─LeakyReLU: 2-19 [1, 180, 174, 72] --\n", "| └─MaxPool2d: 2-20 [1, 180, 174, 72] --\n", "| └─Dropout: 2-21 [1, 180, 174, 72] --\n", "├─Sequential: 1-16 [1, 150, 100, 72] --\n", "| └─Conv2d: 2-22 [1, 150, 100, 72] 2,025,150\n", "| └─LeakyReLU: 2-23 [1, 150, 100, 72] --\n", "| └─Dropout: 2-24 [1, 150, 100, 72] --\n", "├─Sequential: 1-17 [1, 1, 100, 72] --\n", "| └─Conv2d: 2-25 [1, 100, 100, 72] 15,100\n", "| └─LeakyReLU: 2-26 [1, 100, 100, 72] --\n", "| └─Dropout: 2-27 [1, 100, 100, 72] --\n", "| └─Conv2d: 2-28 [1, 1, 100, 72] 101\n", "| └─Sigmoid: 2-29 [1, 1, 100, 72] --\n", "├─Sequential: 1-18 [1, 24, 7, 1] --\n", "| └─Conv2d: 2-30 [1, 128, 9, 9] 327,808\n", "| └─LeakyReLU: 2-31 [1, 128, 9, 9] --\n", "| └─MaxPool2d: 2-32 [1, 128, 8, 3] --\n", "| └─Dropout: 2-33 [1, 128, 8, 3] --\n", "| └─Conv2d: 2-34 [1, 24, 7, 1] 18,456\n", "==========================================================================================\n", "Total params: 14,597,963\n", "Trainable params: 14,597,963\n", "Non-trainable params: 0\n", "Total mult-adds (G): 109.75\n", "==========================================================================================\n", "Input size (MB): 0.90\n", "Forward/backward pass size (MB): 228.60\n", "Params size (MB): 58.39\n", "Estimated Total Size (MB): 287.89\n", "==========================================================================================" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir_models = os.path.join(basepath, 'models_pretrained')\n", "num_octaves_inp = 6\n", "num_output_bins, min_pitch = 72, 24\n", "\n", "\n", "# Polyphony U-Net trained in recommended MusicNet split (test set MuN-10full):\n", "model_params = {'n_chan_input': 6,\n", " 'n_chan_layers': [128,180,150,100],\n", " 'n_ch_out': 2,\n", " 'n_bins_in': num_octaves_inp*12*3,\n", " 'n_bins_out': num_output_bins,\n", " 'a_lrelu': 0.3,\n", " 'p_dropout': 0.2,\n", " 'scalefac': 2,\n", " 'num_polyphony_steps': 24\n", " }\n", "mp = model_params\n", "\n", "fn_model = 'RETRAIN4_exp195f_musicnet_aligned_unet_extremelylarge_polyphony_softmax_rerun1.pt'\n", "model = simple_u_net_polyphony_classif_softmax(n_chan_input=mp['n_chan_input'], n_chan_layers=mp['n_chan_layers'], \\\n", " n_bins_in=mp['n_bins_in'], n_bins_out=mp['n_bins_out'], a_lrelu=mp['a_lrelu'], p_dropout=mp['p_dropout'], \\\n", " scalefac=mp['scalefac'], num_polyphony_steps=mp['num_polyphony_steps'])\n", "\n", "\n", "path_trained_model = os.path.join(dir_models, fn_model)\n", "\n", "model.load_state_dict(torch.load(path_trained_model, map_location=torch.device('cpu')))\n", "\n", "model.eval()\n", "summary(model, input_size=(1, 6, 174, 216))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load Test Audio and compute HCQT" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "fs = 22050\n", "\n", "audio_folder = os.path.join(basepath, 'data', 'MusicNet', 'audio')\n", "fn_audio = '2382_Beethoven_OP130_StringQuartet.wav'\n", "\n", "path_audio = os.path.join(audio_folder, fn_audio)\n", "f_audio, fs_load = librosa.load(path_audio, sr=fs)\n", "\n", "bins_per_semitone = 3\n", "num_octaves = 6\n", "n_bins = bins_per_semitone*12*num_octaves\n", "num_harmonics = 5\n", "num_subharmonics = 1\n", "center_bins=True\n", "\n", "f_hcqt, fs_hcqt, hopsize_cqt = compute_efficient_hcqt(f_audio, fs=22050, fmin=librosa.note_to_hz('C1'), fs_hcqt_target=50, \\\n", " bins_per_octave=bins_per_semitone*12, num_octaves=num_octaves, \\\n", " num_harmonics=num_harmonics, num_subharmonics=num_subharmonics, center_bins=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Predict Pitches" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(160, 72)\n" ] } ], "source": [ "# Set test parameters\n", "test_params = {'batch_size': 1,\n", " 'shuffle': False,\n", " 'num_workers': 1\n", " }\n", "device = 'cpu'\n", "\n", "test_dataset_params = {'context': 75,\n", " 'stride': 1,\n", " 'compression': 10\n", " }\n", "half_context = test_dataset_params['context']//2\n", "\n", "inputs = np.transpose(f_hcqt, (2, 1, 0))\n", "targets = np.zeros(inputs.shape[1:]) # need dummy targets to use dataset object\n", "\n", "inputs_context = torch.from_numpy(np.pad(inputs, ((0, 0), (half_context, half_context+1), (0, 0))))\n", "targets_context = torch.from_numpy(np.pad(targets, ((half_context, half_context+1), (0, 0))))\n", "\n", "test_set = dataset_context(inputs_context, targets_context, test_dataset_params)\n", "test_generator = torch.utils.data.DataLoader(test_set, **test_params)\n", "\n", "pred_tot = np.zeros((0, num_output_bins))\n", "\n", "max_frames = 160\n", "k=0\n", "for test_batch, test_labels in test_generator:\n", " k+=1\n", " if k>max_frames:\n", " break\n", " # Model computations\n", " y_pred, n_pred = model(test_batch)\n", " pred_log = torch.squeeze(torch.squeeze(y_pred.to('cpu'),2),1).detach().numpy()\n", " # pred_log = torch.squeeze(y_pred.to('cpu')).detach().numpy()\n", " pred_tot = np.append(pred_tot, pred_log, axis=0)\n", " \n", "predictions = pred_tot\n", "\n", "print(predictions.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Plot Predictions and Annotations (Ground truth)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 720x252 with 2 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "start_sec = 0\n", "show_sec = 3.5\n", "\n", "fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 0.05]}, figsize=(10, 3.5))\n", "im = libfmp.b.plot_matrix(predictions.T[:, int(start_sec*fs_hcqt):int(show_sec*fs_hcqt)], Fs=fs_hcqt, ax=ax, cmap='gray_r', ylabel='MIDI pitch')\n", "ax[0].set_yticks(np.arange(0, 73, 12))\n", "ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, num_octaves+1)])\n", "ax[0].set_title('Multi-pitch predictions')\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Compare with annotations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 720x252 with 2 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "annot_folder = os.path.join(basepath, 'data', 'MusicNet', 'csv')\n", "fn_annot = os.path.join(annot_folder, fn_audio[:-4]+'.csv')\n", "\n", "df = pd.read_csv(fn_annot, sep=',')\n", "note_events = df.to_numpy()[:,(0,1,3)]\n", "note_events[:,:2] /= 44100\n", "note_events = np.append(note_events, np.zeros((note_events.shape[0], 1)), axis=1)\n", " \n", "f_annot_pitch = compute_annotation_array_nooverlap(note_events.copy(), f_hcqt, fs_hcqt, annot_type='pitch', shorten=1.0)\n", "plt.rcParams.update({'font.size': 12})\n", "fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 0.05]}, figsize=(10, 3.5))\n", "\n", "cfig, cax, cim = libfmp.b.plot_matrix(f_annot_pitch[24:97, int(start_sec*fs_hcqt):int(show_sec*fs_hcqt)], ax=ax, Fs=fs_hcqt, cmap='gray_r', ylabel='MIDI pitch')\n", "plt.ylim([0, 73])\n", "ax[0].set_yticks(np.arange(0, 73, 12))\n", "ax[0].set_yticklabels([str(24+12*octave) for octave in range(0, num_octaves+1)])\n", "ax[0].set_title('Multi-pitch annotations (piano roll)')\n", "ax[1].set_ylim([0, 1])\n", "plt.tight_layout()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }