5. deep learning operators
5.1. flash attention
5.2. transformer decoder k/v cache
5.3. LSTM
The multi-layer LSTM can be implemented as:
from collections import OrderedDict
from typing import Optional
import torch
class MyLSTM:
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bidirectional: bool = False,
):
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.wi = []
self.wh = []
self.bias = []
self.wi_reverse = []
self.wh_reverse = []
self.bias_reverse = []
def set_weights(self, state_dict: OrderedDict):
for layer_id in range(self.num_layers):
self.wi.append(state_dict[f"weight_ih_l{layer_id}"].T)
self.wh.append(state_dict[f"weight_hh_l{layer_id}"].T)
self.bias.append(
state_dict[f"bias_ih_l{layer_id}"]
+ state_dict[f"bias_hh_l{layer_id}"]
)
if self.bidirectional:
self.wi_reverse.append(
state_dict[f"weight_ih_l{layer_id}_reverse"].T
)
self.wh_reverse.append(
state_dict[f"weight_hh_l{layer_id}_reverse"].T
)
self.bias_reverse.append(
state_dict[f"bias_ih_l{layer_id}_reverse"]
+ state_dict[f"bias_hh_l{layer_id}_reverse"]
)
def forward(
self,
x,
h_0: Optional[torch.Tensor] = None,
c_0: Optional[torch.Tensor] = None,
):
seq_len, batch_size, input_size = x.shape
assert input_size == self.input_size
if h_0 is None:
h_n = torch.zeros(
self.num_layers,
batch_size,
self.hidden_size,
device=x.device,
dtype=x.dtype,
)
else:
h_n = h_0.clone()
if c_0 is None:
c_n = torch.zeros(
self.num_layers,
batch_size,
self.hidden_size,
device=x.device,
dtype=x.dtype,
)
else:
c_n = c_0.clone()
out = torch.empty(
seq_len, batch_size, self.hidden_size, device=x.device, dtype=x.dtype
)
for t in range(seq_len):
# layer 0
ifgo = (
torch.matmul(x[t], self.wi[0])
+ torch.matmul(h_n[0], self.wh[0])
+ self.bias[0]
)
i = ifgo[:, 0 : self.hidden_size]
f = ifgo[:, self.hidden_size : self.hidden_size * 2]
g = ifgo[:, self.hidden_size * 2 : self.hidden_size * 3]
o = ifgo[:, self.hidden_size * 3 : self.hidden_size * 4]
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c_n[0] = f * c_n[0] + i * g
h_n[0] = o * torch.tanh(c_n[0])
# layer >= 1
for layer_id in range(1, self.num_layers):
ifgo = (
torch.matmul(h_n[layer_id - 1], self.wi[layer_id])
+ torch.matmul(h_n[layer_id], self.wh[layer_id])
+ self.bias[layer_id]
)
i = ifgo[:, 0 : self.hidden_size]
f = ifgo[:, self.hidden_size : self.hidden_size * 2]
g = ifgo[:, self.hidden_size * 2 : self.hidden_size * 3]
o = ifgo[:, self.hidden_size * 3 : self.hidden_size * 4]
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c_n[layer_id] = f * c_n[layer_id] + i * g
h_n[layer_id] = o * torch.tanh(c_n[layer_id])
out[t] = h_n[self.num_layers - 1].clone()
return out, h_n, c_n
5.4. 方差计算
方差定义(two-pass method):
\[\begin{aligned}
\bar{x}_n &= \frac{1}{n} \sum_{i=1}^{n} x_i \\
\sigma^2_n &= \frac{1}{n} \sum_{i=1}^{n} (x_i - \bar{x}_n)^2
\end{aligned}
\]
方差计算简化方式 (naive method, use one-pass):
\[\begin{aligned}
\sigma^2_n &= \frac{1}{n} \sum_{i=1}^{n} x_i^2 - \bar{x}_n^2
\end{aligned}
\]
5.4.1. Welford计算方差
Welford计算方差是用one-pass method,但误差远小于naive method。
均值的递推关系:
(1)\[\bar{x}_n = \frac{(n-1)\bar{x}_{n-1} + x_n }{n} = \bar{x}_{n-1} + \frac{x_n - \bar{x}_{n-1}}{n}\]
方差的递推关系:
\[\begin{aligned}
& n \sigma_n^2 - (n-1) \sigma_{n-1}^2 \\
&= \sum_{i=1}^{n} (x_i - \bar{x}_n)^2 - \sum_{i=1}^{n-1} (x_i - \bar{x}_{n-1})^2 \\
&= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} \left( (x_i - \bar{x}_n)^2 - (x_i - \bar{x}_{n-1})^2 \right) \\
&= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} \left( (x_i - \bar{x}_n)^2 - (x_i - \bar{x}_{n-1})^2 \right) \\
&= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} (2 x_i - \bar{x}_n - \bar{x}_{n-1}) (\bar{x}_{n-1} - \bar{x}_n) \\
&= (x_n - \bar{x}_n)^2 + (\bar{x}_n - x_n) (\bar{x}_{n-1} - \bar{x}_n) \\
&= (x_n - \bar{x}_n) (x_n - \bar{x}_{n-1})
\end{aligned}
\]
定义 \(M_n\) (文献中通常定义成 \(M_{2,n}\) ) 如下:
\[M_{n} = \sum_{i=1}^n (x_i - \bar{x}_n)^2
\]
可以得到Welford算法如下:
(2)\[ \begin{aligned}
M_{n} &= M_{n-1} + (x_n - \bar{x}_n) (x_n - \bar{x}_{n-1}) \\
\sigma_n^2 &= M_{n} / n
\end{aligned}\]
当 \(x_n\) 偏离均值比较多的时候,\(x_n - \bar{x}_n\) 比较小,\(x_n - \bar{x}_{n-1}\) 比较大, 线性偏差。
python实现示例如下:
from typing import Tuple
import numpy as np
def welford_online_update(
x: float, count: int, s: float, a: float
) -> Tuple[float, float]:
"""
a: average of x_i
s: M_{2,n}
"""
b = a + (x - a) / count
s += (x - b) * (x - a)
return s, b
def calculate_mean_and_var(arr: np.ndarray) -> Tuple[float, float]:
a: float = 0
s: float = 0
for i in range(arr.size):
s, a = welford_online_update(arr[i], i + 1, s, a)
return a, s / arr.size
if __name__ == "__main__":
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32)
a, s = calculate_mean_and_var(arr)
print(f"ref average = {arr.mean()}, var = {arr.var()}")
print(f"welford average = {a}, var = {s}")
print()
5.4.2. c++实现示例
demo只考虑样本数是4的倍数。
c++实现示例 WelfordCpp, 支持x86 sse和arm neon指令。
5.4.3. 除法计算和Newton-Raphson iteration
均值递推关系 (1) 需要计算除法 1/n,但除法计算的延迟比较高。
Newton's method (Newton-Raphson method):
\[x_1 = x_0 - \frac{f(x_0)}{f^{\prime}(x_0)}\]
Newton–Raphson division is a fast method to calculate the reciprocal of a number \(a\). We can define \(f(x) = 1/x - a\) and thus \(f^{\prime}(x) = -1/x^2\).
Then Newton's iteration is:
\[\begin{aligned}
x_{n+1} &= x_n - \frac{f(x_n)}{f^{\prime}(x_n)} \\
&= x_n - \frac{\frac{1}{x_n} - a}{-\frac{1}{x_n^2}} \\
&= x_n (2 - a x_n)
\end{aligned}
\]
为什么不选去其他函数,比如 \(f(x) = a x - 1\) ,主要是收敛性和收敛速度等决定的。
ARM neon实现示例如下:
float32x4_t fast_reciprocal(float32x4_t a) {
float32x4_t recip = vrecpeq_f32(a);
// Newton-Raphson iteration two times
recip = vmulq_f32(recip, vrecpsq_f32(recip, a));
recip = vmulq_f32(recip, vrecpsq_f32(recip, a));
return recip;
}
vrecpsq_f32 就是计算 \(2.0 - a * x\) 。
待处理
cuda layernorm