Attention Efficiency & Long Context · Problem 2 of 5
Implement a KV cache + incremental single-token decode loop for a small transformer.
Implement the function/class skeleton in the editor. Any correct approach is accepted.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CachedAttention(nn.Module):
def __init__(self, d_model, n_heads):
raise NotImplementedError
def forward(self, x, cache=None):
raise NotImplementedError
@torch.no_grad()
def decode(model, embed, unembed, prompt_ids, n_new):
raise NotImplementedErrorReady when you are
Submit your solution and a structured review appears here — verdict, score, and concrete feedback. Any correct approach passes.
Implement a KV cache + incremental single-token decode loop for a small transformer.
Implement the function/class skeleton in the editor. Any correct approach is accepted.