CodeExamples / Voice Recognition / Voice-Recognition-Training.py
Voice-Recognition-Training.py
Raw
import os
import sys
import time
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn


# This is a good set of 20 test cases, as they're mostly internally consistent and present in most speakers:
# 014 should always be included
# 002, 012, 013, 016, 018, 020, 032, 038, 043, 049, 054, 056, 062, 064, 067, 081, 083, 084, 094

class VoiceRecognitionDataset(Dataset):

    test_samples = ["014", "002", "012", "013", "016", "018", "020", "032", "038", "043", "049", "054", "056", "062", "064", "067", "081", "083", "084", "094"]

    def __init__(self, data_path, is_testdata):
        self.data_path = data_path
        self.num_speakers = len(os.listdir(data_path))
        print(self.num_speakers)
        self.speaker_list = []

        self.utterance_count = 0
        self.utterence_list = []
        self.audio_sample_list = []
        #search through all audio folders
        for dir in os.listdir(data_path):
            self.speaker_list.append(dir)
            #for each participant folder (e.g. p274)
            print(dir)
            for file in os.listdir(data_path + dir):
                if not file == "log.txt":
                    file_num = file[5:8]
                    mic_num = file[-5:-4]
                    if mic_num == "2":
                        if is_testdata:
                            if file_num in self.test_samples:
                                self.utterance_count += 1
                                self.utterence_list.append((dir, data_path + dir + "/" + file))
                                self.audio_sample_list.append(torch.transpose(torch.load(data_path + dir + "/" + file), 1, 2).squeeze())
                        else:
                            if file_num not in self.test_samples:
                                self.utterance_count += 1
                                self.utterence_list.append((dir, data_path + dir + "/" + file))   
                                self.audio_sample_list.append(torch.transpose(torch.load(data_path + dir + "/" + file), 1, 2).squeeze())
        print("Number of speakers: " + str(len(self.speaker_list)))     

    def __len__(self):
        return self.utterance_count

    def __getitem__(self, idx):
        #speaker, mel_spec = self.utterence_list[idx]
        speaker, _ = self.utterence_list[idx]
        #mel_spec = torch.transpose(torch.load(mel_spec), 1, 2).squeeze()
        mel_spec = self.audio_sample_list[idx]
        speaker = self.get_speaker_id(speaker)
        return mel_spec, speaker
    
    def get_speaker_id(self, speaker):
        return self.speaker_list.index(speaker)


def collate_fn_padded(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get labels
    labels = [y for (x,y) in batch]
    ## padd
    batch = [ torch.Tensor(x).to(device) for (x,y) in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    batch = torch.transpose(batch,0,1)
    ## compute mask
    #mask = (batch != 0).to(device)        
    return batch, labels  #, lengths, mask

def prepare_data(data_path):
    training_data = VoiceRecognitionDataset(data_path, is_testdata=False)
    test_data = VoiceRecognitionDataset(data_path, is_testdata=True)
    train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True, collate_fn=collate_fn_padded)
    test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True, collate_fn=collate_fn_padded)
    return train_dataloader, test_dataloader

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Model's basic structure
#
#               Should have BatchNorm after every layer it seems?
#   Mel Spectogram -> Convolutional Block -> (B?)LSTM -> Linear Block -> Encoding -> Prediction

class VoiceRecognitionModel(nn.Module):

    def __init__(self, num_speakers):
        super().__init__()
        self.num_speakers = num_speakers        

        self.convolution_block = (nn.Sequential(
            nn.Conv2d(1, 32, 3, padding='same'),
            nn.BatchNorm2d(32)
        ))
        for i in range(0,1):
            self.convolution_block.append(nn.Sequential(
                nn.Conv2d(32, 32, 3, padding='same'),
                nn.BatchNorm2d(32)
            ))
        self.convolution_block.append(nn.Sequential(
            nn.Conv2d(32, 1, 3, padding='same'),
            nn.BatchNorm2d(1)
        ))

        self.blstm = nn.LSTM(input_size=128, hidden_size=128, bidirectional=False, batch_first=True)
        self.blstm_batchnorm = nn.BatchNorm1d(128)

        self.linear_block = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU()
        )

        self.encoding = nn.Sequential(
            nn.Linear(32, 16), # what size?
            nn.BatchNorm1d(16),
            nn.ReLU()
        )

        self.prediction = nn.Sequential(
            nn.Linear(16, self.num_speakers), #how many speakers are there?
            nn.Softmax(dim=1)
        )

    def forward(self, x):          
        x = x.unsqueeze(1)
        x_hat = self.convolution_block(x)    
        x_hat = x_hat.squeeze()       
        output, (h_n, c_n)  = self.blstm(x_hat)
        x_hat = self.blstm_batchnorm(h_n.squeeze())   
        x_hat = self.linear_block(x_hat)
        x_hat = self.encoding(x_hat)
        pred = self.prediction(x_hat)
        return pred


#probably correct?
def train(dataloader, model, loss_fn, optimizer):    
    size = len(dataloader.dataset)
    model.train()
    start = time.time()
    for batch, (X, y) in enumerate(dataloader):
        y = torch.tensor(y).to(device)            
        X = X.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    end = time.time()
    print("Epoch Time: " + str((end - start)/60) + " minutes.")
    return True


#probably correct?
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), torch.tensor(y).to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


if __name__=="__main__":
    train_dataloader, test_dataloader = prepare_data(sys.argv[1]) #put proper path here

    model = VoiceRecognitionModel(num_speakers=110)
    model.to(device)
    loss_fn = nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    print("Model:")
    print(model)

    epochs = 100
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
    print("Done!")

    torch.save(model.state_dict(), "model.pth")
    print("Saved PyTorch Model State to model.pth")