Recurrent-Neural-Networks / rnn / generate.py
generate.py
Raw
# https://github.com/spro/char-rnn.pytorch

import torch
import os
import argparse

from rnn.helpers import *
from rnn.model import *

def generate(decoder, prime_str='A', predict_len=100, temperature=0.8, device=None):
    hidden = decoder.init_hidden(1, device=device)
    prime_input = char_tensor(prime_str).unsqueeze(0).to(device)

    predicted = prime_str

    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[:,p], hidden)
        
    inp = prime_input[:,-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)
        
        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Add predicted character to string and use as next input
        predicted_char = all_characters[top_i]
        predicted += predicted_char
        inp = char_tensor(predicted_char).unsqueeze(0).to(device)

    return predicted