Skip to content

Commit

Permalink
Replace message arg with kwargs for tqdm (#13)
Browse files Browse the repository at this point in the history
* replace message arg with tqdm kwargs

* added additional tqdm arguments example to readme

* lint with black
  • Loading branch information
mdmould authored Nov 21, 2023
1 parent 9034145 commit 54733b2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

will update every other step.

### Progress bar options

Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)
progress bar constructor. For example:

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, desc='progress bar', position=0, leave=False)
def step(carry, x):
return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

## Why JAX-tqdm?

JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),
Expand Down
31 changes: 18 additions & 13 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def scan_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
message: typing.Optional[str] = None,
**kwargs,
) -> typing.Callable:
"""
tqdm progress bar for a JAX scan
Expand All @@ -17,19 +17,19 @@ def scan_tqdm(
----------
n : int
Number of scan steps/iterations.
print_rate: int
print_rate : int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
message : str
Optional string to prepend to tqdm progress bar.
**kwargs
Extra keyword arguments to pass to tqdm.
Returns
-------
typing.Callable:
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)

def _scan_tqdm(func):
"""Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
Expand All @@ -55,7 +55,7 @@ def wrapper_progress_bar(carry, x):
def loop_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
message: typing.Optional[str] = None,
**kwargs,
) -> typing.Callable:
"""
tqdm progress bar for a JAX fori_loop
Expand All @@ -67,16 +67,16 @@ def loop_tqdm(
print_rate: int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
message : str
Optional string to prepend to tqdm progress bar.
**kwargs
Extra keyword arguments to pass to tqdm.
Returns
-------
typing.Callable:
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)

def _loop_tqdm(func):
"""
Expand All @@ -95,14 +95,19 @@ def wrapper_progress_bar(i, val):


def build_tqdm(
n: int, print_rate: typing.Optional[int], message: typing.Optional[str] = None
n: int,
print_rate: typing.Optional[int],
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
"""

if message is None:
message = f"Running for {n:,} iterations"
desc = kwargs.pop("desc", f"Running for {n:,} iterations")
message = kwargs.pop("message", desc)
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
kwargs.pop(kwarg, None)

tqdm_bars = {}

if print_rate is None:
Expand All @@ -122,7 +127,7 @@ def build_tqdm(
remainder = n % print_rate

def _define_tqdm(arg, transform):
tqdm_bars[0] = tqdm(range(n))
tqdm_bars[0] = tqdm(range(n), **kwargs)
tqdm_bars[0].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jax-tqdm"
version = "0.1.1"
version = "0.1.2"
description = "Tqdm progress bar for JAX scans and loops"
authors = [
"Jeremie Coullon <[email protected]>",
Expand Down

0 comments on commit 54733b2

Please sign in to comment.