Attention Efficiency & Long Context · Problem 3 of 5
Implement grouped-query attention (GQA) in PyTorch by repeating/broadcasting KV heads across query-head groups, and verify it matches full multi-head attention when num_kv_heads equals num_query_heads.
Implement the function/class skeleton in the editor. Any correct approach is accepted.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GQA(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads):
raise NotImplementedError
def forward(self, x):
raise NotImplementedError
def mha(x):
raise NotImplementedErrorReady when you are
Submit your solution and a structured review appears here — verdict, score, and concrete feedback. Any correct approach passes.
Implement grouped-query attention (GQA) in PyTorch by repeating/broadcasting KV heads across query-head groups, and verify it matches full multi-head attention when num_kv_heads equals num_query_heads.
Implement the function/class skeleton in the editor. Any correct approach is accepted.