Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ dynamic-services will fail if they have any required input that is not set #5845

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
DefaultPricingUnitNotFoundError,
NodeNotFoundError,
ProjectInvalidRightsError,
ProjectNodeRequiredInputsNotSetError,
ProjectNodeResourcesInsufficientRightsError,
ProjectNodeResourcesInvalidError,
ProjectNotFoundError,
Expand Down Expand Up @@ -105,6 +106,8 @@ async def wrapper(request: web.Request) -> web.StreamResponse:
raise web.HTTPConflict(reason=f"{exc}") from exc
except ClustersKeeperNotAvailableError as exc:
raise web.HTTPServiceUnavailable(reason=f"{exc}") from exc
except ProjectNodeRequiredInputsNotSetError as exc:
raise web.HTTPConflict(reason=f"{exc}") from exc

return wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import redis.exceptions
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.users import UserID

from ..errors import WebServerBaseError
Expand Down Expand Up @@ -136,6 +137,47 @@ class ProjectNodeResourcesInsufficientRightsError(BaseProjectError):
...


class ProjectNodeRequiredInputsNotSetError(BaseProjectError):
...


class ProjectNodeConnectionsMissingError(ProjectNodeRequiredInputsNotSetError):
msg_template = "Missing '{joined_unset_required_inputs}' connection(s) to '{node_with_required_inputs}'"

def __init__(
self,
*,
unset_required_inputs: list[str],
node_with_required_inputs: NodeID,
**ctx,
):

joined_unset_required_inputs = ", ".join(unset_required_inputs)
ctx["joined_unset_required_inputs"] = joined_unset_required_inputs
super().__init__(**ctx)
self.unset_required_inputs = unset_required_inputs
self.node_with_required_inputs = node_with_required_inputs


class ProjectNodeOutputPortMissingValueError(ProjectNodeRequiredInputsNotSetError):
msg_template = "Missing: {joined_start_message}"

def __init__(
self,
*,
unset_outputs_in_upstream: list[tuple[str, str]],
**ctx,
):
start_messages = [
f"'{input_key}' of '{service_name}'"
for input_key, service_name in unset_outputs_in_upstream
]
joined_start_message = ", ".join(start_messages)
ctx["joined_start_message"] = joined_start_message
super().__init__(**ctx)
self.unset_outputs_in_upstream = unset_outputs_in_upstream


class DefaultPricingUnitNotFoundError(BaseProjectError):
msg_template = "Default pricing unit not found for node '{node_uuid}' in project '{project_uuid}'"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from models_library.errors import ErrorDict
from models_library.products import ProductName
from models_library.projects import Project, ProjectID, ProjectIDStr
from models_library.projects_nodes import Node
from models_library.projects_nodes_io import NodeID, NodeIDStr
from models_library.projects_nodes import Node, OutputsDict
from models_library.projects_nodes_io import NodeID, NodeIDStr, PortLink
from models_library.projects_state import (
Owner,
ProjectLocked,
Expand Down Expand Up @@ -124,6 +124,9 @@
NodeNotFoundError,
ProjectInvalidRightsError,
ProjectLockError,
ProjectNodeConnectionsMissingError,
ProjectNodeOutputPortMissingValueError,
ProjectNodeRequiredInputsNotSetError,
ProjectNodeResourcesInvalidError,
ProjectOwnerNotFoundInTheProjectAccessRightsError,
ProjectStartsTooManyDynamicNodesError,
Expand Down Expand Up @@ -447,6 +450,56 @@ def _by_type_name(ec2: EC2InstanceTypeGet) -> bool:
raise ClustersKeeperNotAvailableError from exc


async def _check_project_node_has_all_required_inputs(
db: ProjectDBAPI, user_id: UserID, project_uuid: ProjectID, node_id: NodeID
) -> None:

project_dict, _ = await db.get_project(user_id, f"{project_uuid}")

nodes_map: dict[NodeID, Node] = {
NodeID(k): Node(**v) for k, v in project_dict["workbench"].items()
}
node = nodes_map[node_id]

unset_required_inputs: list[str] = []
unset_outputs_in_upstream: list[tuple[str, str]] = []

def _check_required_input(required_input_key: str) -> None:
input_entry: PortLink | None = None
if node.inputs:
input_entry = node.inputs.get(required_input_key, None)
if input_entry is None:
# NOT linked to any node connect service or set value manually(whichever applies)
unset_required_inputs.append(required_input_key)
return

source_node_id: NodeID = input_entry.node_uuid
source_output_key = input_entry.output

source_node = nodes_map[source_node_id]

output_entry: OutputsDict | None = None
if source_node.outputs:
output_entry = source_node.outputs.get(source_output_key, None)
if output_entry is None:
unset_outputs_in_upstream.append((source_output_key, source_node.label))

for required_input in node.inputs_required:
_check_required_input(required_input)

node_with_required_inputs = node.label
if unset_required_inputs:
raise ProjectNodeConnectionsMissingError(
unset_required_inputs=unset_required_inputs,
node_with_required_inputs=node_with_required_inputs,
)

if unset_outputs_in_upstream:
raise ProjectNodeOutputPortMissingValueError(
unset_outputs_in_upstream=unset_outputs_in_upstream
)


async def _start_dynamic_service(
request: web.Request,
*,
Expand All @@ -456,6 +509,7 @@ async def _start_dynamic_service(
user_id: UserID,
project_uuid: ProjectID,
node_uuid: NodeID,
graceful_start: bool = False,
) -> None:
if not _is_node_dynamic(service_key):
return
Expand All @@ -464,6 +518,20 @@ async def _start_dynamic_service(

db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(request.app)

try:
await _check_project_node_has_all_required_inputs(
db, user_id, project_uuid, node_uuid
)
except ProjectNodeRequiredInputsNotSetError as e:
if graceful_start:
log.info(
"Did not start '%s' because of missing required inputs: %s",
node_uuid,
e,
)
return
raise

save_state = False
user_role: UserRole = await get_user_role(request.app, user_id)
if user_role > UserRole.GUEST:
Expand Down Expand Up @@ -1464,6 +1532,7 @@ async def run_project_dynamic_services(
user_id=user_id,
project_uuid=project["uuid"],
node_uuid=NodeID(service_uuid),
graceful_start=True,
)
for service_uuid, is_deprecated in zip(
services_to_start_uuids, deprecated_services, strict=True
Expand Down
117 changes: 117 additions & 0 deletions services/web/server/tests/unit/with_dbs/03/test_project_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
from simcore_service_webserver.projects.db import ProjectAccessRights, ProjectDBAPI
from simcore_service_webserver.projects.exceptions import (
NodeNotFoundError,
ProjectNodeRequiredInputsNotSetError,
ProjectNotFoundError,
)
from simcore_service_webserver.projects.models import ProjectDict
from simcore_service_webserver.projects.projects_api import (
_check_project_node_has_all_required_inputs,
)
from simcore_service_webserver.users.exceptions import UserNotFoundError
from simcore_service_webserver.utils import to_datetime
from sqlalchemy.engine.result import Row
Expand Down Expand Up @@ -829,3 +833,116 @@ async def test_has_permission(
await db_api.has_permission(second_user["id"], project_id, permission)
is access_rights[permission]
), f"Found unexpected {permission=} for {access_rights=} of {user_role=} and {project_id=}"


def _fake_output_data() -> dict:
return {
"store": 0,
"path": "9f8207e6-144a-11ef-831f-0242ac140027/98b68cbe-9e22-4eb5-a91b-2708ad5317b7/outputs/output_2/output_2.zip",
"eTag": "ec3bc734d85359b660aab400147cd1ea",
}


def _fake_connect_to(output_number: int) -> dict:
return {
"nodeUuid": "98b68cbe-9e22-4eb5-a91b-2708ad5317b7",
"output": f"output_{output_number}",
}


@pytest.fixture
async def inserted_project(
logged_user: dict[str, Any],
insert_project_in_db: Callable[..., Awaitable[dict[str, Any]]],
fake_project: dict[str, Any],
downstream_inputs: dict,
downstream_required_inputs: list[str],
upstream_outputs: dict,
) -> dict:
fake_project["workbench"] = {
"98b68cbe-9e22-4eb5-a91b-2708ad5317b7": {
"key": "simcore/services/dynamic/jupyter-math",
"version": "2.0.10",
"label": "upstream",
"inputs": {},
"inputsUnits": {},
"inputNodes": [],
"thumbnail": "",
"outputs": upstream_outputs,
"runHash": "c6ae58f36a2e0f65f443441ecda023a451cb1b8051d01412d79aa03653e1a6b3",
},
"324d6ef2-a82c-414d-9001-dc84da1cbea3": {
"key": "simcore/services/dynamic/jupyter-math",
"version": "2.0.10",
"label": "downstream",
"inputs": downstream_inputs,
"inputsUnits": {},
"inputNodes": ["98b68cbe-9e22-4eb5-a91b-2708ad5317b7"],
"thumbnail": "",
"inputsRequired": downstream_required_inputs,
},
}

return await insert_project_in_db(fake_project, user_id=logged_user["id"])


@pytest.mark.parametrize(
"downstream_inputs,downstream_required_inputs,upstream_outputs,expected_error",
[
pytest.param(
{"input_1": _fake_connect_to(1)},
["input_1", "input_2"],
{},
"Missing 'input_2' connection(s) to 'downstream'",
id="missing_connection_on_input_2",
),
pytest.param(
{"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)},
["input_1", "input_2"],
{"output_2": _fake_output_data()},
"Missing: 'output_1' of 'upstream'",
id="output_1_has_not_file",
),
],
)
@pytest.mark.parametrize("user_role", [(UserRole.USER)])
async def test_check_project_node_has_all_required_inputs_raises(
logged_user: dict[str, Any],
db_api: ProjectDBAPI,
inserted_project: dict,
expected_error: str,
):

with pytest.raises(ProjectNodeRequiredInputsNotSetError) as exc:
await _check_project_node_has_all_required_inputs(
db_api,
user_id=logged_user["id"],
project_uuid=UUID(inserted_project["uuid"]),
node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"),
)
assert f"{exc.value}" == expected_error


@pytest.mark.parametrize(
"downstream_inputs,downstream_required_inputs,upstream_outputs",
[
pytest.param(
{"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)},
["input_1", "input_2"],
{"output_1": _fake_output_data(), "output_2": _fake_output_data()},
id="with_required_inputs_present",
),
],
)
@pytest.mark.parametrize("user_role", [(UserRole.USER)])
async def test_check_project_node_has_all_required_inputs_ok(
logged_user: dict[str, Any],
db_api: ProjectDBAPI,
inserted_project: dict,
):
await _check_project_node_has_all_required_inputs(
db_api,
user_id=logged_user["id"],
project_uuid=UUID(inserted_project["uuid"]),
node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"),
)
Loading