|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
from torch import nn |
|
|
from triton import Config |
|
|
from typing import Any, Optional |
|
|
|
|
|
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]: |
|
|
""" |
|
|
Helper function to get constant values for the current platform. |
|
|
|
|
|
Returns: |
|
|
pt_dtype (torch.dtype): The correct torch fp8 datatype. |
|
|
tl_dtype (tl.dtype): The correct triton fp8 datatype. |
|
|
max_fp8 (float): The maximum reprsentable value for the fp8 datatype. |
|
|
eps (float): Minimum clip value to prevent divide by zero. |
|
|
""" |
|
|
pt_fp8_dtype = torch.float8_e4m3fn |
|
|
tl_fp8_dtype = tl.float8e4nv |
|
|
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 |
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
Config({"BLOCK_SIZE": 512}), |
|
|
Config({"BLOCK_SIZE": 1024}), |
|
|
Config({"BLOCK_SIZE": 2048}), |
|
|
Config({"BLOCK_SIZE": 4096}), |
|
|
Config({"BLOCK_SIZE": 8192}), |
|
|
], |
|
|
key=["K"], |
|
|
) |
|
|
@triton.jit |
|
|
def _kernel_quantize_fp8_row( |
|
|
A, |
|
|
A_scale, |
|
|
A_fp8, |
|
|
scale_ub, |
|
|
zero_start_index_M, |
|
|
B, |
|
|
M, |
|
|
N, |
|
|
K, |
|
|
K_fp8, |
|
|
stride_ab, |
|
|
stride_am, |
|
|
stride_an, |
|
|
stride_ak, |
|
|
stride_ob, |
|
|
stride_om, |
|
|
stride_on, |
|
|
stride_ok, |
|
|
stride_zb, |
|
|
stride_zm, |
|
|
TL_FP8_DTYPE: tl.constexpr, |
|
|
MAX_FP8: tl.constexpr, |
|
|
EPS: tl.constexpr, |
|
|
CLAMP_MAX: tl.constexpr, |
|
|
JAGGED: tl.constexpr, |
|
|
BLOCK_SIZE: tl.constexpr, |
|
|
USE_INT64: tl.constexpr, |
|
|
) -> None: |
|
|
"""Quantize and scale each row. |
|
|
|
|
|
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :])) |
|
|
|
|
|
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles |
|
|
in a max pass then scale/quantize pass. |
|
|
|
|
|
Todo: |
|
|
* Better tiling schemes. |
|
|
|
|
|
Args: |
|
|
A (Tensor): higher precision input tensor of 4 dimension. |
|
|
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row. |
|
|
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale |
|
|
scale_ub (Tensor): [1] Maximum value allowed for scale. |
|
|
B (int): Size of dimenion 0 |
|
|
M (int): Size of dimenion 1 |
|
|
N (int): Size of dimenion 2 |
|
|
K (int): Size of dimenion 3 (input row size) |
|
|
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K) |
|
|
stride_ab (int): Stride of b dimension of A. |
|
|
stride_am (int): Stride of m dimension of A. |
|
|
stride_an (int): Stride of n dimension of A. |
|
|
stride_ak (int): Stride of k dimension of A. |
|
|
stride_ob (int): Stride of b dimension of output. |
|
|
stride_om (int): Stride of m dimension of output. |
|
|
stride_on (int): Stride of n dimension of output. |
|
|
stride_ok (int): Stride of k dimension of output. |
|
|
stride_zb (int): Stride of b dimension of jagged index. |
|
|
stride_zm (int): Stride of m dimension of jagged index. |
|
|
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype. |
|
|
MAX_FP8 (float): Maxmimum expressible value for FP8. |
|
|
EPS (float): Epsilon value for numerical stability. |
|
|
CLAMP_MAX (bool): Whethar to apply scale_ub. |
|
|
JAGGED (bool): Whether to use jagged indexing. |
|
|
BLOCK_SIZE (int): Block size for reduction. |
|
|
USE_INT64 (bool): Whether to use int64 indexing for large inputs. |
|
|
""" |
|
|
pid = tl.program_id(0) |
|
|
|
|
|
|
|
|
if USE_INT64: |
|
|
pid = pid.to(tl.int64) |
|
|
n_offset = tl.arange(0, BLOCK_SIZE) |
|
|
a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an |
|
|
a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on |
|
|
|
|
|
K_in = K |
|
|
if JAGGED: |
|
|
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm |
|
|
group_rows = tl.load(zero_start_index_M + z_offset_base) |
|
|
current_row = pid % N |
|
|
|
|
|
if current_row >= group_rows: |
|
|
K_in = 0 |
|
|
|
|
|
|
|
|
cur_max = 0.0 |
|
|
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)): |
|
|
a = tl.load( |
|
|
A + a_offset_base + n_offset * stride_ak, |
|
|
mask=n_offset < K_in, |
|
|
other=0.0, |
|
|
) |
|
|
tile_max = tl.max(tl.abs(a)) |
|
|
cur_max = tl.maximum(tile_max, cur_max) |
|
|
n_offset += BLOCK_SIZE |
|
|
|
|
|
if CLAMP_MAX: |
|
|
ub = tl.load(scale_ub) |
|
|
cur_max = tl.clamp(cur_max, EPS, ub) |
|
|
else: |
|
|
cur_max = tl.maximum(cur_max, EPS) |
|
|
|
|
|
a_scale = MAX_FP8 / cur_max |
|
|
tl.store(A_scale + pid, 1.0 / a_scale) |
|
|
n_offset = tl.arange(0, BLOCK_SIZE) |
|
|
|
|
|
|
|
|
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)): |
|
|
|
|
|
a = tl.load( |
|
|
A + a_offset_base + n_offset * stride_ak, |
|
|
mask=n_offset < K_in, |
|
|
other=0.0, |
|
|
) |
|
|
|
|
|
a_fp8 = a * a_scale |
|
|
|
|
|
|
|
|
|
|
|
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE) |
|
|
|
|
|
|
|
|
tl.store( |
|
|
A_fp8 + a_fp8_offset_base + n_offset * stride_ok, |
|
|
a_fp8, |
|
|
mask=n_offset < K_fp8, |
|
|
) |
|
|
n_offset += BLOCK_SIZE |
|
|
|
|
|
|
|
|
def quantize_fp8_per_row( |
|
|
a: torch.Tensor, |
|
|
scale_ub: Optional[torch.Tensor] = None, |
|
|
zero_start_index_M: Optional[torch.Tensor] = None, |
|
|
align_rows_to: Optional[int] = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings. |
|
|
|
|
|
Args: |
|
|
a (Tensor): higher precision input tensor of 4 dimension. |
|
|
scale_ub (Tensor): Maximum allowed value for scale. |
|
|
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row. |
|
|
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16) |
|
|
Returns: |
|
|
torch.Tensor: fp8 scaled tensor. |
|
|
torch.Tensor: reciprocal scale tensor per row. |
|
|
""" |
|
|
|
|
|
if a.device.type == "meta": |
|
|
pt_dtype, _, _, _ = get_fp8_constants() |
|
|
a_shape = list(a.shape) |
|
|
if align_rows_to is not None: |
|
|
last_dim = a_shape[-1] |
|
|
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to |
|
|
a_shape[-1] = padded_last_dim |
|
|
|
|
|
|
|
|
return ( |
|
|
torch.empty(a_shape, device="meta", dtype=pt_dtype), |
|
|
torch.empty(a_shape[:-1], device="meta", dtype=torch.float32) |
|
|
) |
|
|
|
|
|
if scale_ub is not None and scale_ub.device != a.device: |
|
|
raise Exception("'scale_ub' must be on the same device as 'a'") |
|
|
if zero_start_index_M is not None and zero_start_index_M.device != a.device: |
|
|
raise Exception("'zero_start_index_M' must be on the same device as 'a'") |
|
|
|
|
|
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor." |
|
|
a_shape = a.shape |
|
|
while a.dim() < 4: |
|
|
a = a.unsqueeze(0) |
|
|
if zero_start_index_M is not None: |
|
|
|
|
|
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1]) |
|
|
|
|
|
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants() |
|
|
num_rows = a.numel() // a.shape[-1] |
|
|
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device) |
|
|
|
|
|
if align_rows_to is not None: |
|
|
last_dim = a.shape[-1] |
|
|
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to |
|
|
a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype) |
|
|
a_shape = torch.Size((*a_shape[:-1], padded_last_dim)) |
|
|
else: |
|
|
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype) |
|
|
|
|
|
|
|
|
use_int64 = a.numel() > (2**31 - 1) |
|
|
grid = (num_rows,) |
|
|
_kernel_quantize_fp8_row[grid]( |
|
|
a, |
|
|
a_scale, |
|
|
a_fp8, |
|
|
scale_ub, |
|
|
zero_start_index_M, |
|
|
a.shape[0], |
|
|
a.shape[1], |
|
|
a.shape[2], |
|
|
a.shape[3], |
|
|
a_fp8.shape[3], |
|
|
a.stride(0), |
|
|
a.stride(1), |
|
|
a.stride(2), |
|
|
a.stride(3), |
|
|
a_fp8.stride(0), |
|
|
a_fp8.stride(1), |
|
|
a_fp8.stride(2), |
|
|
a_fp8.stride(3), |
|
|
(zero_start_index_M.stride(0) if zero_start_index_M is not None else None), |
|
|
(zero_start_index_M.stride(1) if zero_start_index_M is not None else None), |
|
|
TL_FP8_DTYPE=tl_dtype, |
|
|
MAX_FP8=max_fp8, |
|
|
EPS=eps, |
|
|
CLAMP_MAX=scale_ub is not None, |
|
|
JAGGED=zero_start_index_M is not None, |
|
|
USE_INT64=use_int64, |
|
|
) |
|
|
|
|
|
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1]) |
|
|
|