diff --git a/scripts/amd/gemm/README.md b/scripts/amd/gemm/README.md index 2151108b54be..fc00d47fba24 100644 --- a/scripts/amd/gemm/README.md +++ b/scripts/amd/gemm/README.md @@ -1,44 +1,164 @@ -# GEMM tuning script v2 +# GEMM tuning script (current v3.3) -This is the v2 version of the gemm tuning script, which is based on @scxiao's v1 (https://github.com/ROCmSoftwarePlatform/triton/pull/309) and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 +## matmul kernel -### Main features -- `rocprof` is used to measure the time for kernels in the full tuning space -- Each kernel is executed 10 times and the execution time of the last instance is used -- All kernels are compiled in parallel -- Two modes for correctness checking - - During tuning, check correctness with the best perf_config for the current gemm size - - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size -- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs -- Limitations - - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs +The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/matmul_kernel.py), which includes the following features: +- grouping order of workgroup id, which is controled by `GROUP_SIZE_M`, that +implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations). +- split-k algorithm, which is controled by `SPLIT_K`. +- Bias along M dim, which is controled by `BIAS` and `bias_ptr`. +- Masked load along K dim inside the loop, which is controled by `EVEN_K`. +This means `BLOCK_SIZE_K` does not need to divide K dim. -### Usage -Go to the script dir -```bash -cd triton/scripts/amd/gemm/ +### Differences between the tutorial + +Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial), +the matmul kernel used in the tuning script (referred as the kerel) does not +guard load along M and N dim +([this](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py#L282-L283) shows how this is done in the tutorial). +When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will +load out-of-bound data. +In most cases this is fine, since the kernel does masked store at the end. +However, this may lead to GPU memory access fault in some cases, especially +when the tensor is large. +We will fix this issue in the future. + + +## Tuning script usage + +### Tuning mode + +The tuning script can take one or more gemm sizes and run tuning for them. +The input gemm sizes are prepared in a yaml file. Here is an example yaml file: +```yaml +- {'M': 4864, 'N': 4096, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 512, 'N': 512, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'N'} ``` -1. Tune gemm sizes given in a yaml file and check correctness on the way -```bash -python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --compare +The tuning script works as follows +```python +./tune_gemm --gemm_size_file input.yaml [options] +``` +The following `options` are supported in the tuning mode + +- Input data types: + - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: input and output element type. + - Supported `dtype`: fp16 (default), bf16, fp8, bf8, int8, int32, fp32 +- Parallel compilation of kernels: + - `num_threads n` controls that n threads will + be used in the compilation stage. The default value is 32. + - `--no_warmup` can be used to skip the compilation stage. Thus kernels will be + compiled during the profiling stage. This increases tuning time. But it's + required for some old torch version, in which some function used in the warmup + kernel launch is not supported. +- Parallel profiling of kernels: The tuning space is first divided into a number +of tasks, which is controled by `--jobs n`. And all the tasks can be profiled in +parallel on a number of GPUs in the system. There are two ways to specify which +GPU(s) we want to use for profiling. Note that these flags cannot be use together. +By default, only one task is generated and profiled on GPU0. + - `--ngpus n`: GPU 0,1,.., n-1 will be used. + - `--gpu_ids ids`: `ids` are comma separated gpu ids and GPUs in `ids` will be used. +- General tuning control flags + - `--init_type INIT_TYPE` defines how input data are initialized. `INIT_TYPE` can be + - hpl: uniform distribution between -.5 and .5 + - trig_float: the distribution of elements in the flattened tensor follow + the `sin` function. + - zeros: initialize all data as 0, i.e. `torch.zeros` + - randn (default): normal distribution, i.e. `torch.randn` + - `--rotating_tensor SIZE`: provide the size of memory used for rotatin tensor. + The default is 0, meaning rotating tensor is not used. + - `--icahe_flush`: If true, the script will generate a kernel to flush i-cache. + The default is False. + - `--bias_vector`: If true, a bias vector along the M dim is applied. + The default is False. +- Correctness check + - `--compare` will check the correctness of the best config for each gemm size. + - `--compare_wo_tuning` will check the correctness of the config provided in + the yaml file. If this is set, user needs to provide all the parameters in + the input yaml file. Example can be found in the benchmark mode section. +- Logistics + - `--keep` can be used to keep the files generated during the tuning process. + Be default, intermediate files are removed at the end. + - `--time_breakdown`: If set, the script will print out elapsed time during + each stage of the tuning in real-time. The default is False. + - `--verbose` will enable more logging message than `--time_breakdown`, such + as output from rocprofv2 + - `--o OUTPUT` can be used to control the output filename to store the tuning + result. The default filename is `tuning_results_branchName@gitCommit_timeStamp.yaml`. + Therefore, each time the user runs the tuning script, a different output file + will be generated. + +Here are some example usages of running the script for tuning: + +Tune some gemm sizes with f16 input +```python +./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml ``` +It's recommended to use as many GPUs as possible and set `--jobs` to +a value that is 4 to 6 times the number of GPUs. -2. Tune a single gemm size -```bash -python tune_gemm.py -m 16 -n 16 -k 16 +If you are only allowed to use a subset of the GPUs, you can +```python +./tune_gemm --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml ``` +This runs the profiling on GPU 0,1,3,4. -3. Choose the file to store tuning results -```bash -python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml +For bf8 input +```python +./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8 ``` -4. Only check correctness given the tuning results -```bash -python tune_gemm.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +Check correctness of the tuned configs +```python +./tune_gemm --gemm_size_file output.yaml --compare_wo_tuning +``` + + +### Benchmark mode + +In benchmark mode, the script will run a single given config multiple times to +collect performance data. The benchmark mode works as +The tuning script works as follows +```python +./tune_gemm --gemm_size_file input.yaml [options] --benchmark ``` -Note that the tuning results file are provided as the `gemm_size_file` in this scenario. +The supported `options` are as followings +- `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode. +- `--iters n` controls the number of iterations to run the kernel. +The default value is 1000. + + +## Tuning script implementation overview + +The general idea of the tuning script can be summarized as +- Compile all the kernels in the tuning space in parallel. +- Divide the tuning space into tasks and invoke `rocprofv2` once per +task. This will save invocation overhead of the profiler. +- Profile tasks in parallel on multiple GPUs. + +For detailed implementation, please refer to the changelog of each version. + + +# Changelog + +## GEMM tuning script v1 + +Shucai (@scxiao) implemented the first version of gemm tuning script: https://github.com/ROCmSoftwarePlatform/triton/pull/309 + +## GEMM tuning script v2 + +This version is based on v1 and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 + +### Main features +- `rocprof` is used to measure the time for kernels in the full tuning space +- Each kernel is executed 10 times and the execution time of the last instance is used +- All kernels are compiled in parallel +- Two modes for correctness checking + - During tuning, check correctness with the best perf_config for the current gemm size + - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size +- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs +- Limitations + - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs ### Overview of implementations @@ -63,7 +183,7 @@ Workflow of the tuning process 5. Invoke `rocprof` on the generated script 6. Post process `results.csv` by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return. -# GEMM Tuning Script v3 +## GEMM Tuning Script v3 ### API changes @@ -89,66 +209,15 @@ This is necessary to keep each file "small" in terms of execution time. - Added error recovery. This helps when rocprof crashes in multi-processing mode. -### Example Usage - -Let's say we have an input yaml file, named `gemm_input.yaml`, that contains the following configs -```yaml -- {'M': 4864, 'N': 4096, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'} -- {'M': 8192, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N'} -``` -1. Tuning with bf8 input types with gpu 4,5,6,7, and save output to `output.yaml` -```bash -python ./tune_gemm.py --gemm_size_file gemm_input.yaml -dtype_a bf8 -dtype_b bf8 --gpu_ids 4,5,6,7 --o output.yaml -``` - -2. Check the correctness of the tuned configs -```bash -python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --compare_wo_tuning -``` - -3. Run benchmark of the tuned configs -```bash -python ./tune_gemm.py --gemm_size_file output.yaml -dtype_a bf8 -dtype_b bf8 --benchmark -``` -A sample output from `benchmark` looks like -```bash -Benchmarking gemm with bf8 inputs (peak tflops: 1298) -trans M N K TFLOPS Efficiency -NT 4864 4096 8192 841.22 65% -NT 8192 8192 8192 745.31 57% -``` - -# GEMM Tuning Script v3.1 +## GEMM Tuning Script v3.1 ### API changes - Added `matrix_instr_nonkdim` into the tuning space. Now we can tune mfma instruction size. -# One config running script - -`one_config.py` is a script that runs one given matmul config. -It is an interface to `tune_gemm.py` functionality and could be used for triton debugging. - -### Usage - -This script supports two methods to specify configuration parameters. - -Variant 1: Separate command line attributes. - -```bash -python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0 --matrix_instr_nonkdim 16 --kpack 2 -``` - -Variant 2: one-line config description. -This is how configs are printed by `tune_gemm.py` script - -```bash -python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16 -``` - -# GEMM Tuning Script v3.2 +## GEMM Tuning Script v3.2 ### API changes @@ -160,7 +229,8 @@ Rotating tensor and icache flush are to make perf numbers are closer to that in - Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, so each element of the bias vector is added to all elements of the corresponding row of the output matrix.) -# GEMM Tuning Script v3.3 + +## GEMM Tuning Script v3.3 ### API changes @@ -201,3 +271,26 @@ that cannot divide `K`. - Tuning result file is open and closed inside the tuning loop, enabling timely flush of the tuning results. - Now we use `rocprofv2` to measure kernel time. + + +# One config running script + +`one_config.py` is a script that runs one given matmul config. +It is an interface to `tune_gemm.py` functionality and could be used for triton debugging. + +## Usage + +This script supports two methods to specify configuration parameters. + +Variant 1: Separate command line attributes. + +```bash +python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0 --matrix_instr_nonkdim 16 --kpack 2 +``` + +Variant 2: one-line config description. +This is how configs are printed by `tune_gemm.py` script + +```bash +python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16 +``` diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index c1d356074898..f460e6ec737d 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -572,7 +572,7 @@ def parse_args(): parser.add_argument("--compare_wo_tuning", action='store_true', default=False, - help="Whether check result correctness") + help="Whether check result correctness without tuning.") parser.add_argument("--benchmark", action='store_true', default=False, @@ -596,11 +596,11 @@ def parse_args(): parser.add_argument("--jobs", type=int, default=1, - help="number of generated files") + help="number of tasks during the profiling process") parser.add_argument("--iters", type=int, default=1000, - help="number of generated files") + help="number of iterations used in --benchmark mode") parser.add_argument( "--init_type", type=str, @@ -626,7 +626,7 @@ def parse_args(): parser.add_argument("--no_warmup", action='store_true', default=False, - help="Do not call the warmup kernel") + help="Whether we want to skip the compilation stage") args = parser.parse_args() if not args.o: if args.benchmark: