Skip to content

Commit

Permalink
chore(setup): put setup function under a `if __name__ == '__main__'…
Browse files Browse the repository at this point in the history
…` block
  • Loading branch information
XuehaiPan committed Feb 22, 2025
1 parent 7e56376 commit 5def973
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
rev: v0.9.7
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -36,7 +36,7 @@ repos:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
rev: 7.1.2
hooks:
- id: flake8
additional_dependencies:
Expand Down
88 changes: 43 additions & 45 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import TYPE_CHECKING, Generator
from typing import TYPE_CHECKING

from setuptools import setup


if TYPE_CHECKING:
from collections.abc import Generator
from types import ModuleType


Expand Down Expand Up @@ -70,48 +71,45 @@ def vcs_version(name: str, path: Path | str) -> Generator[ModuleType]:
file.write(content)


extra_requirements = {
'lint': [
'ruff',
'pylint[spelling]',
'mypy',
'typing-extensions',
'pre-commit',
],
'cuda10': ['nvidia-ml-py == 11.450.51'],
}


with vcs_version(
name='nvitop.version',
path=HERE / 'nvitop' / 'version.py',
) as version:
for pynvml_major in sorted(
{int(pynvml.partition('.')[0]) for pynvml in version.PYNVML_VERSION_CANDIDATES},
):
pynvml_range = [
pynvml
for pynvml in version.PYNVML_VERSION_CANDIDATES
if pynvml.startswith(f'{pynvml_major}.')
]
if len(pynvml_range) == 1:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py == {pynvml_range[0]}',
if __name__ == '__main__':
extra_requirements = {
'lint': [
'ruff',
'pylint[spelling]',
'mypy',
'typing-extensions',
'pre-commit',
],
'cuda10': ['nvidia-ml-py == 11.450.51'],
}

with vcs_version(name='nvitop.version', path=HERE / 'nvitop' / 'version.py') as version:
for pynvml_major in sorted(
{int(pynvml.partition('.')[0]) for pynvml in version.PYNVML_VERSION_CANDIDATES},
):
pynvml_range = [
pynvml
for pynvml in version.PYNVML_VERSION_CANDIDATES
if pynvml.startswith(f'{pynvml_major}.')
]
elif len(pynvml_range) >= 2:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py >= {pynvml_range[0]}, <= {pynvml_range[-1]}',
]
extra_requirements.update(
{
# The identifier could not start with numbers, add a prefix `pynvml-`
f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}']
for pynvml in version.PYNVML_VERSION_CANDIDATES
},
)

setup(
name='nvitop',
version=version.__version__,
extras_require=extra_requirements,
)
if len(pynvml_range) == 1:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py == {pynvml_range[0]}',
]
elif len(pynvml_range) >= 2:
extra_requirements[f'cuda{pynvml_major}'] = [
f'nvidia-ml-py >= {pynvml_range[0]}, <= {pynvml_range[-1]}',
]
extra_requirements.update(
{
# The identifier could not start with numbers, add a prefix `pynvml-`
f'pynvml-{pynvml}': [f'nvidia-ml-py == {pynvml}']
for pynvml in version.PYNVML_VERSION_CANDIDATES
},
)

setup(
name='nvitop',
version=version.__version__,
extras_require=extra_requirements,
)

0 comments on commit 5def973

Please sign in to comment.