Skip to content

Commit

Permalink
Chat dataset support for sarvam 1 with correct chat template (#11)
Browse files Browse the repository at this point in the history
* sarvam 1 prompt template

* fix chat template

* links to sources

* support caching packed datasets and correct the save interval

* isort

* correct the cache save part

---------

Co-authored-by: mohit_sarvam_ai <[email protected]>
  • Loading branch information
mohit-sarvam and mohit_sarvam_ai authored Dec 8, 2024
1 parent 70a2f91 commit 6d89357
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 58 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"datasets",
"huggingface_hub[hf_transfer]",
"safetensors",
"transformers",

# Kaggle Integrations
"kagglehub",
Expand Down
40 changes: 33 additions & 7 deletions recipes/configs/sarvam1/full_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,46 @@ model:

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
_component_: torchtune.models.sarvam1.sarvam1_tokenizer
path: /projects/data/rahul_sarvam_ai/nemo_models/sarvam-1-pt/tokenizer.model
max_seq_len: 8192

# Dataset
dataset:
_component_: torchtune.datasets.instruct_dataset
_component_: torchtune.datasets.chat_dataset
packed: True # True increases speed
split_across_pack: False
conversation_style: openai
conversation_column: messages
source: json
packs_cache_path: /projects/data/mohit_sarvam_ai/torchtune/data/ft_train_cache
data_files: [
# /projects/data/mohit_sarvam_ai/sarvam-2b-ft/nemo/sft-data-train/sample.jsonl,
/projects/data/mohit_sarvam_ai/sarvam-2b-ft/nemo/sft-data-train/part-1-to-5.jsonl
# /projects/data/mohit_sarvam_ai/torchtune/data/ft_train/sample.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_0.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_1.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_2.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_3.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_4.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_5.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_6.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_7.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_8.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_9.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_10.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_11.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_12.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_13.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_14.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_15.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_16.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_17.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_18.json,
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_19.json,
]
split: train
seed: null
shuffle: True
output_dir: /projects/data/mohit_sarvam_ai/torchtune/output-sft/sarvam1-sft-v5

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
Expand All @@ -48,8 +71,9 @@ checkpointer:
pytorch_model.bin
]
recipe_checkpoint: null
output_dir: /projects/data/mohit_sarvam_ai/torchtune/output-sft/sarvam1-sft-v5-phase-1
output_dir: ${output_dir}
model_type: LLAMA3
save_interval: 1000
resume_from_checkpoint: False

# Fine-tuning arguments
Expand All @@ -61,12 +85,15 @@ optimizer:
lr: 7e-6
weight_decay: 0.01
betas: [0.9, 0.98]
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
max_steps_per_epoch: null
gradient_accumulation_steps: 6 # Use to increase virtual batch size
gradient_accumulation_steps: 4 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

Expand All @@ -85,7 +112,6 @@ metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
# the W&B project to log to
project: torchtune
output_dir: /projects/data/mohit_sarvam_ai/torchtune/output-sft/sarvam1-sft-v5-phase-1
log_every_n_steps: 10
log_peak_memory_stats: True

Expand Down
19 changes: 9 additions & 10 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,29 @@
import os
import sys
import time

from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn

import torch
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

from torchtune import config, modules, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training import PROFILER_KEY, DummyProfiler
from torchtune.training.activations import \
apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm

log = utils.get_logger("DEBUG")


Expand Down Expand Up @@ -747,7 +746,7 @@ def save_checkpoint(
ckpt_path = f"epoch_{epoch}"
if step is not None:
ckpt_path = f"{ckpt_path}_step_{step}"
self._checkpointer._output_dir = os.path.join(self._checkpointer._output_dir, ckpt_path)
self._checkpointer._output_dir = Path(os.path.join(self._output_dir, ckpt_path))

self._checkpointer.save_checkpoint(
checkpoint_dict,
Expand Down Expand Up @@ -929,8 +928,8 @@ def train(self) -> None:
# will include multiple forward / backward passes if gradient accumulation > 1
self._profiler.step()

if self._save_interval is not None and (idx + 1) % self._save_interval == 0:
self.save_checkpoint(epoch=curr_epoch, step=idx)
if self._save_interval is not None and ((idx + 1) / self._gradient_accumulation_steps) % self._save_interval == 0:
self.save_checkpoint(epoch=curr_epoch, step=int((idx + 1) / self._gradient_accumulation_steps))

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch, step=None)
Expand Down
27 changes: 26 additions & 1 deletion torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import pickle
from typing import Any, Callable, Dict, Optional, Union

from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
Expand All @@ -21,6 +23,8 @@ def chat_dataset(
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
split_across_pack: bool = False,
packs_cache_path: Optional[str] = None,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
Expand Down Expand Up @@ -155,6 +159,20 @@ def chat_dataset(
Raises:
ValueError: if the conversation format is not supported
"""

if packs_cache_path is not None and not packed:
raise ValueError("read_packs_from_path can only be used with packed=True.")

if packs_cache_path is not None and os.path.exists(packs_cache_path):
with open(packs_cache_path, "rb") as f:
packed_ds = pickle.load(f)
# check instance type
if packed_ds.__class__.__name__ != "PackedDataset":
raise ValueError(
"PackedDataset cache file is not a PackedDataset instance."
)
return packed_ds

if conversation_style == "sharegpt":
message_transform = ShareGPTToMessages(
train_on_input=train_on_input,
Expand Down Expand Up @@ -183,5 +201,12 @@ def chat_dataset(
raise ValueError(
"PackedDataset requires a max_seq_len to be set on the tokenizer."
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len)
ds = PackedDataset(
ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=split_across_pack
)

if packs_cache_path is not None:
with open(packs_cache_path, "wb") as f:
pickle.dump(ds, f)

return ds
5 changes: 3 additions & 2 deletions torchtune/models/sarvam1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._model_builders import sarvam1_tokenizer, sarvam1, lora_sarvam1

from ._model_builders import lora_sarvam1, sarvam1, sarvam1_tokenizer
from ._prompt_template import Sarvam1ChatTemplate
from ._tokenizer import Sarvam1Tokenizer

__all__ = [
"sarvam1",
"lora_sarvam1",
"Sarvam1Tokenizer",
"sarvam1_tokenizer",
"Sarvam1ChatTemplate",
]
27 changes: 20 additions & 7 deletions torchtune/models/sarvam1/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
from functools import partial
from typing import List, Optional

from torchtune.data._prompt_templates import (_get_prompt_template,
_TemplateType)
from torchtune.models.llama3._component_builders import llama3, lora_llama3

from torchtune.modules import TransformerDecoder
from torchtune.models.sarvam1._tokenizer import Sarvam1Tokenizer
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.data._prompt_templates import _TemplateType
from torchtune.data._prompt_templates import _get_prompt_template


"""
Model builders build specific instantiations using Llama 3 component builders.
"""

def sarvam1_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[str] = None) -> Sarvam1Tokenizer:

def sarvam1_tokenizer(
path: str,
max_seq_len: Optional[int] = None,
prompt_template: Optional[
_TemplateType
] = "torchtune.models.sarvam1.Sarvam1ChatTemplate",
) -> Sarvam1Tokenizer:
"""
Tokenizer for Sarvam1.
Expand All @@ -32,7 +37,15 @@ def sarvam1_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_templ
Returns:
Sarvam1Tokenizer: Instantiation of the Llama2 tokenizer
"""
return Sarvam1Tokenizer(path=path, max_seq_len=max_seq_len, prompt_template=prompt_template)
return Sarvam1Tokenizer(
path=path,
max_seq_len=max_seq_len,
prompt_template=(
_get_prompt_template(prompt_template)
if prompt_template is not None
else None
),
)


def sarvam1() -> TransformerDecoder:
Expand Down
84 changes: 84 additions & 0 deletions torchtune/models/sarvam1/_prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

from torchtune.data import Message, PromptTemplateInterface

# https://github.com/pytorch/torchtune/blob/26b2200010a37474015925c5e3f4606435b72dd3/torchtune/models/llama2/_prompt_template.py


class Sarvam1ChatTemplate(PromptTemplateInterface):
"""
Prompt template that formats chat data of human and system prompts with appropriate tags
used in Llama2 pre-training. Taken from Meta's official `Llama inference
repository <https://github.com/meta-llama/llama/blob/main/llama/generation.py>`_.
.. code-block:: text
"[INST] <<SYS>>
You are a helpful, respectful and honest assistant.
<</SYS>>"
I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture..."
"""

template = {
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST] "),
"assistant": ("", ""),
"ipython": ("", ""),
}

def __call__(
self,
messages: List[Message],
) -> List[Message]:
"""
Format user and system messages with appropriate tags.
Args:
messages (List[Message]): a single conversation, structured as a list
of `Message` objects
Returns:
The formatted list of messages
"""
system_message = []
formatted_dialogue = []
for message in messages:
if message.role == "system":
system_message = (
[{"type": "text", "content": self.template["system"][0]}]
+ message.content
+ [{"type": "text", "content": self.template["system"][1]}]
)
# Incorporate the system message in the user message - Llama2 only
# looks for the <<SYS>> tags and not the explicit role so this will
# be treated the same as an actual system message. We do this because
# of the nesting of the system prompt in the user message.
continue
elif message.role == "user":
content = (
[{"type": "text", "content": self.template["user"][0]}]
+ system_message
+ message.content
+ [{"type": "text", "content": self.template["user"][1]}]
)
elif message.role == "assistant":
# No special formatting needed for assistant message
content = message.content
formatted_dialogue.append(
Message(
role=message.role,
content=content,
masked=message.masked,
ipython=message.ipython,
eot=message.eot,
),
)
return formatted_dialogue
Loading

0 comments on commit 6d89357

Please sign in to comment.