【DeepSeek-模型解读】DeepSeek-V3模型特性之FP8混合精度训练
作者:昇腾实战派
DeepSeek知识地图:https://blog.csdn.net/weixin_45216014/article/details/156450562
1. 介绍
受此前低精度训练优势的启发,文章提出了使用FP8数据格式的细粒度混合精度框架用于DeepSeek-V3的训练。
尽管低精度训练有巨大的潜力,但经常被激活值、权重、梯度的异常值限制。虽然在推理量化上已经有重大进步,但是低精度技术在大规模语言模型预训练中的成功应用相对较少。
为了应对这一挑战并且有效地拓宽FP8格式的动态范围,文章引入了细粒度量化策略:切片分组和分块分组量化。在提高累加精度的过程中,相关的反量化开销得到大幅缓解,这是提高FP8 GEMM (Genaral Matrix Multiplication)精度的关键。此外,为了进一步减少MoE训练中的内存和通信开销,文章以FP8格式缓存和分发激活值,同时以BF16格式存储低精度的优化器状态。
文章在与DeepSeek-V2-Lite和DeepSeek-V2相似的两种模型规模上验证了FP8混合精度框架,训练了1T个tokens。值得注意的是,FP8训练模型的相对损失误差始终保持在0.25%以下,这一水平完全在训练随机性的可接受范围内。
论文链接:DeepSeek-V3 Technical Report
2. 混合精度框架
在FP8混合精度训练框架中,大多数计算密集型操作使用FP8,另一方面为了平衡训练的效率和稳定性,少量关键的算子仍然沿用原始的数据格式。框架总览如图1所示。
首先,为了加速模型训练,最主要的计算操作(GEMM,通用矩阵乘)使用FP8格式。这些GEMM操作的输入tensor是FP8,输出tensor是BF16 或者 FP32。如图1所示:有三处线性操作用到GEMM算子,分别命名为Fprop(前向,计算每个专家的输出)、Dgrad(反向,计算每个专家的输入数据梯度)、Wgrad(反向,计算每个专家的权重梯度),所有GEMM使用FP8执行。理论上这种设计的计算速度相比于之前BF16要快1倍。除此之外,Wgrad GEMM可以存储FP8激活值用于反向传播,这极大降低了内存消耗。
尽管FP8格式效率高,但是某些算子由于对低精度计算敏感,它们仍需要高精度数据格式。并且,一些低消耗的算子由于在整个训练过程的开销是微不足道的,他们也能使用高精度数据格式。因为上述原因,文章对以下组件保留了原始的精度(BF16或FP32): the embedding module, the output head, MoE gating modules, normalization operators, attention operators。这些针对性的高精度保留确保了DeekSeek-V3的训练稳定性。另外为了更好地保证数值稳定性,文章使用高精度存储 master weights(权重的全精度副本,FP32), weight gradients(FP32), and optimizer states(BF16)。尽管这些高精度组件会引入内存开销,但是可以在分布式训练中通过切分将数据分配到各张卡上使得影响最小化。
图1 FP8混合精度训练框架总览
3. 通过量化和乘法提高精度
基于FP8混合精度框架,文章引入几种提高低精度训练精度的策略,重点在量化方法和乘法过程。
图2
3.1 细粒度量化
在低精度训练框架中,由于FP8格式指数位减少,动态表示范围受限,因此上溢出和下溢出都是挑战。为了让输入分布和FP8格式的表示范围对齐,通常做法是将输入张量的最大绝对值缩放到FP8的最大可表示值。这种方法使得低精度训练对激活值中的异常值极为敏感,从而可能严重减低量化精度。为了解决这一问题,文章提出了一种细粒度的量化方法,该方法在更精细的层面上应用缩放。如图2的(a)部分所示:
(1)对于激活值,文章以1*128的切片为单位(即每个token的每128个通道)对元素进行分组和缩放;
(2)对于权重,文章以128*128的块为单位(即每128个输入通道和每128个输出通道)对元素进行分组和缩放。
这种方法通过根据更小的元素组调整缩放比例,确保量化过程能够更好地适应异常值。
上述方法的关键修改是引入GEMM算子沿内部维度的逐组缩放因子,这一功能在标准的FP8 GEMM中不直接支持。然而结合文章的精确FP32累加策略,它可以高效地实现。
3.2 提高累加精度
低精度GEMM操作常常会遇到下溢问题(隐藏层维度越大,加法次数越多,下溢问题越明显),GEMM精度很大程度上依赖于高精度累加,高精度累加通常以FP32精度进行。然而,文章观察到H800 GPUs上的FP8 GEMM的累加精度被限制在保留大约14位,这显著低于FP32的累加精度。当内部维度K较大时,这一问题将变得更加明显,而在大规模模型训练中,增加批量大小和模型宽度是典型场景。以K=4096的两个随机矩阵的GEMM操作为例,在文章的初步测试中,Tensor Cores中有限的累加精度导致了接近2%的最大相对误差。尽管存在这些问题,有限的累加精度仍然是一些FP8框架中的默认选项,严重限制了训练精度。
为了解决这一问题,文章采用了提升至CUDA Cores的策略以实现更高精度。方法如图2的Figure 7(b)所示。具体来说,在Tensor Cores上执行MMA(Matrix Multiply-Accumulate,矩阵乘加)时,中间结果使用有限的位宽进行累加。一旦达到NC的间隔,这些部分结果将被复制到CUDA Cores上的FP32寄存器中,并在那里执行全精度的FP32累加。如前所述,文章的细粒度量化沿内部维度K应用逐组缩放因子。这些缩放因子可以在CUDA Cores上高效地作为反量化过程进行乘法运算,且只需最小的额外计算成本。
值得注意的是,这一修改降低了单个warpgroup的WGMMA(Warpgroup-level Matrix Multiply-Accumulate)指令的发出率。然而,在H800架构上,通常会有两个WGMMA同时进行:当一个warpgroup执行提升操作时,另一个warpgroup能够执行MMA操作。这种设计实现了两种操作的重叠,从而保持了Tensor Cores的高利用率。根据文章的实验,将NC设置为128个元素(相当于4个WGMMA)是最小的累加间隔,可以在不引入显著开销的情况下显著提高精度。
3.3 尾数和指数
在以往使用混合精度FP8的工作中,Fprop采用E4M3(4位指数3位尾数),Dgrad和Wgrad采用E5M2(5位指数2位尾数)。相比之下,文章为了更高的精度对所有张量采用E4M3格式。这种方法的可行性归功于文章的细粒度量化策略,即切片和分块后缩放。通过操作更小的元素组,有效地在组内元素间共享指数位(少了异常值的影响,数据分布差异不会很大),从而减轻了有限动态范围的影响。(E4M3相比于E5M2的动态范围更有限)。
3.4 在线量化
张量级量化框架中采用了延迟量化的技术,该技术通过维护前几次迭代中的最大绝对值历史来推断当前值。为了保证准确的缩放比例并简化框架,文章在线计算每个1128激活切片或128128权重块的最大绝对值。基于此,文章推导出缩放因子,然后在线将激活值或权重量化为FP8格式。
4 低精度存储和通信
结合文章的FP8训练框架,文章通过将缓存的激活值和优化器状态压缩为更低精度的格式,进一步减少了内存消耗和通信开销。
4.1 低精度优化器状态
文章采用BF16数据格式而非FP32来跟踪AdamW优化器中的一阶和二阶动量,这样不会导致明显的性能下降。然而,主权重master weight(由优化器存储)和梯度(用于batch size累加)仍保留为FP32格式,以确保整个训练过程中的数值稳定性。
4.2 低精度激活值
如图1所示,Wgrad操作使用FP8格式。为了减少内存开销,将激活值以FP8格式缓存以用于线性算子的反向传播是一个自然的选择。然而,为了实现低开销的高精度训练,文章对一些操作进行了特殊考虑:
(1)注意力算子之后的线性层的输入:这些激活值会被用于反向的注意力算子,因此对精度是敏感的。文章专门对这些激活值采用了E5M6格式。此外,在反向传播时,这些激活值将从1128的量化切片转换为1281的切片。为避免引入额外的量化误差,所有量化因子是2的整数幂。
(2)MoE中SwiGLU算子的输入:为了进一步降低内存成本,文章缓存了SwiGLU算子的输入,并在反向传播过程中重新计算其输出。这些激活值同样用细粒度量化方法存储为FP8,在内存效率和计算准确性取得平衡。
4.3 低精度通信
通信带宽是MOE模型训练的一个关键瓶颈。为了缓解这一瓶颈,文章在MOE上投影(up-projections)之前将激活值量化为FP8,然后应用分发组件,这与MoE上投影中的FP8前向传播兼容。和“注意力算子之后的线性层的输入”一样,激活值的缩放因子是2的整数幂。类似的策略也被应用于MoE下投影之前的激活梯度。对于前向和反向的组合组件,文章将其保留为BF16格式,以在训练流程的关键部分保留训练精度。
5. 代码解读
5.1 DeepSeek-V3源码解读
源码地址 https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
5.1.1 代码结构
代码主要分为以下几个部分:
- 量化操作(act_quant_kernel 和 act_quant)
- 反量化操作(weight_dequant_kernel 和 weight_dequant)
- FP8矩阵乘法(fp8_gemm_kernel 和 fp8_gemm)
5.1.2 量化操作
act_quant_kernel 函数定义
- @triton.jit:这是一个装饰器,表示这个函数是一个Triton内核函数,将在GPU上执行。
- x_ptr:输入张量的指针,指向需要量化的数据。
- y_ptr:输出张量的指针,指向量化后的数据。
- s_ptr:输出张量的指针,指向缩放因子。
- BLOCK_SIZE:一个常量表达式,表示每个实例处理的数据块大小。
act_quant_kernel 函数主要功能是对输入张量进行量化操作。它首先计算当前数据块的最大绝对值(line15-18),然后根据这个最大值计算缩放因子(line18),接着将输入数据除以缩放因子得到量化后的数据(line19-20),最后将量化后的数据和缩放因子存储到输出张量中(line21-22)。这个过程确保了量化后的数据能够适应FP8格式的有限表示范围。
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
act_quant 函数定义
- x:输入张量,需要进行量化的数据。
- block_size:量化时使用的块大小,默认为128。
- 返回值:返回一个元组,包含量化后的张量(FP8格式)和缩放因子张量(FP32格式)。
act_quant 函数的主要功能是对输入张量x进行量化操作。它首先检查输入张量是否是连续的(Line14),并且最后一个维度的大小是否能被block_size整除(Line15)。然后,它创建用于存储量化结果和缩放因子的张量(Line16-17),并调用Triton内核函数 act_quant_kernel 进行实际的量化操作(Line18-19)。最后,返回量化后的张量和缩放因子张量。
- Line16:创建一个空张量y,用于存储量化后的数据。数据类型为fp8 e4m3格式。
- Line17:创建一个空张量s,用于存储缩放因子。s除了最后一个维度外,其他维度与x相同,最后一个维度的大小是x.size(-1) // block_size。数据类型为fp32。
- Line18:定义一个网格函数grid,用于动态计算Triton内核的执行网格大小。网格大小是输入张量总元素除以块大小的向上取整,确保所有数据都能被处理。
- Line19:调用Triton内核函数act_quant_kernel,并指定执行网格大小和参数。内核函数会根据网格大小启动多个线程块,并行处理出入张量x,并将量化结果写入y和s。[grid]是triton的语法,用于制定内核的执行网格大小。grid是一个函数,返回网格的大小。
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
5.1.3 反量化操作
weight_dequant_kernel 函数定义
- x_ptr:量化后的权重张量的指针,指向需要反量化的数据。
- s_ptr:缩放因子的指针,指向每个块的缩放因子。
- y_ptr:输出张量的指针,指向反量化后的权重数据。
- M:权重矩阵的行数。
- N:权重矩阵的列数。
- BLOCK_SIZE:一个常量表达式,表示每个线程块处理的数据块大小。
weight_dequant_kernel 对量化后的权重进行反量化操作(内核函数),函数体分析如下:
Line17:获取当前线程块在行维度(axis=0)的ID,用于处理行方向的数据。
Line18:获取当前线程块在列维度(axis=1)的ID,用于处理列方向的数据。
Line19:计算列方向上的块数,用于后续计算缩放因子的索引。
Line20:计算当前线程块的行偏移量。
Line21:计算当前线程块的列偏移量。
Line22:计算全局内存中的偏移量,offs表示当前线程块处理的每个元素在全局内存中的位置,形状为(BLOCK_SIZE, BLOCK_SIZE)。
Line23:这个掩码确保只处理不超过矩阵边界的元素,形状为(BLOCK_SIZE, BLOCK_SIZE)。
Line24:加载量化后的权重数据。
Line25:加载缩放因子。
Line26:反量化操作。
Line27:存储反量化后的权重数据。
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
weight_dequant 函数定义
- x:量化后的权重张量,形状为(M, N)。
- s:缩放因子张量,形状为(M, N)。
- BLOCK_SIZE:块大小,用于分块处理,默认为128。
- 返回值:返回量化后的权重张量,形状与x相同。
weight_dequant_kernel 对大规模权重矩阵进行并行反量化操作,函数体分析如下:
Line16:检查权重张量x和缩放因子张量s是否是连续的。
Line17:确保输入张量的形状是(M, N),符合权重矩阵的要求。
Line19:创建一个与输入张量x形状相同的空张量,数据类型为默认的浮点类型(通常是fp32)。
Line20:定义一个网格函数。返回一个元组(行方向的块数,列方向的块数)。
Line21-22:调用Triton内核函数weight_dequant_kernel,并行处理每个网格。反量化的权重存储在y,最后返回y。
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
5.1.4 FP8矩阵乘法
fp8_gemm_kernel 函数定义
@triton.autotune:自动调优装饰器,根据fp8_gemm_configs中的配置自动选择最优的内核参数。
- a_ptr:第一个输入矩阵A的指针。
- b_ptr:第二个输入矩阵B的指针。
- c_ptr:输出矩阵C的指针。
- a_s_ptr:矩阵A的缩放因子指针。
- b_s_ptr:矩阵B的缩放因子指针。
- M:矩阵A和C的行数。
- N:矩阵B和C的列数。
- K:矩阵A的列数和B的行数。
- BLOCK_SIZE_M:块大小,用于分块处理M维度。
- BLOCK_SIZE_N:块大小,用于分块处理N维度。
- BLOCK_SIZE_K:块大小,用于分块处理K维度。
fp8_gemm_kernel 并行处理FP8精度的大规模矩阵乘法操作。函数体分析如下:
- Line33-34:获取当前线程块在行维度和列维度的ID。
- Line35:计算隐藏层K维度的块数k。
- Line36-37:计算当前线程块在行方向和列方向的全局偏移量offs_m和offs_n,形状分别为(BLOCK_SIZE_M, )和(BLOCK_SIZE_N, )。
- Line38:块内的隐藏层维度偏移量offs_k。
- Line39-40:计算矩阵A和B的内存地址。
- Line41-42:计算缩放因子的内存地址。
- Line44:创建一个全零的张量,用于累加矩阵乘法的中间结果。形状为(BLOCK_SIZE_M, BLOCK_SIZE_N),数据类型为float32。
- Line45-54:分块计算矩阵乘法。
- Line45:遍历隐藏层维度k上的所有块。
- Line46-47:从矩阵A和矩阵B中加载当前块的数据,使用掩码确保只加载有效的元素。
- Line48-49:加载矩阵A和矩阵B的缩放因子。
- Line50:计算矩阵乘法的中间结果,并乘以缩放因子,累加到accumulate中。
- Line51-54:更新矩阵和缩放因子的内存地址,指向下一个块。
- Line55:将累加器的结果转换为输出矩阵C的数据类型。
- Line56-57:计算当前线程块在行方向和列方向的全局偏移量offs_m和offs_n。
- Line58:计算输出矩阵C中当前块的内存地址。
- Line59:创建掩码,确保只存储有效的元素。
- Line60:将结果存储到输出矩阵C中,使用掩码确保只存储有效的元素。
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
fp8_gemm 函数定义
- a:第一个输入矩阵,形状为(M, K)。
- a_s:矩阵A的缩放因子,形状为(M, )。
- b:第二个输入矩阵,形状为(N, K)。值得注意的是,在常规的矩阵乘法规则下,b矩阵的形状应该是(K, N)。但是下面源码18行把b矩阵的行数定义为N,因此在实际调用该函数时应确保b的形状为(N, K)。
- b_s:矩阵B的缩放因子,形状为(N, )。
- 返回值:返回矩阵乘法的结果,形状为(M, N)。
fp8_gemm 并行处理FP8精度的大规模矩阵乘法操作。函数体分析如下:
Line14-15:检查输入张量是否连续。确保输入张量在内存中按顺序存储可以高效进行内存访问。
Line16-18:获取矩阵的维度M, N, K。
Line19:创建输出张量,用于存储矩阵乘法的结果。
Line20:定义一个网格函数。
Line21:调用内核函数执行FP8精度的矩阵乘法操作。
Line22:返回矩阵乘法的结果。
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐
所有评论(0)