Skip to content

Commit

Permalink
Better README again
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Jul 23, 2024
1 parent a61967b commit e979474
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
30 changes: 24 additions & 6 deletions scripts/amd/gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The input gemm sizes are prepared in a yaml file. Here is an example yaml file:
The tuning script works as follows
```python
./tune_gemm --gemm_size_file input.yaml [options]
./tune_gemm.py --gemm_size_file input.yaml [options]
```
The following `options` are supported in the tuning mode

Expand Down Expand Up @@ -87,30 +87,40 @@ By default, only one task is generated and profiled on GPU0.
result. The default filename is `tuning_results_branchName@gitCommit_timeStamp.yaml`.
Therefore, each time the user runs the tuning script, a different output file
will be generated.
- Hacks
- `--hack_triton_compiler`: If set, the triton source code will be modified
to provide a static backend target so that the compiler will not query
GPU information. This makes sure that during the compilation stage, no
hip runtime kernels are launched.
Note that this is a very hacky option, because
- It modifies the triton compiler directly, which is located from
`pip show triton`.
- It does string match and replace to modify the code.
- It does not restore the code when the tuning session terminates.

Here are some example usages of running the script for tuning:

Tune some gemm sizes with f16 input
```python
./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml
./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml
```
It's recommended to use as many GPUs as possible and set `--jobs` to
a value that is 4 to 6 times the number of GPUs.

If you are only allowed to use a subset of the GPUs, you can
```python
./tune_gemm --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml
./tune_gemm.py --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml
```
This runs the profiling on GPU 0,1,3,4.

For bf8 input
```python
./tune_gemm --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8
./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8
```

Check correctness of the tuned configs
```python
./tune_gemm --gemm_size_file output.yaml --compare_wo_tuning
./tune_gemm.py --gemm_size_file output.yaml --compare_wo_tuning
```


Expand All @@ -120,7 +130,7 @@ In benchmark mode, the script will run a single given config multiple times to
collect performance data. The benchmark mode works as
The tuning script works as follows
```python
./tune_gemm --gemm_size_file input.yaml [options] --benchmark
./tune_gemm.py --gemm_size_file input.yaml [options] --benchmark
```
The supported `options` are as followings
- `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode.
Expand Down Expand Up @@ -273,6 +283,14 @@ that cannot divide `K`.
- Tuning result file is open and closed inside the tuning loop, enabling timely flush
of the tuning results.
- Now we use `rocprofv2` to measure kernel time.
- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation
stage. This is achieved by modifying the triton frontend compiler in the following
places:
- Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433)
- Return statically constructed GPUTarget from the `get_current_target()`
function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437)
- Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383)
- Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589)


# One config running script
Expand Down
29 changes: 17 additions & 12 deletions scripts/amd/gemm/tune_gemm.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# fp8
#!/usr/bin/env python3

import argparse
import sys
import yaml
Expand Down Expand Up @@ -569,10 +570,11 @@ def parse_args():
action='store_true',
default=False,
help="Whether check result correctness")
parser.add_argument("--compare_wo_tuning",
action='store_true',
default=False,
help="Whether check result correctness without tuning.")
parser.add_argument(
"--compare_wo_tuning",
action='store_true',
default=False,
help="Whether check result correctness without tuning.")
parser.add_argument("--benchmark",
action='store_true',
default=False,
Expand Down Expand Up @@ -605,16 +607,15 @@ def parse_args():
"--init_type",
type=str,
default='randn',
help=
"Initialization type for input matrices (default uniform rand [0, 1.0)])"
)
choices=['randn', 'hpl', 'trig_float', 'zeros'],
help="Input tensor initialization (default normal distribution)")
parser.add_argument(
"--rotating_tensor",
type=int,
default=0,
help=
"total size (MB) of all tensors (default 0 MB (no rotating tensor), need to be larger than the L1, L2, MALL size)"
)
help="total size (MB) of all tensors (a, b, c, bias)."
" The default value is 0 (no rotating tensor)."
" When set, it needs to be larger than the L1, L2, MALL size)")
parser.add_argument("--bias_vector",
action='store_true',
default=False,
Expand Down Expand Up @@ -788,7 +789,6 @@ def main():
if hack_triton:
patch_triton_compiler()


configs = []

## Big for loop of tuning
Expand Down Expand Up @@ -927,6 +927,11 @@ def main():
print(f"Tuning ends at: {end_time}")
print(f"Total tuning time (h:m:s): {tuning_time}")

if hack_triton:
print(
"Triton compiler is hacked, don't forget to git restore the changes :)"
)


if __name__ == '__main__':
sys.exit(main())

0 comments on commit e979474

Please sign in to comment.