Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pool-related CLI arguments respect profile #954

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 16 additions & 55 deletions src/dstack/_internal/cli/commands/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
InstanceRuntime,
SSHKey,
)
from dstack._internal.core.models.pools import Instance, Pool
from dstack._internal.core.models.profiles import (
DEFAULT_POOL_TERMINATION_IDLE_TIME,
Profile,
SpotPolicy,
TerminationPolicy,
parse_max_duration,
)
from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE
from dstack._internal.core.models.runs import InstanceStatus, Requirements
Expand Down Expand Up @@ -105,11 +101,6 @@ def _register(self) -> None:
add_parser = subparsers.add_parser(
"add", help="Add instance to pool", formatter_class=self._parser.formatter_class
)
add_parser.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
add_parser.add_argument(
"-y", "--yes", help="Don't ask for confirmation", action="store_true"
)
Expand All @@ -127,8 +118,7 @@ def _register(self) -> None:
add_parser.add_argument(
"--name", dest="instance_name", help="Set the name of the instance"
)
add_parser.add_argument("--idle-duration", dest="idle_duration", help="Idle duration")
register_profile_args(add_parser)
register_profile_args(add_parser, pool_add=True)
register_resource_args(add_parser)
add_parser.set_defaults(subfunc=self._add)

Expand Down Expand Up @@ -241,35 +231,7 @@ def _add(self, args: argparse.Namespace) -> None:
)

profile = load_profile(Path.cwd(), args.profile)
apply_profile_args(args, profile)
profile.pool_name = args.pool_name

spot = None
if profile.spot_policy == SpotPolicy.SPOT:
spot = True
if profile.spot_policy == SpotPolicy.ONDEMAND:
spot = False

requirements = Requirements(
resources=resources,
max_price=args.max_price,
spot=spot,
)

idle_duration = parse_max_duration(args.idle_duration)
if idle_duration is None:
profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME
profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE
elif idle_duration == "off":
profile.termination_idle_time = DEFAULT_POOL_TERMINATION_IDLE_TIME
profile.termination_policy = TerminationPolicy.DONT_DESTROY
elif isinstance(idle_duration, int):
profile.termination_idle_time = idle_duration
profile.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE
else:
raise CLIError(
f"Invalid format --idle-duration {args.idle_duration!r}. It must be literal string 'off' or an integer number with an suffix s|m|h|d|w "
)
apply_profile_args(args, profile, pool_add=True)

# Add remote instance
if args.remote:
Expand All @@ -286,11 +248,21 @@ def _add(self, args: argparse.Namespace) -> None:
# TODO(egor-s): print on success
return

spot = None
if profile.spot_policy == SpotPolicy.SPOT:
spot = True
if profile.spot_policy == SpotPolicy.ONDEMAND:
spot = False
requirements = Requirements(
resources=resources,
max_price=profile.max_price,
spot=spot,
)

with console.status("Getting instances..."):
pool_offers = self.api.runs.get_offers(profile, requirements)

offers = [o for o in pool_offers.instances if o.instance_runtime == InstanceRuntime.SHIM]

offers = [o for o in pool_offers.instances]
print_offers_table(pool_offers.pool_name, profile, requirements, offers)
if not offers:
console.print("\nThere are no offers with these criteria. Exiting...")
Expand All @@ -304,6 +276,8 @@ def _add(self, args: argparse.Namespace) -> None:
pub_key = SSHKey(public=user_pub_key)
try:
with console.status("Creating instance..."):
# TODO: Instance name is not passed, so --instance does not work.
# There is profile.instance_name but it makes sense for `dstack run` only.
instance = self.api.runs.create_instance(
pool_offers.pool_name, profile, requirements, pub_key
)
Expand Down Expand Up @@ -375,17 +349,6 @@ def print_offers_table(
) -> None:
pretty_req = requirements.pretty_format(resources_only=True)
max_price = f"${requirements.max_price:g}" if requirements.max_price else "-"
max_duration = (
f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-"
)

# TODO: improve retry policy
# retry_policy = profile.retry_policy
# retry_policy = (
# (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes")
# if retry_policy.retry
# else "no"
# )

if requirements.spot is None:
spot_policy = "auto"
Expand All @@ -404,9 +367,7 @@ def th(s: str) -> str:
props.add_row(th("Pool name"), pool_name)
props.add_row(th("Min resources"), pretty_req)
props.add_row(th("Max price"), max_price)
props.add_row(th("Max duration"), max_duration)
props.add_row(th("Spot policy"), spot_policy)
# props.add_row(th("Retry policy"), retry_policy)

offers_table = Table(box=None)
offers_table.add_column("#")
Expand Down
64 changes: 5 additions & 59 deletions src/dstack/_internal/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
from dstack._internal.cli.utils.run import print_run_plan
from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError
from dstack._internal.core.models.configurations import ConfigurationType
from dstack._internal.core.models.profiles import (
DEFAULT_RUN_TERMINATION_IDLE_TIME,
CreationPolicy,
TerminationPolicy,
parse_max_duration,
)
from dstack._internal.core.models.runs import JobErrorCode
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -84,29 +78,6 @@ def _register(self):
type=int,
default=3,
)
self._parser.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
self._parser.add_argument(
"--reuse",
dest="creation_policy_reuse",
action="store_true",
help="Reuse instance from pool",
)
self._parser.add_argument(
"--idle-duration",
dest="idle_duration",
type=str,
help="Idle time before instance termination",
)
self._parser.add_argument(
"--instance",
dest="instance_name",
metavar="NAME",
help="Reuse instance from pool with name [code]NAME[/]",
)
register_profile_args(self._parser)

def _command(self, args: argparse.Namespace):
Expand All @@ -118,31 +89,6 @@ def _command(self, args: argparse.Namespace):
self._parser.print_help()
return

termination_policy_idle = DEFAULT_RUN_TERMINATION_IDLE_TIME
termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE

if args.idle_duration is not None:
try:
termination_policy_idle = int(args.idle_duration)
except ValueError:
termination_policy_idle = parse_max_duration(args.idle_duration)

creation_policy = (
CreationPolicy.REUSE if args.creation_policy_reuse else CreationPolicy.REUSE_OR_CREATE
)

if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None:
console.print(
"[warning]If the flag --reuse is set, the argument --idle-duration will be skipped[/]"
)
termination_policy = TerminationPolicy.DONT_DESTROY

if args.instance_name is not None and termination_policy_idle is not None:
console.print(
f"[warning]--idle-duration won't be applied to the instance {args.instance_name!r}[/]"
)
termination_policy = TerminationPolicy.DONT_DESTROY

super()._command(args)
try:
repo = self.api.repos.load(Path.cwd())
Expand Down Expand Up @@ -176,11 +122,11 @@ def _command(self, args: argparse.Namespace):
max_price=profile.max_price,
working_dir=args.working_dir,
run_name=args.run_name,
pool_name=args.pool_name,
instance_name=args.instance_name,
creation_policy=creation_policy,
termination_policy=termination_policy,
termination_policy_idle=termination_policy_idle,
pool_name=profile.pool_name,
instance_name=profile.instance_name,
creation_policy=profile.creation_policy,
termination_policy=profile.termination_policy,
termination_policy_idle=profile.termination_idle_time,
)
except ConfigurationError as e:
raise CLIError(str(e))
Expand Down
123 changes: 96 additions & 27 deletions src/dstack/_internal/cli/services/configurators/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
import os

from dstack._internal.core.models.profiles import (
CreationPolicy,
Profile,
ProfileRetryPolicy,
SpotPolicy,
TerminationPolicy,
parse_duration,
parse_max_duration,
)


def register_profile_args(parser: argparse.ArgumentParser):
def register_profile_args(parser: argparse.ArgumentParser, pool_add: bool = False):
"""
Registers `parser` with `dstack run` and `dstack pool add`
CLI arguments that override `profiles.yml` settings.
"""
profile_group = parser.add_argument_group("Profile")
profile_group.add_argument(
"--profile",
Expand All @@ -26,9 +32,14 @@ def register_profile_args(parser: argparse.ArgumentParser):
help="The maximum price per hour, in dollars",
dest="max_price",
)
profile_group.add_argument(
"--max-duration", type=max_duration, dest="max_duration", metavar="DURATION"
)
if not pool_add:
profile_group.add_argument(
"--max-duration",
type=max_duration,
dest="max_duration",
help="The maximum duration of the run",
metavar="DURATION",
)
profile_group.add_argument(
"-b",
"--backend",
Expand All @@ -45,6 +56,41 @@ def register_profile_args(parser: argparse.ArgumentParser):
dest="regions",
help="The regions that will be tried for provisioning",
)
if pool_add:
pools_group_exc = parser
else:
pools_group = parser.add_argument_group("Pools")
pools_group_exc = pools_group.add_mutually_exclusive_group()
pools_group_exc.add_argument(
"--pool",
dest="pool_name",
help="The name of the pool. If not set, the default pool will be used",
)
pools_group_exc.add_argument(
"--reuse",
dest="creation_policy_reuse",
action="store_true",
help="Reuse instance from pool",
)
pools_group_exc.add_argument(
"--dont-destroy",
dest="dont_destroy",
action="store_true",
help="Do not destroy instance after the run is finished",
)
pools_group_exc.add_argument(
"--idle-duration",
dest="idle_duration",
type=str,
help="Time to wait before destroying the idle instance",
)
if not pool_add:
pools_group_exc.add_argument(
"--instance",
dest="instance_name",
metavar="NAME",
help="Reuse instance from pool with name [code]NAME[/]",
)

spot_group = parser.add_argument_group("Spot policy")
spot_group_exc = spot_group.add_mutually_exclusive_group()
Expand Down Expand Up @@ -77,39 +123,62 @@ def register_profile_args(parser: argparse.ArgumentParser):
help="One of %s" % ", ".join([f"[code]{i.value}[/]" for i in SpotPolicy]),
)

retry_group = parser.add_argument_group("Retry policy")
retry_group_exc = retry_group.add_mutually_exclusive_group()
retry_group_exc.add_argument("--retry", action="store_const", dest="retry_policy", const=True)
retry_group_exc.add_argument(
"--no-retry", action="store_const", dest="retry_policy", const=False
)
retry_group_exc.add_argument(
"--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION"
)
if not pool_add:
retry_group = parser.add_argument_group("Retry policy")
retry_group_exc = retry_group.add_mutually_exclusive_group()
retry_group_exc.add_argument(
"--retry", action="store_const", dest="retry_policy", const=True
)
retry_group_exc.add_argument(
"--no-retry", action="store_const", dest="retry_policy", const=False
)
retry_group_exc.add_argument(
"--retry-limit", type=retry_limit, dest="retry_limit", metavar="DURATION"
)


def apply_profile_args(args: argparse.Namespace, profile: Profile):
if args.max_price is not None:
profile.max_price = args.max_price
if args.max_duration is not None:
profile.max_duration = args.max_duration
def apply_profile_args(args: argparse.Namespace, profile: Profile, pool_add: bool = False):
"""
Overrides `profile` settings with arguments registered by `register_profile_args()`.
"""
# TODO: Re-assigned profile attributes are not validated by pydantic.
# So the validation will only be done by the server.
# Consider setting validate_assignment=True for modified pydantic models.
if args.backends:
profile.backends = args.backends
if args.regions:
profile.regions = args.regions
if args.max_price is not None:
profile.max_price = args.max_price
if not pool_add:
if args.max_duration is not None:
profile.max_duration = args.max_duration

if args.pool_name:
profile.pool_name = args.pool_name
if args.idle_duration is not None:
profile.termination_idle_time = args.idle_duration
if args.dont_destroy:
profile.termination_policy = TerminationPolicy.DONT_DESTROY
if not pool_add:
if args.instance_name:
profile.instance_name = args.instance_name
if args.creation_policy_reuse:
profile.creation_policy = CreationPolicy.REUSE

if args.spot_policy is not None:
profile.spot_policy = args.spot_policy

if args.retry_policy is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = args.retry_policy
elif args.retry_limit is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = True
profile.retry_policy.limit = args.retry_limit
if not pool_add:
if args.retry_policy is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = args.retry_policy
elif args.retry_limit is not None:
if not profile.retry_policy:
profile.retry_policy = ProfileRetryPolicy()
profile.retry_policy.retry = True
profile.retry_policy.limit = args.retry_limit


def max_duration(v: str) -> int:
Expand Down
Loading
Loading