Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error message on packed=True for stack exchange dataset #2079

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,14 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input):
else:
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16

def test_dataset_fails_with_packed(self):
with pytest.raises(
ValueError,
match="Packed is currently not supported for preference datasets",
):
hh_rlhf_helpful_dataset(
tokenizer=DummyTokenizer(),
train_on_input=True,
packed=True,
)
14 changes: 14 additions & 0 deletions tests/torchtune/datasets/test_preference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,17 @@ def test_load_local_json(self):

assert expected_chosen_labels[0] == ds[0]["chosen_labels"]
assert expected_rejected_labels[0] == ds[0]["rejected_labels"]

def test_dataset_fails_with_packed(self):
with pytest.raises(
ValueError,
match="Packed is currently not supported for preference datasets.",
):
preference_dataset(
tokenizer=DummyTokenizer(),
source="json",
data_files=str(ASSETS / "hh_rlhf_tiny.json"),
train_on_input=False,
split="train",
packed=True,
)
10 changes: 10 additions & 0 deletions tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input):
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 52

def test_dataset_fails_with_packed(self):
with pytest.raises(
ValueError,
match="Packed is currently not supported for preference datasets",
):
stack_exchange_paired_dataset(
tokenizer=DummyTokenizer(),
packed=True,
)


class TestStackExchangePairedToMessages:
@pytest.fixture
Expand Down
11 changes: 11 additions & 0 deletions torchtune/datasets/_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. Packed is
currently not supported for ``PreferenceDataset`` and a ``ValueError`` will be raised if this is set to True.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
for more details.

Raises:
ValueError: If ``packed`` is True, this feature is not supported for ``PreferenceDataset``.
"""

def __init__(
Expand All @@ -101,8 +106,14 @@ def __init__(
message_transform: Transform,
tokenizer: ModelTokenizer,
filter_fn: Optional[Callable] = None,
packed: bool = False,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
if packed:
raise ValueError(
"Packed is currently not supported for preference datasets."
)

self._tokenizer = tokenizer
self._message_transform = message_transform
self._data = load_dataset(source, **load_dataset_kwargs)
Expand Down