1class Attention(nn.Module):
2 def __init__(self, model_args: MOEConfig):
3 super().__init__()
4 d_model = model_args.d_model
5 self.num_heads = model_args.num_heads
6 self.head_dim = model_args.d_model // model_args.num_heads
7 self.attn_dropout = nn.Dropout(model_args.dropout)
8 self.res_dropout = nn.Dropout(model_args.dropout)
9 self.flash_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
10
11 self.q_lora_rank = model_args.q_lora_rank
12 self.qk_rope_head_dim = model_args.qk_rope_head_dim
13 self.kv_lora_rank = model_args.kv_lora_rank
14 self.v_head_dim = model_args.v_head_dim
15 self.qk_nope_head_dim = model_args.qk_nope_head_dim
16 self.q_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim
17 self.q_a_proj = nn.Linear(d_model, model_args.q_lora_rank, bias=False)
18 self.q_a_layernorm = RMSNorm(model_args.q_lora_rank)
19 self.q_b_proj = nn.Linear(model_args.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
20 self.kv_a_proj_with_mqa = nn.Linear(d_model,model_args.kv_lora_rank + model_args.qk_rope_head_dim,bias=False,)
21 self.kv_a_layernorm = RMSNorm(model_args.kv_lora_rank)
22 self.kv_b_proj = nn.Linear(model_args.kv_lora_rank,self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),bias=False,)
23 self.o_proj = nn.Linear(self.num_heads * self.v_head_dim,d_model, bias=False,)
24 def forward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) -> torch.Tensor:
25 batch, seq_len, d_model = x.shape
26 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
27 q = q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
28 q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
29 compressed_kv = self.kv_a_proj_with_mqa(x)
30 compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
31 k_pe = k_pe.view(batch, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
32 kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
33 .view(batch, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
34 .transpose(1, 2))
35 k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
36 kv_seq_len = value_states.shape[-2]
37 q_pe, k_pe = apply_rope(q_pe, k_pe, freqs_cis)
38 k_pe = k_pe.transpose(2, 1)
39 q_pe = q_pe.transpose(2, 1)
40 query_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
41 query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
42 query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
43 key_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
44 key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
45 key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
46 attn_mtx = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
47 attn_mtx = attn_mtx + mask[:, :, :seq_len, :seq_len]
48 attn_mtx = F.softmax(attn_mtx.float(), dim=-1).type_as(key_states)
49 attn_mtx = self.attn_dropout(attn_mtx)
50 output = torch.matmul(attn_mtx, value_states) # (batch, n_head, seq_len, head_dim)
51 output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads * self.v_head_dim)
52 output = self.o_proj(output)
53 output = self.res_dropout(output)
54 return output