From ee3dd4c2783220d6193e1bb78c5c081d963c4e27 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 1 Mar 2024 08:59:08 +0500 Subject: [PATCH] Make pool-related CLI arguments respect profile --- src/dstack/_internal/cli/commands/pool.py | 71 +++------- src/dstack/_internal/cli/commands/run.py | 64 +-------- .../cli/services/configurators/profile.py | 123 ++++++++++++++---- src/dstack/_internal/cli/utils/run.py | 9 ++ src/dstack/_internal/core/models/profiles.py | 11 +- src/dstack/_internal/server/services/runs.py | 4 +- src/dstack/api/server/__init__.py | 4 +- 7 files changed, 138 insertions(+), 148 deletions(-) diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index 881477876..85dd640e2 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -18,16 +18,12 @@ from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, - InstanceRuntime, SSHKey, ) from dstack._internal.core.models.pools import Instance, Pool from dstack._internal.core.models.profiles import ( - DEFAULT_POOL_TERMINATION_IDLE_TIME, Profile, SpotPolicy, - TerminationPolicy, - parse_max_duration, ) from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE from dstack._internal.core.models.runs import InstanceStatus, Requirements @@ -105,11 +101,6 @@ def _register(self) -> None: add_parser = subparsers.add_parser( "add", help="Add instance to pool", formatter_class=self._parser.formatter_class ) - add_parser.add_argument( - "--pool", - dest="pool_name", - help="The name of the pool. If not set, the default pool will be used", - ) add_parser.add_argument( "-y", "--yes", help="Don't ask for confirmation", action="store_true" ) @@ -127,8 +118,7 @@ def _register(self) -> None: add_parser.add_argument( "--name", dest="instance_name", help="Set the name of the instance" ) - add_parser.add_argument("--idle-duration", dest="idle_duration", help="Idle duration") - register_profile_args(add_parser) + register_profile_args(add_parser, pool_add=True) register_resource_args(add_parser) add_parser.set_defaults(subfunc=self._add) @@ -241,35 +231,7 @@ def _add(self, args: argparse.Namespace) -> None: ) profile = load_profile(Path.cwd(), args.profile) - apply_profile_args(args, profile) - profile.pool_name = args.pool_name - - spot = None - if profile.spot_policy == SpotPolicy.SPOT: - spot = True - if profile.spot_policy == SpotPolicy.ONDEMAND: - spot = False - - requirements = Requirements( - resources=resources, - max_price=args.max_price, - spot=spot, - ) - - idle_duration = parse_max_duration(args.idle_duration) - if idle_duration is None: - profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME - profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - elif idle_duration == "off": - profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME - profile.termination_policy = TerminationPolicy.DONT_DESTROY - elif isinstance(idle_duration, int): - profile.termination_idle_time = idle_duration - profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - else: - raise CLIError( - f"Invalid format --idle-duration {args.idle_duration!r}. It must be literal string 'off' or an integer number with an suffix s|m|h|d|w " - ) + apply_profile_args(args, profile, pool_add=True) # Add remote instance if args.remote: @@ -286,11 +248,21 @@ def _add(self, args: argparse.Namespace) -> None: # TODO(egor-s): print on success return + spot = None + if profile.spot_policy == SpotPolicy.SPOT: + spot = True + if profile.spot_policy == SpotPolicy.ONDEMAND: + spot = False + requirements = Requirements( + resources=resources, + max_price=profile.max_price, + spot=spot, + ) + with console.status("Getting instances..."): pool_offers = self.api.runs.get_offers(profile, requirements) - offers = [o for o in pool_offers.instances if o.instance_runtime == InstanceRuntime.SHIM] - + offers = [o for o in pool_offers.instances] print_offers_table(pool_offers.pool_name, profile, requirements, offers) if not offers: console.print("\nThere are no offers with these criteria. Exiting...") @@ -304,6 +276,8 @@ def _add(self, args: argparse.Namespace) -> None: pub_key = SSHKey(public=user_pub_key) try: with console.status("Creating instance..."): + # TODO: Instance name is not passed, so --instance does not work. + # There is profile.instance_name but it makes sense for `dstack run` only. instance = self.api.runs.create_instance( pool_offers.pool_name, profile, requirements, pub_key ) @@ -375,17 +349,6 @@ def print_offers_table( ) -> None: pretty_req = requirements.pretty_format(resources_only=True) max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" - max_duration = ( - f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-" - ) - - # TODO: improve retry policy - # retry_policy = profile.retry_policy - # retry_policy = ( - # (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes") - # if retry_policy.retry - # else "no" - # ) if requirements.spot is None: spot_policy = "auto" @@ -404,9 +367,7 @@ def th(s: str) -> str: props.add_row(th("Pool name"), pool_name) props.add_row(th("Min resources"), pretty_req) props.add_row(th("Max price"), max_price) - props.add_row(th("Max duration"), max_duration) props.add_row(th("Spot policy"), spot_policy) - # props.add_row(th("Retry policy"), retry_policy) offers_table = Table(box=None) offers_table.add_column("#") diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index b804eb452..203e11a8f 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,12 +17,6 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType -from dstack._internal.core.models.profiles import ( - DEFAULT_RUN_TERMINATION_IDLE_TIME, - CreationPolicy, - TerminationPolicy, - parse_max_duration, -) from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager from dstack._internal.utils.logging import get_logger @@ -84,29 +78,6 @@ def _register(self): type=int, default=3, ) - self._parser.add_argument( - "--pool", - dest="pool_name", - help="The name of the pool. If not set, the default pool will be used", - ) - self._parser.add_argument( - "--reuse", - dest="creation_policy_reuse", - action="store_true", - help="Reuse instance from pool", - ) - self._parser.add_argument( - "--idle-duration", - dest="idle_duration", - type=str, - help="Idle time before instance termination", - ) - self._parser.add_argument( - "--instance", - dest="instance_name", - metavar="NAME", - help="Reuse instance from pool with name [code]NAME[/]", - ) register_profile_args(self._parser) def _command(self, args: argparse.Namespace): @@ -118,31 +89,6 @@ def _command(self, args: argparse.Namespace): self._parser.print_help() return - termination_policy_idle = DEFAULT_RUN_TERMINATION_IDLE_TIME - termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - - if args.idle_duration is not None: - try: - termination_policy_idle = int(args.idle_duration) - except ValueError: - termination_policy_idle = parse_max_duration(args.idle_duration) - - creation_policy = ( - CreationPolicy.REUSE if args.creation_policy_reuse else CreationPolicy.REUSE_OR_CREATE - ) - - if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None: - console.print( - "[warning]If the flag --reuse is set, the argument --idle-duration will be skipped[/]" - ) - termination_policy = TerminationPolicy.DONT_DESTROY - - if args.instance_name is not None and termination_policy_idle is not None: - console.print( - f"[warning]--idle-duration won't be applied to the instance {args.instance_name!r}[/]" - ) - termination_policy = TerminationPolicy.DONT_DESTROY - super()._command(args) try: repo = self.api.repos.load(Path.cwd()) @@ -176,11 +122,11 @@ def _command(self, args: argparse.Namespace): max_price=profile.max_price, working_dir=args.working_dir, run_name=args.run_name, - pool_name=args.pool_name, - instance_name=args.instance_name, - creation_policy=creation_policy, - termination_policy=termination_policy, - termination_policy_idle=termination_policy_idle, + pool_name=profile.pool_name, + instance_name=profile.instance_name, + creation_policy=profile.creation_policy, + termination_policy=profile.termination_policy, + termination_policy_idle=profile.termination_idle_time, ) except ConfigurationError as e: raise CLIError(str(e)) diff --git a/src/dstack/_internal/cli/services/configurators/profile.py b/src/dstack/_internal/cli/services/configurators/profile.py index 2f6191847..1b02cfdee 100644 --- a/src/dstack/_internal/cli/services/configurators/profile.py +++ b/src/dstack/_internal/cli/services/configurators/profile.py @@ -2,15 +2,21 @@ import os from dstack._internal.core.models.profiles import ( + CreationPolicy, Profile, ProfileRetryPolicy, SpotPolicy, + TerminationPolicy, parse_duration, parse_max_duration, ) -def register_profile_args(parser: argparse.ArgumentParser): +def register_profile_args(parser: argparse.ArgumentParser, pool_add: bool = False): + """ + Registers `parser` with `dstack run` and `dstack pool add` + CLI arguments that override `profiles.yml` settings. + """ profile_group = parser.add_argument_group("Profile") profile_group.add_argument( "--profile", @@ -26,9 +32,14 @@ def register_profile_args(parser: argparse.ArgumentParser): help="The maximum price per hour, in dollars", dest="max_price", ) - profile_group.add_argument( - "--max-duration", type=max_duration, dest="max_duration", metavar="DURATION" - ) + if not pool_add: + profile_group.add_argument( + "--max-duration", + type=max_duration, + dest="max_duration", + help="The maximum duration of the run", + metavar="DURATION", + ) profile_group.add_argument( "-b", "--backend", @@ -45,6 +56,41 @@ def register_profile_args(parser: argparse.ArgumentParser): dest="regions", help="The regions that will be tried for provisioning", ) + if pool_add: + pools_group_exc = parser + else: + pools_group = parser.add_argument_group("Pools") + pools_group_exc = pools_group.add_mutually_exclusive_group() + pools_group_exc.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", + ) + pools_group_exc.add_argument( + "--reuse", + dest="creation_policy_reuse", + action="store_true", + help="Reuse instance from pool", + ) + pools_group_exc.add_argument( + "--dont-destroy", + dest="dont_destroy", + action="store_true", + help="Do not destroy instance after the run is finished", + ) + pools_group_exc.add_argument( + "--idle-duration", + dest="idle_duration", + type=str, + help="Time to wait before destroying the idle instance", + ) + if not pool_add: + pools_group_exc.add_argument( + "--instance", + dest="instance_name", + metavar="NAME", + help="Reuse instance from pool with name [code]NAME[/]", + ) spot_group = parser.add_argument_group("Spot policy") spot_group_exc = spot_group.add_mutually_exclusive_group() @@ -77,39 +123,62 @@ def register_profile_args(parser: argparse.ArgumentParser): help="One of %s" % ", ".join([f"[code]{i.value}[/]" for i in SpotPolicy]), ) - retry_group = parser.add_argument_group("Retry policy") - retry_group_exc = retry_group.add_mutually_exclusive_group() - retry_group_exc.add_argument("--retry", action="store_const", dest="retry_policy", const=True) - retry_group_exc.add_argument( - "--no-retry", action="store_const", dest="retry_policy", const=False - ) - retry_group_exc.add_argument( - "--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION" - ) + if not pool_add: + retry_group = parser.add_argument_group("Retry policy") + retry_group_exc = retry_group.add_mutually_exclusive_group() + retry_group_exc.add_argument( + "--retry", action="store_const", dest="retry_policy", const=True + ) + retry_group_exc.add_argument( + "--no-retry", action="store_const", dest="retry_policy", const=False + ) + retry_group_exc.add_argument( + "--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION" + ) -def apply_profile_args(args: argparse.Namespace, profile: Profile): - if args.max_price is not None: - profile.max_price = args.max_price - if args.max_duration is not None: - profile.max_duration = args.max_duration +def apply_profile_args(args: argparse.Namespace, profile: Profile, pool_add: bool = False): + """ + Overrides `profile` settings with arguments registered by `register_profile_args()`. + """ + # TODO: Re-assigned profile attributes are not validated by pydantic. + # So the validation will only be done by the server. + # Consider setting validate_assignment=True for modified pydantic models. if args.backends: profile.backends = args.backends if args.regions: profile.regions = args.regions + if args.max_price is not None: + profile.max_price = args.max_price + if not pool_add: + if args.max_duration is not None: + profile.max_duration = args.max_duration + + if args.pool_name: + profile.pool_name = args.pool_name + if args.idle_duration is not None: + profile.termination_idle_time = args.idle_duration + if args.dont_destroy: + profile.termination_policy = TerminationPolicy.DONT_DESTROY + if not pool_add: + if args.instance_name: + profile.instance_name = args.instance_name + if args.creation_policy_reuse: + profile.creation_policy = CreationPolicy.REUSE if args.spot_policy is not None: profile.spot_policy = args.spot_policy - if args.retry_policy is not None: - if not profile.retry_policy: - profile.retry_policy = ProfileRetryPolicy() - profile.retry_policy.retry = args.retry_policy - elif args.retry_limit is not None: - if not profile.retry_policy: - profile.retry_policy = ProfileRetryPolicy() - profile.retry_policy.retry = True - profile.retry_policy.limit = args.retry_limit + if not pool_add: + if args.retry_policy is not None: + if not profile.retry_policy: + profile.retry_policy = ProfileRetryPolicy() + profile.retry_policy.retry = args.retry_policy + elif args.retry_limit is not None: + if not profile.retry_policy: + profile.retry_policy = ProfileRetryPolicy() + profile.retry_policy.retry = True + profile.retry_policy.limit = args.retry_limit def max_duration(v: str) -> int: diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 27cd18b33..b34c156f1 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -4,6 +4,7 @@ from dstack._internal.cli.utils.common import console from dstack._internal.core.models.instances import InstanceAvailability, InstanceType +from dstack._internal.core.models.profiles import TerminationPolicy from dstack._internal.core.models.runs import RunPlan from dstack._internal.utils.common import pretty_date from dstack.api import Run @@ -28,6 +29,11 @@ def print_run_plan(run_plan: RunPlan, offers_limit: int = 3): if retry_policy.retry else "no" ) + creation_policy = run_plan.run_spec.profile.creation_policy + termination_policy = run_plan.run_spec.profile.termination_policy + termination_idle_time = f"{run_plan.run_spec.profile.termination_idle_time}s" + if termination_policy == TerminationPolicy.DONT_DESTROY: + termination_idle_time = "-" if req.spot is None: spot_policy = "auto" @@ -48,6 +54,9 @@ def th(s: str) -> str: props.add_row(th("Max duration"), max_duration) props.add_row(th("Spot policy"), spot_policy) props.add_row(th("Retry policy"), retry_policy) + props.add_row(th("Creation policy"), creation_policy) + props.add_row(th("Termination policy"), termination_policy) + props.add_row(th("Termination idle time"), termination_idle_time) offers = Table(box=None) offers.add_column("#") diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 3c3694db7..c240e828f 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -121,18 +121,21 @@ class Profile(ForbidExtra): instance_name: Annotated[Optional[str], Field(description="The name of the instance")] creation_policy: Annotated[ Optional[CreationPolicy], Field(description="The policy for using instances from the pool") - ] + ] = CreationPolicy.REUSE_OR_CREATE termination_policy: Annotated[ Optional[TerminationPolicy], Field(description="The policy for termination instances") - ] + ] = TerminationPolicy.DESTROY_AFTER_IDLE termination_idle_time: Annotated[ - int, - Field(description="Seconds to wait before destroying the instance"), + Optional[Union[str, int]], + Field(description="Time to wait before destroying the idle instance"), ] = DEFAULT_RUN_TERMINATION_IDLE_TIME _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration ) + _validate_termination_idle_time = validator( + "termination_idle_time", pre=True, allow_reuse=True + )(parse_duration) class ProfilesConfig(ForbidExtra): diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 1f3a835fa..574f80b94 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -172,7 +172,7 @@ async def get_run_plan( _validate_run_name(run_spec.run_name) profile = run_spec.profile - creation_policy = profile.creation_policy + creation_policy = profile.creation_policy or CreationPolicy.REUSE_OR_CREATE pool = await get_or_create_pool_by_name( session=session, project=project, pool_name=profile.pool_name @@ -199,7 +199,7 @@ async def get_run_plan( job_offers: List[InstanceOfferWithAvailability] = [] job_offers.extend(pool_offers) - if creation_policy is None or creation_policy == CreationPolicy.REUSE_OR_CREATE: + if creation_policy == CreationPolicy.REUSE_OR_CREATE: offers = await get_offers_by_requirements( project=project, profile=profile, diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 955bf956e..a98008698 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -1,4 +1,5 @@ import os +import pprint import time from typing import Dict, List, Optional, Type @@ -121,7 +122,8 @@ def _request( code = kwargs.pop("code") raise _server_client_errors[code](**kwargs) if resp.status_code == 422: - logger.debug("Server validation error: %s", resp.text) + formatted_error = pprint.pformat(resp.json()) + raise ClientError(f"Server validation error: \n{formatted_error}") resp.raise_for_status() return resp