import numpy as np
import torch
import torch.nn as nn
import math
from sklearn.preprocessing import StandardScaler
class scaled_dot_product_attention(torch.nn.Module):
def forward(self, q, k, v, mask = None, dropout = None): #q,k,v皆為(B, num_heads, seq_len, h)
score = torch.matmul(q,k.transpose(-1,-2))/math.sqrt(q.size(-1))
if mask is not None:
score = score.masked_fill(mask==0, -1e9) # if true, then fill -1e9
prob_score = torch.softmax(score, dim = -1)
if dropout is not None:
prob_score = dropout(prob_score)
attention_val = torch.matmul(prob_score, v)
return attention_val, prob_score
class MultiHeadAttention(torch.nn.Module):
def __init__(self, num_heads, qk_d_model, v_d_model, dropout_rate = 0.2):
super(MultiHeadAttention, self).__init__()
assert qk_d_model%num_heads==0
assert v_d_model%num_heads==0
self.hidden_units = v_d_model
self.num_heads = num_heads
self.qk_head_dim = qk_d_model // num_heads
self.v_head_dim = v_d_model // num_heads
self.W_Q = torch.nn.Linear(qk_d_model, qk_d_model)
self.W_K = torch.nn.Linear(qk_d_model, qk_d_model)
self.W_V = torch.nn.Linear(v_d_model, v_d_model)
self.fc = torch.nn.Linear(v_d_model, v_d_model)
self.attention = scaled_dot_product_attention()
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(self, queries, keys, values, mask=None):
# queries, keys, values: (N, T, C)
batch_size = queries.size(0)
q = self.W_Q(queries).view(batch_size,-1, self.num_heads, self.qk_head_dim)
k = self.W_K(keys).view(batch_size,-1, self.num_heads, self.qk_head_dim)
v = self.W_V(values).view(batch_size,-1, self.num_heads, self.v_head_dim)
q,k,v = [x.transpose(1,2) for x in [q,k,v]]
attention_output, attention_weights = self.attention(q, k, v, mask = mask, dropout = self.dropout)
attention_output = attention_output.transpose(1,2).contiguous().view(batch_size, -1, self.hidden_units)
outputs = self.fc(attention_output)
return outputs, attention_weights
class CL_Projector(nn.Module):
def __init__(self, opt):
super().__init__()
self.dense = nn.Linear(opt["hidden_units"], opt["hidden_units"])
self.activation = nn.ReLU()
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, first_token_tensor):
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class GenderDiscriminator(nn.Module):
def __init__(self,opt):
super(GenderDiscriminator, self).__init__()
self.opt = opt
self.attention=nn.Linear(opt['hidden_units'],1)
self.layer = nn.Sequential(
nn.Linear(opt['hidden_units'], 2*opt['hidden_units']),
nn.ReLU(),
nn.Linear(2*opt['hidden_units'], opt['hidden_units']),
nn.ReLU(),
nn.Linear(opt['hidden_units'],1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = torch.nn.functional.softmax(self.attention(x), dim=1)
weighted_average = torch.sum(x * attention_weights, dim=1)
return self.layer(weighted_average), attention_weights
class ClusterRepresentation(nn.Module):
def __init__(self, opt, feature_dim, num_clusters, topk):#topk most similar cluster
super(ClusterRepresentation, self).__init__()
self.opt = opt
self.num_clusters = num_clusters
self.feature_dim = feature_dim
self.topk = topk
self.cluster_prototypes = nn.Parameter(torch.randn(num_clusters, feature_dim))
self.feature_extractor = nn.Sequential(
nn.Linear(feature_dim*topk, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim),
)
def forward(self, features):
sim = features@self.cluster_prototypes.T #[X, num_clusters]
sim /= sim.max(-1,keepdim = True)[0]
weight = torch.softmax(sim, dim=-1)
new_cluster = weight.T@features
new_sim = features@new_cluster.T
_, top_k_indice = torch.topk(new_sim, self.topk, dim=-1)#[X, topk]
multi_interest = new_cluster[top_k_indice.squeeze()]#[B,topk,feature_dim]
multi_interest = self.feature_extractor(multi_interest.reshape(-1, self.feature_dim*self.topk))#[B, feature_dim]
return new_cluster, multi_interest