FairCDSR / FairCDSR.py
FairCDSR.py
Raw
import numpy as np
import torch
import torch.nn as nn
import math
from sklearn.preprocessing import StandardScaler


class scaled_dot_product_attention(torch.nn.Module):
    def forward(self, q, k, v, mask = None, dropout = None): #q,k,v皆為(B, num_heads, seq_len, h)
        score = torch.matmul(q,k.transpose(-1,-2))/math.sqrt(q.size(-1))
        if mask is not None:
            score = score.masked_fill(mask==0, -1e9) # if true, then fill -1e9
        prob_score = torch.softmax(score, dim = -1)
        if dropout is not None:
            prob_score = dropout(prob_score)
        attention_val = torch.matmul(prob_score, v)
        return attention_val, prob_score            
class MultiHeadAttention(torch.nn.Module):
    def __init__(self,  num_heads, qk_d_model, v_d_model, dropout_rate = 0.2):
        super(MultiHeadAttention, self).__init__()
        assert qk_d_model%num_heads==0
        assert v_d_model%num_heads==0
        self.hidden_units = v_d_model
        self.num_heads = num_heads
        self.qk_head_dim = qk_d_model // num_heads
        self.v_head_dim = v_d_model // num_heads
        self.W_Q = torch.nn.Linear(qk_d_model, qk_d_model)
        self.W_K = torch.nn.Linear(qk_d_model, qk_d_model)
        self.W_V = torch.nn.Linear(v_d_model, v_d_model)

        self.fc = torch.nn.Linear(v_d_model, v_d_model)
        self.attention = scaled_dot_product_attention()
        self.dropout = torch.nn.Dropout(p=dropout_rate)
       
    def forward(self, queries, keys, values, mask=None):
        # queries, keys, values: (N, T, C)
        batch_size = queries.size(0)
        q = self.W_Q(queries).view(batch_size,-1, self.num_heads, self.qk_head_dim)
        k = self.W_K(keys).view(batch_size,-1, self.num_heads, self.qk_head_dim)
        v = self.W_V(values).view(batch_size,-1, self.num_heads, self.v_head_dim)
        q,k,v = [x.transpose(1,2) for x in [q,k,v]]
        attention_output, attention_weights = self.attention(q, k, v, mask = mask, dropout = self.dropout)
        attention_output = attention_output.transpose(1,2).contiguous().view(batch_size, -1, self.hidden_units)
        outputs = self.fc(attention_output)
        return outputs, attention_weights


class CL_Projector(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.dense = nn.Linear(opt["hidden_units"], opt["hidden_units"])
        self.activation = nn.ReLU()
        self.apply(self._init_weights)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, mode='fan_in', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, first_token_tensor):
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class GenderDiscriminator(nn.Module):
    def __init__(self,opt):
        super(GenderDiscriminator, self).__init__()
        self.opt = opt
        self.attention=nn.Linear(opt['hidden_units'],1)
        self.layer = nn.Sequential(
            nn.Linear(opt['hidden_units'], 2*opt['hidden_units']),
            nn.ReLU(),
            nn.Linear(2*opt['hidden_units'], opt['hidden_units']),
            nn.ReLU(),
            nn.Linear(opt['hidden_units'],1),
            nn.Sigmoid()
        )
    def forward(self, x):
        attention_weights = torch.nn.functional.softmax(self.attention(x), dim=1)
        weighted_average = torch.sum(x * attention_weights, dim=1)
        
        return self.layer(weighted_average), attention_weights
        
class ClusterRepresentation(nn.Module):
    def __init__(self, opt, feature_dim, num_clusters, topk):#topk most similar cluster
        super(ClusterRepresentation, self).__init__()
        self.opt = opt
        self.num_clusters = num_clusters
        self.feature_dim = feature_dim
        self.topk = topk
        self.cluster_prototypes = nn.Parameter(torch.randn(num_clusters, feature_dim))
        self.feature_extractor = nn.Sequential(
            nn.Linear(feature_dim*topk, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim),
        )
        
    def forward(self, features):
        sim = features@self.cluster_prototypes.T #[X, num_clusters]
        sim /= sim.max(-1,keepdim = True)[0]
        weight = torch.softmax(sim, dim=-1)
        new_cluster = weight.T@features
        new_sim = features@new_cluster.T
        _, top_k_indice  = torch.topk(new_sim, self.topk, dim=-1)#[X, topk]
        multi_interest = new_cluster[top_k_indice.squeeze()]#[B,topk,feature_dim]
        multi_interest = self.feature_extractor(multi_interest.reshape(-1, self.feature_dim*self.topk))#[B, feature_dim]
        return new_cluster, multi_interest