Insall the kernels using the following commands:
git clone https://github.com/IST-DASLab/gemm_fp8.git
cd gemm_fp8
pip install -e . # or pip install .
Then, the kernel can be used as follows:
import torch
import gemm_fp8
y = gemm_fp8.matmul(a, b, alpha=1.0)
where a
and b
are the input matrices (in torch.float8_e4m3fn
format) and alpha
is the scaling factor (in float
).
Run the following command to benchmark the kernel:
python benchmark.py