HashFormer Simplificado
Implementação básica do HashFormer usando PyTorch. Este código mostra como implementar atenção aproximada baseada em hash para reduzir complexidade computacional.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
import math
class LSHAttention(nn.Module):
"""Locality-Sensitive Hashing Attention for HashFormer"""
def __init__(self, d_model: int, n_heads: int, n_buckets: int = 64,
n_rounds: int = 4, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_buckets = n_buckets
self.n_rounds = n_rounds
self.d_k = d_model // n_heads
# Linear projections for Q, K, V
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model)
# Random projection matrices for LSH
self.register_buffer('random_projections',
torch.randn(n_rounds, self.d_k, n_buckets // 2))
self.dropout = nn.Dropout(dropout)
self.scale = 1.0 / math.sqrt(self.d_k)
def hash_vectors(self, vecs: torch.Tensor) -> torch.Tensor:
"""Apply LSH to group similar vectors"""
batch_size, seq_len, n_heads, d_k = vecs.shape
# Reshape for hashing: [batch_size * seq_len * n_heads, d_k]
vecs_flat = vecs.reshape(-1, d_k)
# Apply random projections
projections = torch.matmul(vecs_flat.unsqueeze(1),
self.random_projections) # [B*L*H, n_rounds, n_buckets//2]
# Create hash codes
hash_codes = torch.where(projections > 0, 1, 0)
# Convert binary codes to bucket indices
bucket_indices = torch.sum(hash_codes *
torch.pow(2, torch.arange(self.n_buckets // 2,
device=vecs.device)), dim=-1)
return bucket_indices.reshape(batch_size, seq_len, n_heads, self.n_rounds)
def create_attention_mask(self, bucket_indices: torch.Tensor) -> torch.Tensor:
"""Create attention mask based on hash buckets"""
batch_size, seq_len, n_heads, n_rounds = bucket_indices.shape
# Create mask for each round
masks = []
for r in range(n_rounds):
# Get bucket indices for this round
buckets = bucket_indices[:, :, :, r] # [B, L, H]
# Create pairwise comparison
buckets_i = buckets.unsqueeze(2) # [B, L, 1, H]
buckets_j = buckets.unsqueeze(1) # [B, 1, L, H]
# Tokens attend to others in same bucket
mask = (buckets_i == buckets_j).float() # [B, L, L, H]
masks.append(mask)
# Combine masks from all rounds (union)
combined_mask = torch.stack(masks, dim=-1).max(dim=-1)[0]
return combined_mask.permute(0, 3, 1, 2) # [B, H, L, L]
def forward(self, x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, d_model = x.shape
# Linear projections
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k)
# Apply LSH to queries and keys
q_buckets = self.hash_vectors(Q)
k_buckets = self.hash_vectors(K)
# Create hash-based attention mask
hash_mask = self.create_attention_mask(q_buckets)
# Reshape for attention computation
Q = Q.permute(0, 2, 1, 3) # [B, H, L, d_k]
K = K.permute(0, 2, 1, 3) # [B, H, L, d_k]
V = V.permute(0, 2, 1, 3) # [B, H, L, d_k]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
# Apply hash mask (only attend within same buckets)
scores = scores.masked_fill(hash_mask == 0, float('-inf'))
# Apply additional attention mask if provided
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# Softmax and dropout
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, V)
# Reshape and apply output projection
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(batch_size, seq_len, d_model)
output = self.w_o(attn_output)
return output
class HashFormerBlock(nn.Module):
"""Single HashFormer transformer block"""
def __init__(self, d_model: int, n_heads: int, d_ff: int,
n_buckets: int = 64, dropout: float = 0.1):
super().__init__()
self.attention = LSHAttention(d_model, n_heads, n_buckets, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Self-attention with residual connection
attn_output = self.attention(self.norm1(x), attention_mask)
x = x + attn_output
# Feed-forward with residual connection
ffn_output = self.ffn(self.norm2(x))
x = x + ffn_output
return x
class HashFormer(nn.Module):
"""Complete HashFormer model"""
def __init__(self, vocab_size: int, d_model: int = 512, n_heads: int = 8,
n_layers: int = 6, d_ff: int = 2048, max_length: int = 512,
n_buckets: int = 64, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_length = max_length
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_length, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
HashFormerBlock(d_model, n_heads, d_ff, n_buckets, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Embeddings
token_embeds = self.token_embedding(input_ids)
pos_embeds = self.position_embedding(positions)
x = self.dropout(token_embeds + pos_embeds)
# Apply transformer blocks
for block in self.blocks:
x = block(x, attention_mask)
return self.norm(x)
# Exemplo de uso
def main():
# Configuração do modelo
model = HashFormer(
vocab_size=50000,
d_model=512,
n_heads=8,
n_layers=6,
d_ff=2048,
n_buckets=32, # Menos buckets = mais aproximação
dropout=0.1
)
# Dados de exemplo
batch_size, seq_len = 4, 1024
input_ids = torch.randint(0, 50000, (batch_size, seq_len))
# Forward pass
output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"Memory usage significantly reduced compared to standard Transformer!")
# Comparar complexidade
print(f"\nComplexity comparison:")
print(f"Standard Attention: O({seq_len}²) = O({seq_len**2:,})")
print(f"HashFormer Attention: O({seq_len} * log({seq_len})) = O({seq_len * int(np.log2(seq_len)):,})")
print(f"Speedup: ~{(seq_len**2) // (seq_len * int(np.log2(seq_len)))}x")
if __name__ == "__main__":
main()