This is a placeholder post with some math, code and images to make sure everything renders nicely.
Some math
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$ $$h_t = A_t \odot h_{t-1} + B_t \odot x_t, \qquad y_t = C_t^\top h_t$$Some inline math: $A_t, B_t, C_t$, $\mathcal{O}(N)$.
Some code
A minimal PyTorch implementation of a single linear-RNN step:
import torch
import torch.nn as nn
class LinearRNNCell(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.A = nn.Parameter(torch.ones(d_model) * 0.9)
self.B = nn.Linear(d_model, d_model, bias=False)
self.C = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
# x, h : (batch, d_model)
h_new = self.A * h + self.B(x) # gated recurrence
y = self.C(h_new) # read-out
return y, h_new
Some images
That's all. Hopefully this gets replaced by actual stuff soon.