5. deep learning operators

5.1. flash attention

5.2. transformer decoder k/v cache

5.3. LSTM

The multi-layer LSTM can be implemented as:

from collections import OrderedDict
from typing import Optional

import torch


class MyLSTM:
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bidirectional: bool = False,
    ):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        self.wi = []
        self.wh = []
        self.bias = []

        self.wi_reverse = []
        self.wh_reverse = []
        self.bias_reverse = []

    def set_weights(self, state_dict: OrderedDict):
        for layer_id in range(self.num_layers):
            self.wi.append(state_dict[f"weight_ih_l{layer_id}"].T)
            self.wh.append(state_dict[f"weight_hh_l{layer_id}"].T)
            self.bias.append(
                state_dict[f"bias_ih_l{layer_id}"]
                + state_dict[f"bias_hh_l{layer_id}"]
            )

            if self.bidirectional:
                self.wi_reverse.append(
                    state_dict[f"weight_ih_l{layer_id}_reverse"].T
                )
                self.wh_reverse.append(
                    state_dict[f"weight_hh_l{layer_id}_reverse"].T
                )
                self.bias_reverse.append(
                    state_dict[f"bias_ih_l{layer_id}_reverse"]
                    + state_dict[f"bias_hh_l{layer_id}_reverse"]
                )

    def forward(
        self,
        x,
        h_0: Optional[torch.Tensor] = None,
        c_0: Optional[torch.Tensor] = None,
    ):
        seq_len, batch_size, input_size = x.shape

        assert input_size == self.input_size

        if h_0 is None:
            h_n = torch.zeros(
                self.num_layers,
                batch_size,
                self.hidden_size,
                device=x.device,
                dtype=x.dtype,
            )
        else:
            h_n = h_0.clone()

        if c_0 is None:
            c_n = torch.zeros(
                self.num_layers,
                batch_size,
                self.hidden_size,
                device=x.device,
                dtype=x.dtype,
            )
        else:
            c_n = c_0.clone()

        out = torch.empty(
            seq_len, batch_size, self.hidden_size, device=x.device, dtype=x.dtype
        )

        for t in range(seq_len):
            # layer 0
            ifgo = (
                torch.matmul(x[t], self.wi[0])
                + torch.matmul(h_n[0], self.wh[0])
                + self.bias[0]
            )

            i = ifgo[:, 0 : self.hidden_size]
            f = ifgo[:, self.hidden_size : self.hidden_size * 2]
            g = ifgo[:, self.hidden_size * 2 : self.hidden_size * 3]
            o = ifgo[:, self.hidden_size * 3 : self.hidden_size * 4]

            i = torch.sigmoid(i)
            f = torch.sigmoid(f)
            g = torch.tanh(g)
            o = torch.sigmoid(o)

            c_n[0] = f * c_n[0] + i * g
            h_n[0] = o * torch.tanh(c_n[0])

            # layer >= 1
            for layer_id in range(1, self.num_layers):
                ifgo = (
                    torch.matmul(h_n[layer_id - 1], self.wi[layer_id])
                    + torch.matmul(h_n[layer_id], self.wh[layer_id])
                    + self.bias[layer_id]
                )

                i = ifgo[:, 0 : self.hidden_size]
                f = ifgo[:, self.hidden_size : self.hidden_size * 2]
                g = ifgo[:, self.hidden_size * 2 : self.hidden_size * 3]
                o = ifgo[:, self.hidden_size * 3 : self.hidden_size * 4]

                i = torch.sigmoid(i)
                f = torch.sigmoid(f)
                g = torch.tanh(g)
                o = torch.sigmoid(o)

                c_n[layer_id] = f * c_n[layer_id] + i * g
                h_n[layer_id] = o * torch.tanh(c_n[layer_id])

            out[t] = h_n[self.num_layers - 1].clone()

        return out, h_n, c_n

5.4. 方差计算

方差定义(two-pass method):

\[\begin{aligned} \bar{x}_n &= \frac{1}{n} \sum_{i=1}^{n} x_i \\ \sigma^2_n &= \frac{1}{n} \sum_{i=1}^{n} (x_i - \bar{x}_n)^2 \end{aligned} \]

方差计算简化方式 (naive method, use one-pass):

\[\begin{aligned} \sigma^2_n &= \frac{1}{n} \sum_{i=1}^{n} x_i^2 - \bar{x}_n^2 \end{aligned} \]

5.4.1. Welford计算方差

Welford计算方差是用one-pass method,但误差远小于naive method。

均值的递推关系:

(1)\[\bar{x}_n = \frac{(n-1)\bar{x}_{n-1} + x_n }{n} = \bar{x}_{n-1} + \frac{x_n - \bar{x}_{n-1}}{n}\]

方差的递推关系:

\[\begin{aligned} & n \sigma_n^2 - (n-1) \sigma_{n-1}^2 \\ &= \sum_{i=1}^{n} (x_i - \bar{x}_n)^2 - \sum_{i=1}^{n-1} (x_i - \bar{x}_{n-1})^2 \\ &= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} \left( (x_i - \bar{x}_n)^2 - (x_i - \bar{x}_{n-1})^2 \right) \\ &= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} \left( (x_i - \bar{x}_n)^2 - (x_i - \bar{x}_{n-1})^2 \right) \\ &= (x_n - \bar{x}_n)^2 + \sum_{i=1}^{n-1} (2 x_i - \bar{x}_n - \bar{x}_{n-1}) (\bar{x}_{n-1} - \bar{x}_n) \\ &= (x_n - \bar{x}_n)^2 + (\bar{x}_n - x_n) (\bar{x}_{n-1} - \bar{x}_n) \\ &= (x_n - \bar{x}_n) (x_n - \bar{x}_{n-1}) \end{aligned} \]

定义 \(M_n\) (文献中通常定义成 \(M_{2,n}\) ) 如下:

\[M_{n} = \sum_{i=1}^n (x_i - \bar{x}_n)^2 \]

可以得到Welford算法如下:

(2)\[ \begin{aligned} M_{n} &= M_{n-1} + (x_n - \bar{x}_n) (x_n - \bar{x}_{n-1}) \\ \sigma_n^2 &= M_{n} / n \end{aligned}\]

\(x_n\) 偏离均值比较多的时候,\(x_n - \bar{x}_n\) 比较小,\(x_n - \bar{x}_{n-1}\) 比较大, 线性偏差。

python实现示例如下:

from typing import Tuple

import numpy as np


def welford_online_update(
    x: float, count: int, s: float, a: float
) -> Tuple[float, float]:
    """
    a: average of x_i
    s: M_{2,n}
    """
    b = a + (x - a) / count
    s += (x - b) * (x - a)

    return s, b


def calculate_mean_and_var(arr: np.ndarray) -> Tuple[float, float]:
    a: float = 0
    s: float = 0

    for i in range(arr.size):
        s, a = welford_online_update(arr[i], i + 1, s, a)

    return a, s / arr.size


if __name__ == "__main__":

    arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32)

    a, s = calculate_mean_and_var(arr)

    print(f"ref     average = {arr.mean()}, var = {arr.var()}")
    print(f"welford average = {a}, var = {s}")
    print()

5.4.2. c++实现示例

demo只考虑样本数是4的倍数。

c++实现示例 WelfordCpp, 支持x86 sse和arm neon指令。

5.4.3. 除法计算和Newton-Raphson iteration

均值递推关系 (1) 需要计算除法 1/n,但除法计算的延迟比较高。

Newton's method (Newton-Raphson method):

\[x_1 = x_0 - \frac{f(x_0)}{f^{\prime}(x_0)}\]

Newton–Raphson division is a fast method to calculate the reciprocal of a number \(a\). We can define \(f(x) = 1/x - a\) and thus \(f^{\prime}(x) = -1/x^2\).

Then Newton's iteration is:

\[\begin{aligned} x_{n+1} &= x_n - \frac{f(x_n)}{f^{\prime}(x_n)} \\ &= x_n - \frac{\frac{1}{x_n} - a}{-\frac{1}{x_n^2}} \\ &= x_n (2 - a x_n) \end{aligned} \]

为什么不选去其他函数,比如 \(f(x) = a x - 1\) ,主要是收敛性和收敛速度等决定的。

ARM neon实现示例如下:

float32x4_t fast_reciprocal(float32x4_t a) {
    float32x4_t recip = vrecpeq_f32(a);

    // Newton-Raphson iteration two times
    recip = vmulq_f32(recip, vrecpsq_f32(recip, a));
    recip = vmulq_f32(recip, vrecpsq_f32(recip, a));

    return recip;
}

vrecpsq_f32 就是计算 \(2.0 - a * x\)

待处理

cuda layernorm

5.4.4. 相关链接

参考1:

https://mp.weixin.qq.com/s/t0x782mDkMo-ZBVEbK8gPg

参考2:

https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/