Skip to content

Commit

Permalink
Make pool-related CLI arguments respect profile (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Mar 1, 2024
1 parent 6086cfb commit 38115c8
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 148 deletions.
71 changes: 16 additions & 55 deletions src/dstack/_internal/cli/commands/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
InstanceRuntime,
SSHKey,
)
from dstack._internal.core.models.pools import Instance, Pool
from dstack._internal.core.models.profiles import (
DEFAULT_POOL_TERMINATION_IDLE_TIME,
Profile,
SpotPolicy,
TerminationPolicy,
parse_max_duration,
)
from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE
from dstack._internal.core.models.runs import InstanceStatus, Requirements
Expand Down Expand Up @@ -105,11 +101,6 @@ def _register(self) -> None:
add_parser = subparsers.add_parser(
"add", help="Add instance to pool", formatter_class=self._parser.formatter_class
)
add_parser.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
add_parser.add_argument(
"-y", "--yes", help="Don't ask for confirmation", action="store_true"
)
Expand All @@ -127,8 +118,7 @@ def _register(self) -> None:
add_parser.add_argument(
"--name", dest="instance_name", help="Set the name of the instance"
)
add_parser.add_argument("--idle-duration", dest="idle_duration", help="Idle duration")
register_profile_args(add_parser)
register_profile_args(add_parser, pool_add=True)
register_resource_args(add_parser)
add_parser.set_defaults(subfunc=self._add)

Expand Down Expand Up @@ -241,35 +231,7 @@ def _add(self, args: argparse.Namespace) -> None:
)

profile = load_profile(Path.cwd(), args.profile)
apply_profile_args(args, profile)
profile.pool_name = args.pool_name

spot = None
if profile.spot_policy == SpotPolicy.SPOT:
spot = True
if profile.spot_policy == SpotPolicy.ONDEMAND:
spot = False

requirements = Requirements(
resources=resources,
max_price=args.max_price,
spot=spot,
)

idle_duration = parse_max_duration(args.idle_duration)
if idle_duration is None:
profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME
profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE
elif idle_duration == "off":
profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME
profile.termination_policy = TerminationPolicy.DONT_DESTROY
elif isinstance(idle_duration, int):
profile.termination_idle_time = idle_duration
profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE
else:
raise CLIError(
f"Invalid format --idle-duration {args.idle_duration!r}. It must be literal string 'off' or an integer number with an suffix s|m|h|d|w "
)
apply_profile_args(args, profile, pool_add=True)

# Add remote instance
if args.remote:
Expand All @@ -286,11 +248,21 @@ def _add(self, args: argparse.Namespace) -> None:
# TODO(egor-s): print on success
return

spot = None
if profile.spot_policy == SpotPolicy.SPOT:
spot = True
if profile.spot_policy == SpotPolicy.ONDEMAND:
spot = False
requirements = Requirements(
resources=resources,
max_price=profile.max_price,
spot=spot,
)

with console.status("Getting instances..."):
pool_offers = self.api.runs.get_offers(profile, requirements)

offers = [o for o in pool_offers.instances if o.instance_runtime == InstanceRuntime.SHIM]

offers = [o for o in pool_offers.instances]
print_offers_table(pool_offers.pool_name, profile, requirements, offers)
if not offers:
console.print("\nThere are no offers with these criteria. Exiting...")
Expand All @@ -304,6 +276,8 @@ def _add(self, args: argparse.Namespace) -> None:
pub_key = SSHKey(public=user_pub_key)
try:
with console.status("Creating instance..."):
# TODO: Instance name is not passed, so --instance does not work.
# There is profile.instance_name but it makes sense for `dstack run` only.
instance = self.api.runs.create_instance(
pool_offers.pool_name, profile, requirements, pub_key
)
Expand Down Expand Up @@ -375,17 +349,6 @@ def print_offers_table(
) -> None:
pretty_req = requirements.pretty_format(resources_only=True)
max_price = f"${requirements.max_price:g}" if requirements.max_price else "-"
max_duration = (
f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-"
)

# TODO: improve retry policy
# retry_policy = profile.retry_policy
# retry_policy = (
# (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes")
# if retry_policy.retry
# else "no"
# )

if requirements.spot is None:
spot_policy = "auto"
Expand All @@ -404,9 +367,7 @@ def th(s: str) -> str:
props.add_row(th("Pool name"), pool_name)
props.add_row(th("Min resources"), pretty_req)
props.add_row(th("Max price"), max_price)
props.add_row(th("Max duration"), max_duration)
props.add_row(th("Spot policy"), spot_policy)
# props.add_row(th("Retry policy"), retry_policy)

offers_table = Table(box=None)
offers_table.add_column("#")
Expand Down
64 changes: 5 additions & 59 deletions src/dstack/_internal/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
from dstack._internal.cli.utils.run import print_run_plan
from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError
from dstack._internal.core.models.configurations import ConfigurationType
from dstack._internal.core.models.profiles import (
DEFAULT_RUN_TERMINATION_IDLE_TIME,
CreationPolicy,
TerminationPolicy,
parse_max_duration,
)
from dstack._internal.core.models.runs import JobErrorCode
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -84,29 +78,6 @@ def _register(self):
type=int,
default=3,
)
self._parser.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
self._parser.add_argument(
"--reuse",
dest="creation_policy_reuse",
action="store_true",
help="Reuse instance from pool",
)
self._parser.add_argument(
"--idle-duration",
dest="idle_duration",
type=str,
help="Idle time before instance termination",
)
self._parser.add_argument(
"--instance",
dest="instance_name",
metavar="NAME",
help="Reuse instance from pool with name [code]NAME[/]",
)
register_profile_args(self._parser)

def _command(self, args: argparse.Namespace):
Expand All @@ -118,31 +89,6 @@ def _command(self, args: argparse.Namespace):
self._parser.print_help()
return

termination_policy_idle = DEFAULT_RUN_TERMINATION_IDLE_TIME
termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE

if args.idle_duration is not None:
try:
termination_policy_idle = int(args.idle_duration)
except ValueError:
termination_policy_idle = parse_max_duration(args.idle_duration)

creation_policy = (
CreationPolicy.REUSE if args.creation_policy_reuse else CreationPolicy.REUSE_OR_CREATE
)

if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None:
console.print(
"[warning]If the flag --reuse is set, the argument --idle-duration will be skipped[/]"
)
termination_policy = TerminationPolicy.DONT_DESTROY

if args.instance_name is not None and termination_policy_idle is not None:
console.print(
f"[warning]--idle-duration won't be applied to the instance {args.instance_name!r}[/]"
)
termination_policy = TerminationPolicy.DONT_DESTROY

super()._command(args)
try:
repo = self.api.repos.load(Path.cwd())
Expand Down Expand Up @@ -176,11 +122,11 @@ def _command(self, args: argparse.Namespace):
max_price=profile.max_price,
working_dir=args.working_dir,
run_name=args.run_name,
pool_name=args.pool_name,
instance_name=args.instance_name,
creation_policy=creation_policy,
termination_policy=termination_policy,
termination_policy_idle=termination_policy_idle,
pool_name=profile.pool_name,
instance_name=profile.instance_name,
creation_policy=profile.creation_policy,
termination_policy=profile.termination_policy,
termination_policy_idle=profile.termination_idle_time,
)
except ConfigurationError as e:
raise CLIError(str(e))
Expand Down
123 changes: 96 additions & 27 deletions src/dstack/_internal/cli/services/configurators/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
import os

from dstack._internal.core.models.profiles import (
CreationPolicy,
Profile,
ProfileRetryPolicy,
SpotPolicy,
TerminationPolicy,
parse_duration,
parse_max_duration,
)


def register_profile_args(parser: argparse.ArgumentParser):
def register_profile_args(parser: argparse.ArgumentParser, pool_add: bool = False):
"""
Registers `parser` with `dstack run` and `dstack pool add`
CLI arguments that override `profiles.yml` settings.
"""
profile_group = parser.add_argument_group("Profile")
profile_group.add_argument(
"--profile",
Expand All @@ -26,9 +32,14 @@ def register_profile_args(parser: argparse.ArgumentParser):
help="The maximum price per hour, in dollars",
dest="max_price",
)
profile_group.add_argument(
"--max-duration", type=max_duration, dest="max_duration", metavar="DURATION"
)
if not pool_add:
profile_group.add_argument(
"--max-duration",
type=max_duration,
dest="max_duration",
help="The maximum duration of the run",
metavar="DURATION",
)
profile_group.add_argument(
"-b",
"--backend",
Expand All @@ -45,6 +56,41 @@ def register_profile_args(parser: argparse.ArgumentParser):
dest="regions",
help="The regions that will be tried for provisioning",
)
if pool_add:
pools_group_exc = parser
else:
pools_group = parser.add_argument_group("Pools")
pools_group_exc = pools_group.add_mutually_exclusive_group()
pools_group_exc.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
pools_group_exc.add_argument(
"--reuse",
dest="creation_policy_reuse",
action="store_true",
help="Reuse instance from pool",
)
pools_group_exc.add_argument(
"--dont-destroy",
dest="dont_destroy",
action="store_true",
help="Do not destroy instance after the run is finished",
)
pools_group_exc.add_argument(
"--idle-duration",
dest="idle_duration",
type=str,
help="Time to wait before destroying the idle instance",
)
if not pool_add:
pools_group_exc.add_argument(
"--instance",
dest="instance_name",
metavar="NAME",
help="Reuse instance from pool with name [code]NAME[/]",
)

spot_group = parser.add_argument_group("Spot policy")
spot_group_exc = spot_group.add_mutually_exclusive_group()
Expand Down Expand Up @@ -77,39 +123,62 @@ def register_profile_args(parser: argparse.ArgumentParser):
help="One of %s" % ", ".join([f"[code]{i.value}[/]" for i in SpotPolicy]),
)

retry_group = parser.add_argument_group("Retry policy")
retry_group_exc = retry_group.add_mutually_exclusive_group()
retry_group_exc.add_argument("--retry", action="store_const", dest="retry_policy", const=True)
retry_group_exc.add_argument(
"--no-retry", action="store_const", dest="retry_policy", const=False
)
retry_group_exc.add_argument(
"--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION"
)
if not pool_add:
retry_group = parser.add_argument_group("Retry policy")
retry_group_exc = retry_group.add_mutually_exclusive_group()
retry_group_exc.add_argument(
"--retry", action="store_const", dest="retry_policy", const=True
)
retry_group_exc.add_argument(
"--no-retry", action="store_const", dest="retry_policy", const=False
)
retry_group_exc.add_argument(
"--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION"
)


def apply_profile_args(args: argparse.Namespace, profile: Profile):
if args.max_price is not None:
profile.max_price = args.max_price
if args.max_duration is not None:
profile.max_duration = args.max_duration
def apply_profile_args(args: argparse.Namespace, profile: Profile, pool_add: bool = False):
"""
Overrides `profile` settings with arguments registered by `register_profile_args()`.
"""
# TODO: Re-assigned profile attributes are not validated by pydantic.
# So the validation will only be done by the server.
# Consider setting validate_assignment=True for modified pydantic models.
if args.backends:
profile.backends = args.backends
if args.regions:
profile.regions = args.regions
if args.max_price is not None:
profile.max_price = args.max_price
if not pool_add:
if args.max_duration is not None:
profile.max_duration = args.max_duration

if args.pool_name:
profile.pool_name = args.pool_name
if args.idle_duration is not None:
profile.termination_idle_time = args.idle_duration
if args.dont_destroy:
profile.termination_policy = TerminationPolicy.DONT_DESTROY
if not pool_add:
if args.instance_name:
profile.instance_name = args.instance_name
if args.creation_policy_reuse:
profile.creation_policy = CreationPolicy.REUSE

if args.spot_policy is not None:
profile.spot_policy = args.spot_policy

if args.retry_policy is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = args.retry_policy
elif args.retry_limit is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = True
profile.retry_policy.limit = args.retry_limit
if not pool_add:
if args.retry_policy is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = args.retry_policy
elif args.retry_limit is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = True
profile.retry_policy.limit = args.retry_limit


def max_duration(v: str) -> int:
Expand Down
Loading

0 comments on commit 38115c8

Please sign in to comment.