import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import collections
from transformers import WordpieceTokenizer
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
class Bert_Layer(nn.Module):
back_type = ['
', '', '', '', '', '']
def extend_bert_vocab(self, words_to_extend):
# print(all_words)
init_len = len(self.tokenizer.vocab)
cur_ind = init_len
for i in words_to_extend:
if i in self.tokenizer.vocab:
continue
self.tokenizer.vocab[i] = cur_ind
cur_ind += 1
print(f"extend bert tokenizer with extra {cur_ind - init_len} words!")
self.tokenizer.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.tokenizer.vocab.items()])
self.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.tokenizer.vocab,
unk_token=self.tokenizer.unk_token)
self.encoder._resize_token_embeddings(cur_ind)
def __init__(self, args):
super(Bert_Layer, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.encoder = BertModel.from_pretrained('bert-base-uncased')
# self.extend_bert_vocab(self.back_type)
# for param in self.encoder.embeddings.word_embeddings.parameters():
# param.requires_grad = False
if args.cuda:
self.new_long_tensor = torch.cuda.LongTensor
self.new_tensor = torch.cuda.FloatTensor
else:
self.new_long_tensor = torch.LongTensor
self.new_tensor = torch.FloatTensor
torch.manual_seed(0)
self.dim_transform = nn.Linear(768, args.hidden_size)
torch.manual_seed(1)
self.column_enc = nn.LSTM(768, args.col_embed_size // 2, bidirectional=True,
batch_first=True)
torch.manual_seed(2)
self.table_enc = nn.LSTM(768, args.col_embed_size // 2, bidirectional=True,
batch_first=True)
def forward(self, src_sents, src_sents_len, col_names, table_names):
'''
:param src_sents_var: [[span,span,...,span]] * batch_size; span=[word1,word2,...] (word1 might be 'column' 'table' and so on)
:param src_sents_len: [span_len] * batch_size (descending order)
:param col_names: [[col,col,...col]] * batch_size; col=[word1,word2,...]
:return:
'''
# print(f"src_sents:\t{src_sents}")
# print(f"src_sents_len:\t{src_sents_len}")
# print(f"col_names:\t{col_names}")
bert_input, infos = self._formatting(src_sents, col_names, table_names)
# print(f"bert_input:\t{bert_input}")
padded_input = self._pad_input(bert_input)
hidden_states, last_cell = self.encoder(padded_input)
# print(hidden_states.shape)
# print(infos)
sent_bert_outs, col_bert_outs, table_bert_outs = [], [], []
col_split, table_split = [0], [0]
col_lens, table_lens = [], []
for batch_iter, formatted_info in enumerate(infos):
sent_h = []
span_ptr = 1
for i in formatted_info['sent']:
span_h = torch.mean(hidden_states[batch_iter, span_ptr:span_ptr + i, :], dim=0, keepdim=True)
sent_h.append(span_h)
span_ptr += i
sent_bert_outs.append(torch.cat(sent_h, dim=0))
col_split.append(col_split[-1] + len(formatted_info['col']))
col_lens += formatted_info['col']
for i in formatted_info['col']:
span_ptr += 1
col_bert_outs.append(hidden_states[batch_iter, span_ptr:span_ptr + i, :])
span_ptr += i
table_split.append(table_split[-1] + len(formatted_info['table']))
table_lens += formatted_info['table']
for i in formatted_info['table']:
span_ptr += 1
table_bert_outs.append(hidden_states[batch_iter, span_ptr:span_ptr + i, :])
span_ptr += i
assert src_sents_len == [x.shape[0] for x in
sent_bert_outs], f'{src_sents_len} vs {[x.shape[0] for x in sent_bert_outs]}'
sent_outs = pad_sequence(sent_bert_outs, batch_first=True) # bsize * max_sents_len * 768
sent_outs = self.dim_transform(sent_outs)
print("dummy2")
print(sent_outs[0])
col_bert_outs = pad_sequence(col_bert_outs, batch_first=True) # total_col_num * max_col_len * 768
table_bert_outs = pad_sequence(table_bert_outs, batch_first=True) # total_table_num * max_table_len * 768
col_lens = self.new_long_tensor(col_lens)
table_lens = self.new_long_tensor(table_lens)
_, (col_last_states, _) = rnn_wrapper(self.column_enc, col_bert_outs, col_lens)
col_lstm_outs = torch.cat([col_last_states[0], col_last_states[1]], -1) # total_col_num * hidden_size
assert col_lstm_outs.shape[0] == col_split[-1]
col_outs = [col_lstm_outs[col_split[i]:col_split[i + 1]] for i in range(len(col_split) - 1)]
col_outs = pad_sequence(col_outs, batch_first=True)
_, (table_last_states, _) = rnn_wrapper(self.table_enc, table_bert_outs, table_lens)
table_lstm_outs = torch.cat([table_last_states[0], table_last_states[1]], -1)
assert table_lstm_outs.shape[0] == table_split[-1]
table_outs = [table_lstm_outs[table_split[i]:table_split[i + 1]] for i in range(len(table_split) - 1)]
table_outs = pad_sequence(table_outs, batch_first=True)
# print(f'sent_bert_outs:\t{sent_bert_outs.shape}')
# print(f'col_bert_outs:\t{col_bert_outs.shape}')
# print(f'table_bert_outs:\t{table_bert_outs.shape}')
# print(sent_bert_outs)
return sent_outs, col_outs, table_outs, last_cell
def _formatting(self, src_sents, col_names, table_names):
formatted_inputs = []
formatted_infos = []
for src_sent, col_name, table_name in zip(src_sents, col_names, table_names):
formatted_info = {"sent": [], "col": [], "table": []}
formatted = [self.tokenizer.cls_token]
for span in src_sent:
formatted_info['sent'].append(len(span))
formatted += span
formatted += [self.tokenizer.sep_token]
for i in col_name:
formatted_info['col'].append(len(i))
formatted += i
formatted += [self.tokenizer.sep_token]
for i in table_name:
formatted_info['table'].append(len(i))
formatted += i
formatted += [self.tokenizer.sep_token]
formatted_inputs.append(formatted)
formatted_infos.append(formatted_info)
return formatted_inputs, formatted_infos
def _pad_input(self, bert_inp):
bert_lens = [len(i) for i in bert_inp]
max_len = max(bert_lens)
padded_inp = [p + [self.tokenizer.pad_token] * (max_len - i) for i, p in zip(bert_lens, bert_inp)]
padded_inp = self.new_long_tensor([self.tokenizer.convert_tokens_to_ids(i) for i in padded_inp])
return padded_inp
def rnn_wrapper(encoder, inputs, lens, cell='lstm'):
"""
@args:
encoder(nn.Module): rnn series bidirectional encoder, batch_first=True
inputs(torch.FloatTensor): rnn inputs, bsize x max_seq_len x in_dim
lens(torch.LongTensor): seq len for each sample, bsize
@return:
out(torch.FloatTensor): output of encoder, bsize x max_seq_len x hidden_dim*2
hidden_states(tuple or torch.FloatTensor): final hidden states, num_layers*2 x bsize x hidden_dim
"""
# rerank according to lens and temporarily remove empty inputs
sorted_lens, sort_key = torch.sort(lens, descending=True)
nonzero_index = torch.sum(sorted_lens > 0).item()
sorted_inputs = torch.index_select(inputs, dim=0, index=sort_key[:nonzero_index])
# forward non empty inputs
packed_inputs = pack_padded_sequence(sorted_inputs, sorted_lens[:nonzero_index].tolist(), batch_first=True)
packed_out, h = encoder(packed_inputs) # bsize x srclen x dim
out, _ = pad_packed_sequence(packed_out, batch_first=True)
if cell.upper() == 'LSTM':
h, c = h
# pad zeros due to empty inputs
pad_zeros = torch.zeros(sorted_lens.size(0) - out.size(0), out.size(1), out.size(2)).type_as(out).to(out.device)
sorted_out = torch.cat([out, pad_zeros], dim=0)
pad_hiddens = torch.zeros(h.size(0), sorted_lens.size(0) - h.size(1), h.size(2)).type_as(h).to(h.device)
sorted_hiddens = torch.cat([h, pad_hiddens], dim=1)
if cell.upper() == 'LSTM':
pad_cells = torch.zeros(c.size(0), sorted_lens.size(0) - c.size(1), c.size(2)).type_as(c).to(c.device)
sorted_cells = torch.cat([c, pad_cells], dim=1)
# rerank according to sort_key
shape = list(sorted_out.size())
out = torch.zeros_like(sorted_out).type_as(sorted_out).to(sorted_out.device).scatter_(0, sort_key.unsqueeze(
-1).unsqueeze(-1).expand(*shape), sorted_out)
shape = list(sorted_hiddens.size())
hiddens = torch.zeros_like(sorted_hiddens).type_as(sorted_hiddens).to(sorted_hiddens.device).scatter_(1,
sort_key.unsqueeze(
0).unsqueeze(
-1).expand(
*shape),
sorted_hiddens)
if cell.upper() == 'LSTM':
cells = torch.zeros_like(sorted_cells).type_as(sorted_cells).to(sorted_cells.device).scatter_(1,
sort_key.unsqueeze(
0).unsqueeze(
-1).expand(
*shape),
sorted_cells)
return out, (hiddens.contiguous(), cells.contiguous())
return out, hiddens.contiguous()