转载

OpenAI研究 Triton 简介:用于神经网络的开源 GPU 编程

我们发布了 Triton 1.0,这是一种类似于 Python 的开源编程语言,它使没有 CUDA 经验的研究人员能够编写高效的 GPU 代码——大部分时间与专家能够编写的代码相当。

为什么重要

Triton 可以用相对较少的努力达到最高的硬件性能;例如,它可用于编写 FP16 矩阵乘法内核,其性能与 cuBLAS 的性能相匹配——这是许多 GPU 程序员无法做到的——不到 25 行代码。我们的研究人员已经使用它来生成比等效的 Torch 实现效率高出 2 倍的内核,我们很高兴与社区合作,让每个人都能更轻松地使用 GPU 编程。

深度学习领域的新颖研究思想通常是使用原生框架运算符的组合来实现的。虽然方便,但这种方法通常需要创建(和/或移动)许多临时张量,这可能会损害神经网络的大规模性能。这些问题可以通过编写专门的 GPU 内核来缓解,但由于 GPU 编程的许多复杂性,这样做可能会非常困难。1 ,23个而且,尽管最近出现了各种系统45个 为了使这个过程更容易,我们发现它们要么过于冗长,缺乏灵活性,要么生成代码的速度明显慢于我们手动调整的基线。这促使我们扩展和改进 Triton6个,一种最新的语言和编译器,其原始创建者现在在 OpenAI 工作。

GPU编程的挑战

现代 GPU 的架构可大致分为三个主要组件——DRAM、SRAM 和 ALU——在优化 CUDA 代码时必须考虑其中的每一个:

  • 来自 DRAM 的内存传输必须 合并 为大型事务,以利用现代内存接口的大总线宽度。
  • 数据必须在重新使用之前手动存储到 SRAM 中,并进行管理以最大限度地减少检索时共享内存库的冲突。
  • 计算必须在流式多处理器 (SM) 之间和内部仔细划分和调度,以促进指令/线程级并行性并利用专用 ALU(例如,张量核心)。
图形处理器架构

GPU 的基本架构。

推理所有这些因素可能具有挑战性,即使对于具有多年经验的经验丰富的 CUDA 程序员也是如此。Triton 的目的是完全自动化这些优化,以便开发人员可以更好地专注于他们的并行代码的高级逻辑。Triton 的目标是广泛适用,因此不会自动安排跨 SM 的工作——将一些重要的算法考虑(例如平铺、SM 间同步)留给开发人员自行决定。

CUDA特里顿
内存合并手动的自动的
共享内存管理手动的自动的
调度(在 SM 内)手动的自动的
调度(跨 SM)手动的手动的

CUDA 与 Triton 中的编译器优化。

编程模型

program_id在所有可用的领域特定语言和 JIT 编译器中,Triton 可能与 Numba 最相似:内核被定义为经过修饰的 Python 函数,并在所谓的 实例网格上与不同的同时启动 。然而,如下面的代码片段所示,相似之处就此止步:Triton 通过对 (尺寸为 2 的幂的小数组)而不是单指令多线程 (SIMT) 的操作公开了实例内并行性7 执行模型。 在这样做时,Triton 有效地抽象掉了与 CUDA 线程块内的并发相关的所有问题 (例如,内存合并、共享内存同步/冲突、张量核心调度)。

BLOCK = 512

# This is a GPU kernel in Numba.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
   # In Numba/CUDA, each kernel 
   # instance itself uses an SIMT execution
   # model, where instructions are executed in
   # parallel for different values of threadIdx
   tid = threadIdx.x
   bid = blockIdx.x
   # scalar index
   idx = bid * BLOCK + tid
   if id < N:
     # There is no pointer in Numba.
     # Z,X,Y are dense tensors
     Z[idx] = X[idx] + Y[idx]


...
grid = (ceil_div(N, BLOCK),)
block = (BLOCK,)
add[grid, block](x, y, z, x.shape[0])
BLOCK = 512

# This is a GPU kernel in Triton.
# Different instances of this
# function may run in parallel.
@jit
def add(X, Y, Z, N):
   # In Triton, each kernel instance
   # executes block operations on a
   # single thread: there is no construct
   # analogous to threadIdx
   pid = program_id(0)
   # block of indices
   idx = pid * BLOCK + arange(BLOCK)
   mask = idx < N
   # Triton uses pointer arithmetics  
   # rather than indexing operators
   x = load(X + idx, mask=mask)
   y = load(Y + idx, mask=mask)
   store(Z + idx, x + y, mask=mask)


...
grid = (ceil_div(N, BLOCK),)
# no thread-block
add[grid](x, y, z, x.shape[0])

Triton 中的矢量加法。

虽然这对于令人尴尬的并行(即逐元素)计算可能不是特别有用,但它可以大大简化更复杂的 GPU 程序的开发。

例如,考虑融合 softmax 内核(下图)的情况,其中每个实例对给定输入张量X ∈R M × N的不同行进行归一化。这种并行化策略的标准 CUDA 实现可能难以编写,需要线程之间的显式同步,因为它们同时减少 X的同一行。Triton 消除了大部分这种复杂性,其中每个内核实例加载感兴趣的行并使用类似 NumPy 的原语对其进行顺序规范化。

import triton
import triton.language as tl

@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
    # row index
    m = tl.program_id(0)
    # col indices
    # this specific kernel only works for matrices that 
    # have less than BLOCK_SIZE columns
    BLOCK_SIZE = 1024
    n = tl.arange(0, BLOCK_SIZE)
    # the memory address of all the elements
    # that we want to load can be computed as follows
    X = X + m * stride_xm + n * stride_xn
    # load input data; pad out-of-bounds elements with 0 
    x = tl.load(X, mask=n < N, other=-float('inf'))
    # compute numerically-stable softmax
    z = x - tl.max(x, axis=0)
    num = tl.exp(z)
    denom = tl.sum(num, axis=0)
    y = num / denom
    # write back to Y
    Y = Y + m * stride_ym + n * stride_yn
    tl.store(Y, y, mask=n < N)

import torch
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)
# SPMD launch grid
grid = (X.shape[0], )
# enqueue GPU kernel
softmax[grid](Y, Y.stride(0), Y.stride(1), 
              X, X.stride(0), X.stride(1),
              X.shape[0]    , X.shape[1])
无效的

请注意,Triton JIT 将 X 和 Y 视为 指针 而不是张量;我们觉得保留对内存访问的低级控制对于处理更复杂的数据结构(例如,块稀疏张量)很重要。

 重要的是,softmax 的这种特殊实现在整个规范化过程中将X的行保留 在 SRAM 中,这在适用时最大化了数据重用(~<32K 列)。这与 PyTorch 的内部 CUDA 代码不同,后者使用临时内存使其更通用但速度明显较慢(下图)。这里的底线并不是 Triton 本身就更好,而是它简化了专用内核的开发,这些内核比通用库中的内核要快得多。

chart = RuntimeError: 获取失败
M=4096 的融合 softmax 的 A100 性能。

Torch (v1.9) JIT 的较低性能凸显了从高级张量操作序列自动生成 CUDA 代码的难度。

将 softmax 与 Torch JIT 融合。

@torch.jit.script
def softmax(x):
    x_max = x.max(dim=1)[0]
    z = x - x_max[:, None]
    numerator = torch.exp(x)
    denominator = numerator.sum(dim=1)
    return numerator / denominator[:, None]
无效的

矩阵乘法

能够为逐元素操作和归约编写融合内核很重要,但考虑到神经网络中矩阵乘法任务的重要性,这还不够。事实证明,Triton 也非常适合这些,仅需约 25 行 Python 代码即可达到最佳性能。另一方面,在 CUDA 中实现类似的东西会花费 更多的精力 ,甚至可能会降低性能。

Triton 中的矩阵乘法。

@triton.jit
def matmul(A, B, C, M, N, K, stride_am, stride_ak, 
            stride_bk, stride_bn, stride_cm, stride_cn,
            **META):
    # extract metaparameters
    BLOCK_M, GROUP_M = META['BLOCK_M'], META['GROUP_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    # programs are grouped together to improve L2 hit rate
    _pid_m = tl.program_id(0)
    _pid_n = tl.program_id(1)
    pid_m = _pid_m // GROUP_M
    pid_n = (_pid_n * GROUP_M) + (_pid_m % GROUP_M)
    # rm (resp. rn) denotes a range of indices
    # for rows (resp. col) of C
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    # rk denotes a range of indices for columns 
    # (resp. rows) of A (resp. B)
    rk = tl.arange(0, BLOCK_K)
    # the memory addresses of elements in the first block of
    # A and B can be computed using numpy-style broadcasting
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk [:, None] * stride_bk  + rn[None, :] * stride_bn)
    # initialize and iteratively update accumulator
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A)
        b = tl.load(B)
        # block level matrix multiplication
        acc += tl.dot(a, b)
        # increment pointers so that the next blocks of A and B
        # are loaded during the next iteration
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    # fuse leaky ReLU if desired
    # acc = tl.where(acc >= 0, acc, alpha * acc)
    # write back result
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm[:, None] < M) & (rn[None, :] < N)
    tl.store(C, acc, mask=mask)
无效的

手写矩阵乘法内核的一个重要优势是它们可以根据需要进行定制,以适应其输入(例如,切片)和输出(例如,Leaky ReLU)的融合变换。如果没有像 Triton 这样的系统,对于没有特殊 GPU 编程专业知识的开发人员来说,矩阵乘法内核的重要修改将是遥不可及的。

chart = RuntimeError: 获取失败

矩阵乘法的 V100 张量核心性能,具有适当调整的 BLOCK 值米的, 阻止否的, 阻止钾的, 集团米的.

高层系统架构

Triton 的良好性能来自于以 Triton-IR 为中心的模块化系统架构,这是一种基于 LLVM 的中间表示,其中多维值块是一等公民。

Python
海卫一红外
LLVM-IR
PTX
@jit
def add(X, Y, Z, N):
   pid = program_id(0)
   idx= pid * 512 + arange(512)
   mask = idx < N
   x = load(X + idx, mask=mask)
   y = load(Y + idx, mask=mask)
   store(Z + idx, x + y, mask=mask)
def void add(i32* X .aligned(16) , i32* Y .aligned(16) , i32* Z .aligned(16) , i32 N .multipleof(2) )
{
entry:
  %0 = get_program_id[0] i32;
  %1 = mul i32 %0, 512;
  %3 = make_range[0 : 512] i32<512>;
  %4 = splat i32<512> %1;
  %6 = add i32<512> %4, %3;
  %9 = splat i32<512> N;
  %11 = icmp_slt i1<512> %6, %9;
  %14 = splat i32*<512> X;
  %16 = getelementptr i32*<512> %14, %6;
  %19 = broadcast i1<512> %11;
  %21 = splat i32<512> undef;
  %22 = masked_load i32<512> %16, %19, %21;
  %26 = splat i32*<512> Y;
  %28 = getelementptr i32*<512> %26, %6;
  %31 = broadcast i1<512> %11;
  %33 = splat i32<512> undef;
  %34 = masked_load i32<512> %28, %31, %33;
  %38 = splat i32*<512> Z;
  %40 = getelementptr i32*<512> %38, %6;
  %43 = add i32<512> %22, %34;
  %46 = broadcast i32<512> %43;
  %48 = broadcast i1<512> %11;
  masked_store void %40, %46, %48;
  ret void;
}
.visible .entry add(
    .param .u64 add_param_0, .param .u64 add_param_1,
    .param .u64 add_param_2, .param .u32 add_param_3
)
.maxntid 128, 1, 1
{
    .reg .pred     %p<4>;
    .reg .b32     %r<18>;
    .reg .b64     %rd<8>;
    ld.param.u64     %rd4, [add_param_0];
    ld.param.u64     %rd5, [add_param_1];
    mov.u32     %r13, %tid.x;
    ld.param.u32     %r14, [add_param_3];
    shl.b32     %r15, %r13, 2;
    mov.u32     %r16, %ctaid.x;
    mad.lo.s32     %r17, %r16, 512, %r15;
    setp.ge.s32     %p3, %r17, %r14;
    setp.lt.s32     %p1, %r17, %r14;
    mul.wide.s32     %rd7, %r17, 4;
    add.s64     %rd2, %rd4, %rd7;
    @%p1 ld.global.cg.v4.b32 {%r5,%r6,%r7,%r8}, [ %rd2 + 0];
    add.s64     %rd3, %rd5, %rd7;
    @%p1 ld.global.cg.v4.b32 {%r9,%r10,%r11,%r12}, [ %rd3 + 0];
    @%p3 bra     LBB0_2;
    ld.param.u64     %rd6, [add_param_2];
    add.s64     %rd1, %rd6, %rd7;
    add.s32     %r1, %r5, %r9;
    add.s32     %r2, %r6, %r10;
    add.s32     %r3, %r7, %r11;
    add.s32     %r4, %r8, %r12;
    st.global.v4.u32     [%rd1], {%r1, %r2, %r3, %r4};
LBB0_2:
    ret;
}
Triton 的高级架构。

装饰 @triton.jit 器通过遍历提供的 Python 函数的抽象语法树 (AST) 来工作,以便使用常见的 SSA 构造算法即时生成 Triton-IR。8个 生成的 IR 代码随后由我们的编译器后端进行简化、优化和自动并行化,然后再转换为高质量的 LLVM-IR——最终是 PTX——以便在最新的 NVIDIA GPU 上执行。目前不支持 CPU 和 AMD GPU,但我们欢迎旨在解决此限制的社区贡献。

编译器后端

我们发现通过 Triton-IR 使用块程序表示允许我们的编译器自动执行各种重要的程序优化。例如,通过查看计算密集型块级操作(例如,  tl.dot)的操作数,可以将数据自动存储到共享内存中,并使用标准的活性分析技术进行分配/同步。

Triton 编译器通过分析计算密集型操作中使用的块变量的有效范围来分配共享内存。

另一方面,Triton 程序可以通过同时执行不同的内核实例来高效自动地并行化(1)跨 SM,以及(2)在 SM 内通过分析每个块级操作的迭代空间并将其充分划分到不同的 SIMD单位,如下图。

逐元素
S1 float A[4,4] = ...
S2 float B[4,4] = ...
S3 float C[4,4] = A + B
FP16 矩阵乘法。乘法
S1 half A[4,2] = ...
S2 half B[2,2] = ...
S3 float C[4,2] = dot(A,B)
矢量化
张量化
SM
显卡
  1. 由三个语句组成的Triton 程序PS1的定义, S2S3
  1. 的迭代空间S3
  1. 映射S3到流式多处理器 (SM)
  1. 将P映射到 GPU
Triton 中的自动并行化。每个块级操作都定义了一个块迭代空间,该空间自动并行化以利用流式多处理器 (SM) 上可用的资源。

贡献

我们打算让 Triton 成为一个社区驱动的项目。请随意在 GitHub上创建我们的存储库!

如果您有兴趣加入我们的团队并致力于 Triton 和 GPU 内核, 我们正在招聘

参考

  1. Yan, D.、Wang, W. 和 Chu, X.(2020 年 5 月)。 揭秘张量核心以优化半精度矩阵乘法。2020年 IEEE 国际并行和分布式处理研讨会 (IPDPS)。IEEE。↩︎

  2. Tillet, P.、Kung, HT 和 Cox, D.(2019 年 6 月)。 Triton:一种用于平铺神经网络计算的中间语言和编译器。在 第三届 ACM SIGPLAN 机器学习和编程语言国际研讨会论文集 (第 10-19 页)中。↩︎

  3. Braun, M.、Buchwald, S.、Hack, S.、Leißa, R.、Mallon, C. 和 Zwinkau, A.(2013 年 3 月)。 简单高效的静态单赋值形式构造。在 国际编译器构造会议 (第 102-122 页)中。斯普林格,柏林,海德堡。↩︎

作者

致谢

大研(科大)、DeepSpeed(微软)、Anthropic