class AdditiveAttention(nn.Module): """加性注意力""" def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k = nn.Linear(key_size, num_hiddens, bias=False) self.W_q = nn.Linear(query_size, num_hiddens, bias=False) self.w_v = nn.Linear(num_hiddens, 1, bias=False) self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)
|