danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
be136a5 verified
raw
history blame contribute delete
496 Bytes
import torch
from ._ops import ops
def gemm_4bit_forward(
input: torch.Tensor,
weight: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: int,
) -> torch.Tensor:
original_dtype = input.dtype
if original_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
output = ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type)
if original_dtype != torch.bfloat16:
output = output.to(original_dtype)
return output