← Каталог
Устройство трансформеров — теория и практика с нуля — Scaled dot-product attention
Фрагмент из «Устройство трансформеров — теория и практика с нуля»: Scaled dot-product attention.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def scaled_dot_product_attention(q, k, v, mask=None):
# q, k, v: (batch, heads, seq_len, d_k)
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, v), weights import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def scaled_dot_product_attention(q, k, v, mask=None):
# q, k, v: (batch, heads, seq_len, d_k)
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, v), weights