Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename FormValidation #6468

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/6463.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The conversation event `form_validation` was renamed to `loop_unhappy`. Rasa Open Source
will continue to be able to read and process old `form_validation` events.
34 changes: 26 additions & 8 deletions rasa/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ def __eq__(self, other) -> bool:
)

def __str__(self) -> Text:
return "UserUttered(text: {}, intent: {}, entities: {})".format(
self.text, self.intent, self.entities
return (
f"UserUttered(text: {self.text}, "
f"intent: {self.intent}, entities: {self.entities})"
)

@staticmethod
Expand Down Expand Up @@ -1183,11 +1184,11 @@ def as_dict(self) -> Dict[Text, Any]:
return d


class FormValidation(Event):
class LoopUnhappy(Event):
"""Event added by FormPolicy and RulePolicy to notify form action
whether or not to validate the user input."""

type_name = "form_validation"
type_name = "loop_unhappy"

def __init__(
self,
Expand All @@ -1199,20 +1200,20 @@ def __init__(
super().__init__(timestamp, metadata)

def __str__(self) -> Text:
return f"FormValidation({self.validate})"
return f"{LoopUnhappy.__name__}({self.validate})"

def __hash__(self) -> int:
return hash(self.validate)

def __eq__(self, other) -> bool:
return isinstance(other, FormValidation)
return isinstance(other, LoopUnhappy)

def as_story_string(self) -> None:
return None

@classmethod
def _from_parameters(cls, parameters) -> "FormValidation":
return FormValidation(
def _from_parameters(cls, parameters) -> "LoopUnhappy":
return LoopUnhappy(
parameters.get("validate"),
parameters.get("timestamp"),
parameters.get("metadata"),
Expand All @@ -1227,6 +1228,23 @@ def apply_to(self, tracker: "DialogueStateTracker") -> None:
tracker.set_form_validation(self.validate)


class LegacyFormValidation(LoopUnhappy):
"""Legacy handler of old `FormValidation` events.

The `LoopUnhappy` event used to be called `FormValidation`. This class is there to
handle old legacy events which were stored with the old type name `form_validation`.
"""

type_name = "form_validation"

def as_dict(self) -> Dict[Text, Any]:
d = super().as_dict()
# Dump old `Form` events as `ActiveLoop` events instead of keeping the old
# event type.
d["event"] = LoopUnhappy.type_name
return d


class ActionExecutionRejected(Event):
"""Notify Core that the execution of the action has been rejected"""

Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/form_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rasa.constants import DOCS_URL_MIGRATION_GUIDE
from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.domain import PREV_PREFIX, ACTIVE_FORM_PREFIX, Domain
from rasa.core.events import FormValidation
from rasa.core.events import LoopUnhappy
from rasa.core.featurizers import TrackerFeaturizer
from rasa.core.interpreter import NaturalLanguageInterpreter, RegexInterpreter
from rasa.core.policies.memoization import MemoizationPolicy
Expand Down Expand Up @@ -148,7 +148,7 @@ def predict_action_probabilities(

if tracker.active_loop.get("rejected"):
if self.state_is_unhappy(tracker, domain):
tracker.update(FormValidation(False))
tracker.update(LoopUnhappy(False))
return result

result = self._prediction_result(
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections import defaultdict

from rasa.core.events import FormValidation
from rasa.core.events import LoopUnhappy
from rasa.core.domain import PREV_PREFIX, ACTIVE_FORM_PREFIX, Domain, InvalidDomain
from rasa.core.featurizers import TrackerFeaturizer
from rasa.core.interpreter import NaturalLanguageInterpreter, RegexInterpreter
Expand Down Expand Up @@ -467,7 +467,7 @@ def _find_action_from_rules(

if DO_NOT_VALIDATE_FORM in unhappy_path_conditions:
logger.debug("Added `FormValidation(False)` event.")
tracker.update(FormValidation(False))
tracker.update(LoopUnhappy(False))

if predicted_action_name is not None:
logger.debug(
Expand Down
18 changes: 16 additions & 2 deletions rasa/core/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,24 @@ def change_form_to(self, form_name: Text) -> None:
)
self.change_loop_to(form_name)

def set_form_validation(self, validate: bool) -> None:
"""Toggle form validation"""
# TODO: Change this :-D
def make_loop_unhappy(self, validate: bool) -> None:
"""Toggle loop validation.

Args:
validate: `False` if the loop was run after an unhappy path.
"""
self.active_loop["validate"] = validate

def set_form_validation(self, validate: bool) -> None:
common_utils.raise_warning(
"`set_form_validation` is deprecated and will be removed "
"in future versions. Please use `change_loop_to` "
"instead.",
category=DeprecationWarning,
)
self.make_loop_unhappy(validate)

def reject_action(self, action_name: Text) -> None:
"""Notify active loop that it was rejected"""
if action_name == self.active_loop.get("name"):
Expand Down
6 changes: 3 additions & 3 deletions tests/core/policies/test_rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ActiveLoop,
SlotSet,
ActionExecutionRejected,
FormValidation,
LoopUnhappy,
)
from rasa.core.interpreter import RegexInterpreter
from rasa.core.nlg import TemplatedNaturalLanguageGenerator
Expand Down Expand Up @@ -645,7 +645,7 @@ async def test_form_unhappy_path_no_validation_from_rule():
action_probabilities = policy.predict_action_probabilities(tracker, domain)
assert_predicted_action(action_probabilities, domain, form_name)
# check that RulePolicy added FormValidation False event based on the training rule
assert tracker.events[-1] == FormValidation(False)
assert tracker.events[-1] == LoopUnhappy(False)


async def test_form_unhappy_path_no_validation_from_story():
Expand Down Expand Up @@ -712,7 +712,7 @@ async def test_form_unhappy_path_no_validation_from_story():
# there is no rule for next action
assert max(action_probabilities) == policy._core_fallback_threshold
# check that RulePolicy added FormValidation False event based on the training story
assert tracker.events[-1] == FormValidation(False)
assert tracker.events[-1] == LoopUnhappy(False)


async def test_form_unhappy_path_without_rule():
Expand Down
44 changes: 44 additions & 0 deletions tests/core/test_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ActionExecutionRejected,
BotUttered,
LegacyForm,
LoopUnhappy,
LegacyFormValidation,
)
from rasa.core.slots import FloatSlot, BooleanSlot, ListSlot, TextSlot, DataSlot, Slot
from rasa.core.tracker_store import (
Expand Down Expand Up @@ -1111,3 +1113,45 @@ def test_change_form_to_deprecation_warning():
tracker.change_form_to(new_form)

assert tracker.active_loop_name() == new_form


def test_reading_of_trackers_with_legacy_form_validation_events():
tracker = DialogueStateTracker.from_dict(
"sender",
events_as_dict=[
{"event": LegacyFormValidation.type_name, "name": None, "validate": True},
{"event": LegacyFormValidation.type_name, "name": None, "validate": False},
],
)

expected_events = [LegacyFormValidation(True), LegacyFormValidation(False)]
actual_events = list(tracker.events)
assert list(tracker.events) == expected_events
assert actual_events[0].validate
assert not actual_events[1].validate

assert not tracker.active_loop.get("validate")


def test_writing_trackers_with_legacy_for_validation_events():
tracker = DialogueStateTracker.from_events(
"sender", evts=[LegacyFormValidation(True), LegacyFormValidation(False)]
)

events_as_dict = [event.as_dict() for event in tracker.events]

for event in events_as_dict:
assert event["event"] == LoopUnhappy.type_name

assert events_as_dict[0]["validate"]
assert not events_as_dict[1]["validate"]


@pytest.mark.parametrize("validate", [True, False])
def test_set_form_validation_deprecation_warning(validate: bool):
tracker = DialogueStateTracker.from_events("conversation", evts=[])

with pytest.warns(DeprecationWarning):
tracker.set_form_validation(validate)

assert tracker.active_loop["validate"] == validate
10 changes: 5 additions & 5 deletions tests/core/training/story_reader/test_markdown_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ActionExecuted,
ActionExecutionRejected,
ActiveLoop,
FormValidation,
LoopUnhappy,
SlotSet,
LegacyForm,
)
Expand Down Expand Up @@ -124,7 +124,7 @@ async def test_persist_legacy_form_story():
ActionExecuted("action_listen"),
# out of form input but continue with the form
UserUttered(intent={"name": "affirm"}),
FormValidation(False),
LoopUnhappy(False),
ActionExecuted("some_form"),
ActionExecuted("action_listen"),
# out of form input
Expand All @@ -134,7 +134,7 @@ async def test_persist_legacy_form_story():
ActionExecuted("action_listen"),
# form input
UserUttered(intent={"name": "inform"}),
FormValidation(True),
LoopUnhappy(True),
ActionExecuted("some_form"),
ActionExecuted("action_listen"),
ActiveLoop(None),
Expand Down Expand Up @@ -199,7 +199,7 @@ async def test_persist_form_story():
ActionExecuted("action_listen"),
# out of form input but continue with the form
UserUttered(intent={"name": "affirm"}),
FormValidation(False),
LoopUnhappy(False),
ActionExecuted("some_form"),
ActionExecuted("action_listen"),
# out of form input
Expand All @@ -209,7 +209,7 @@ async def test_persist_form_story():
ActionExecuted("action_listen"),
# form input
UserUttered(intent={"name": "inform"}),
FormValidation(True),
LoopUnhappy(True),
ActionExecuted("some_form"),
ActionExecuted("action_listen"),
ActiveLoop(None),
Expand Down