From e979474decf5bfb4ce8c0292bae9016bf3f5cfbb Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 22 Jul 2024 22:27:01 -0500 Subject: [PATCH] Better README again --- scripts/amd/gemm/README.md | 30 ++++++++++++++++++++++++------ scripts/amd/gemm/tune_gemm.py | 29 +++++++++++++++++------------ 2 files changed, 41 insertions(+), 18 deletions(-) mode change 100644 => 100755 scripts/amd/gemm/tune_gemm.py diff --git a/scripts/amd/gemm/README.md b/scripts/amd/gemm/README.md index 38bd21206d71..0abb5e410821 100644 --- a/scripts/amd/gemm/README.md +++ b/scripts/amd/gemm/README.md @@ -37,7 +37,7 @@ The input gemm sizes are prepared in a yaml file. Here is an example yaml file: The tuning script works as follows ```python -./tune_gemm --gemm_size_file input.yaml [options] +./tune_gemm.py --gemm_size_file input.yaml [options] ``` The following `options` are supported in the tuning mode @@ -87,30 +87,40 @@ By default, only one task is generated and profiled on GPU0. result. The default filename is `tuning_results_branchName@gitCommit_timeStamp.yaml`. Therefore, each time the user runs the tuning script, a different output file will be generated. +- Hacks + - `--hack_triton_compiler`: If set, the triton source code will be modified + to provide a static backend target so that the compiler will not query + GPU information. This makes sure that during the compilation stage, no + hip runtime kernels are launched. + Note that this is a very hacky option, because + - It modifies the triton compiler directly, which is located from + `pip show triton`. + - It does string match and replace to modify the code. + - It does not restore the code when the tuning session terminates. Here are some example usages of running the script for tuning: Tune some gemm sizes with f16 input ```python -./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml ``` It's recommended to use as many GPUs as possible and set `--jobs` to a value that is 4 to 6 times the number of GPUs. If you are only allowed to use a subset of the GPUs, you can ```python -./tune_gemm --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml +./tune_gemm.py --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml ``` This runs the profiling on GPU 0,1,3,4. For bf8 input ```python -./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8 +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8 ``` Check correctness of the tuned configs ```python -./tune_gemm --gemm_size_file output.yaml --compare_wo_tuning +./tune_gemm.py --gemm_size_file output.yaml --compare_wo_tuning ``` @@ -120,7 +130,7 @@ In benchmark mode, the script will run a single given config multiple times to collect performance data. The benchmark mode works as The tuning script works as follows ```python -./tune_gemm --gemm_size_file input.yaml [options] --benchmark +./tune_gemm.py --gemm_size_file input.yaml [options] --benchmark ``` The supported `options` are as followings - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode. @@ -273,6 +283,14 @@ that cannot divide `K`. - Tuning result file is open and closed inside the tuning loop, enabling timely flush of the tuning results. - Now we use `rocprofv2` to measure kernel time. +- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation +stage. This is achieved by modifying the triton frontend compiler in the following +places: + - Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433) + - Return statically constructed GPUTarget from the `get_current_target()` + function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437) + - Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383) + - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) # One config running script diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py old mode 100644 new mode 100755 index 4ffe8ab2a3ce..3fdd7da082b5 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -1,4 +1,5 @@ -# fp8 +#!/usr/bin/env python3 + import argparse import sys import yaml @@ -569,10 +570,11 @@ def parse_args(): action='store_true', default=False, help="Whether check result correctness") - parser.add_argument("--compare_wo_tuning", - action='store_true', - default=False, - help="Whether check result correctness without tuning.") + parser.add_argument( + "--compare_wo_tuning", + action='store_true', + default=False, + help="Whether check result correctness without tuning.") parser.add_argument("--benchmark", action='store_true', default=False, @@ -605,16 +607,15 @@ def parse_args(): "--init_type", type=str, default='randn', - help= - "Initialization type for input matrices (default uniform rand [0, 1.0)])" - ) + choices=['randn', 'hpl', 'trig_float', 'zeros'], + help="Input tensor initialization (default normal distribution)") parser.add_argument( "--rotating_tensor", type=int, default=0, - help= - "total size (MB) of all tensors (default 0 MB (no rotating tensor), need to be larger than the L1, L2, MALL size)" - ) + help="total size (MB) of all tensors (a, b, c, bias)." + " The default value is 0 (no rotating tensor)." + " When set, it needs to be larger than the L1, L2, MALL size)") parser.add_argument("--bias_vector", action='store_true', default=False, @@ -788,7 +789,6 @@ def main(): if hack_triton: patch_triton_compiler() - configs = [] ## Big for loop of tuning @@ -927,6 +927,11 @@ def main(): print(f"Tuning ends at: {end_time}") print(f"Total tuning time (h:m:s): {tuning_time}") + if hack_triton: + print( + "Triton compiler is hacked, don't forget to git restore the changes :)" + ) + if __name__ == '__main__': sys.exit(main())