-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(nvidia): build pytorch to get older cuda compute capabilities and setup arm64 support #578
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,9 @@ package eksapi | |
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"os" | ||
"text/template" | ||
"fmt" | ||
|
||
"k8s.io/klog" | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,9 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ | |
LANG=C.UTF-8 \ | ||
LC_ALL=C.UTF-8 | ||
|
||
ARG PYTORCH_BRANCH=v2.5.0 | ||
ARG PYTORCH_BUILD_ENV="MAX_JOBS=8 BUILD_TEST=0" | ||
|
||
############################################################################### | ||
# 1) System packages | ||
############################################################################### | ||
|
@@ -75,3 +78,20 @@ WORKDIR /app | |
COPY infer.py /app/ | ||
COPY requirements.txt /app/ | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
############################################################################### | ||
# 4) Install Pytorch from Source | ||
############################################################################### | ||
# envs needed to make the path of NVCC known to the compilation | ||
ENV CUDA_HOME=/usr/local/cuda | ||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64 | ||
ENV PATH=$PATH:$CUDA_HOME/bin | ||
# this list could be minimized based on the supported GPUs | ||
ENV TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.7 8.9 9.0" | ||
|
||
RUN pip3 install typing-extensions sympy | ||
RUN git clone \ | ||
--recursive https://github.com/pytorch/pytorch.git \ | ||
--branch $PYTORCH_BRANCH \ | ||
&& cd pytorch && eval "$PYTORCH_BUILD_ENV python3 setup.py install" && cd .. \ | ||
&& rm -rf pytorch | ||
Comment on lines
+82
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any idea how long this step takes? Just curious There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the workflow took ~5hr 30mins, which the bulk of is here 😅 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
torch==2.5 | ||
transformers==4.33 | ||
numpy==1.26 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
torch==2.3 | ||
transformers==4.29 | ||
numpy==1.23 | ||
transformers==4.33 | ||
numpy==1.26 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,8 +110,7 @@ def main(): | |
# Retrieve environment variables | ||
rank = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) | ||
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) | ||
num_gpus_per_node = int(os.getenv("NUM_GPUS_PER_NODE", "8")) | ||
local_rank = rank % num_gpus_per_node | ||
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mattcjo any reason we didn't do this before based on https://docs.open-mpi.org/en/v5.0.x/tuning-apps/environment-var.html? |
||
|
||
print(f"Process started for rank {rank} with local rank {local_rank}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious... Any reason for choosing a value of
8
forMAX_JOBS
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea this was just manually tuned based on OOM errors for
MAX_JOBS
being too high and searched for a value that passed in under the 6hr default gh-action limit