Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] truncation and skipping #2419

Open
wants to merge 100 commits into
base: main
Choose a base branch
from
Open

[RFC] truncation and skipping #2419

wants to merge 100 commits into from

Conversation

krammnic
Copy link
Contributor

@krammnic krammnic commented Feb 21, 2025

#2344 Mention two important points related to our data loading and processing. This RFC works on both of these aspects.

Truncation

Currently, we don't support truncation in both right and left directions. Unfortunately, this is not the best strategy, as both of the truncation types might be relevant in different situations (assume that important information in the sample is presented in further context, then right truncation is better and vice versa). First of all, will show losses for both experiments (Qwen2.5 3B, LoRA, 2048 max_seq_len, packed: True).

Right truncation:

W B Chart 2_21_2025, 10_06_52 PM

Left truncation:

W B Chart 2_21_2025, 9_53_48 PM

We see that with left truncation we have better convergence. Notice, that it happens because of the dataset structure, not because that it is always better to have such truncation!

Therefore, there are 2 options:

  1. truncation_type as tokenizer option in config;

Pros:
Adaptiveness to datasets with different sample structures.

Cons:
One more argument in config!

  1. truncation_type is default to right/left

Pros and Cons are opposite to 1.

I like the idea of setting truncation_type: "left" by default without exposing it to all configs, but saving possibility to change truncation type to left.

Skipping

It is simple check directly in recipe, the only possible question is - where to put the check_batch_requires_grad function.

@krammnic
Copy link
Contributor Author

Screenshot 2025-02-26 at 01 44 37 DPO LoRA single device qwen3b 20 steps, truncation_type: right

@krammnic
Copy link
Contributor Author

Screenshot 2025-02-26 at 01 53 18

DPO LoRA single device qwen3b 20 steps, truncation_type: left

@krammnic
Copy link
Contributor Author

krammnic commented Feb 25, 2025

Gemma2b, lora distributed, truncation_type: "right"

Step 1 | loss:26.85302734375 lr:2.0000000000000003e-06 tokens_per_second_per_gpu:209.4905242919922 peak_memory_active:5.214685440063477 > Step 2 | loss:27.898555755615234 lr:4.000000000000001e-06 tokens_per_second_per_gpu:399.8436279296875 peak_memory_active:5.3756670951843> Step 3 | loss:26.375633239746094 lr:6e-06 tokens_per_second_per_gpu:589.7933959960938 peak_memory_active:5.548027515411377 peak_memory_a> Step 4 | loss:25.86777687072754 lr:8.000000000000001e-06 tokens_per_second_per_gpu:457.12811279296875 peak_memory_active:5.3763642311096> Step 5 | loss:23.593978881835938 lr:1e-05 tokens_per_second_per_gpu:502.33099365234375 peak_memory_active:5.376002788543701 peak_memory_> Step 6 | loss:20.89067268371582 lr:1.2e-05 tokens_per_second_per_gpu:652.4540405273438 peak_memory_active:5.844363689422607 peak_memory_> Step 7 | loss:20.04551887512207 lr:1.4e-05 tokens_per_second_per_gpu:482.68695068359375 peak_memory_active:5.376347064971924 peak_memory> Step 8 | loss:17.218624114990234 lr:1.6000000000000003e-05 tokens_per_second_per_gpu:551.7702026367188 peak_memory_active:6.474055767059> Step 9 | loss:17.941665649414062 lr:1.8e-05 tokens_per_second_per_gpu:546.103515625 peak_memory_active:5.376846790313721 peak_memory_all> Step 10 | loss:16.08778953552246 lr:2e-05 tokens_per_second_per_gpu:485.8663024902344 peak_memory_active:5.376261234283447 peak_memory_a> Step 11 | loss:15.600129127502441 lr:1.9999998828397348e-05 tokens_per_second_per_gpu:557.3656616210938 peak_memory_active:5.37738084793> Step 12 | loss:14.559475898742676 lr:1.999999531358965e-05 tokens_per_second_per_gpu:532.0314331054688 peak_memory_active:5.375916957855> Step 13 | loss:13.859354019165039 lr:1.9999989455577743e-05 tokens_per_second_per_gpu:672.8965454101562 peak_memory_active:5.41393947601> Step 14 | loss:12.657991409301758 lr:1.999998125436299e-05 tokens_per_second_per_gpu:654.568115234375 peak_memory_active:5.4139394760131> Step 15 | loss:11.866594314575195 lr:1.9999970709947322e-05 tokens_per_second_per_gpu:413.97149658203125 peak_memory_active:5.3771395683> Step 16 | loss:11.371137619018555 lr:1.9999957822333203e-05 tokens_per_second_per_gpu:567.3087158203125 peak_memory_active:5.37627840042> Step 17 | loss:11.43862247467041 lr:1.9999942591523652e-05 tokens_per_second_per_gpu:556.8408203125 peak_memory_active:5.376760959625244> Step 18 | loss:10.980806350708008 lr:1.9999925017522245e-05 tokens_per_second_per_gpu:574.8773193359375 peak_memory_active:5.37643289566> Step 19 | loss:10.352394104003906 lr:1.999990510033309e-05 tokens_per_second_per_gpu:548.7708740234375 peak_memory_active:5.375951290130> Step 20 | loss:10.502054214477539 lr:1.9999882839960865e-05 tokens_per_second_per_gpu:683.5306396484375 peak_memory_active:6.06643533706>

But found a small bug to fix

@krammnic
Copy link
Contributor Author

Gemma 2b, lora_distributed, truncation: "left"

W B Chart 2_26_2025, 2_11_47 AM

@krammnic
Copy link
Contributor Author

krammnic commented Feb 25, 2025

Reverified dpo: it saving correctly, so if recipe will fail again - it is llama2 specific.

UPD: It was

@krammnic
Copy link
Contributor Author

W B Chart 2_26_2025, 3_11_19 AM
W B Chart 2_26_2025, 3_14_11 AM

@codecov-commenter
Copy link

codecov-commenter commented Feb 26, 2025

Codecov Report

Attention: Patch coverage is 18.80342% with 95 lines in your changes missing coverage. Please review.

Project coverage is 23.07%. Comparing base (4d9840c) to head (302946a).

Files with missing lines Patch % Lines
...chtune/modules/transforms/tokenizers/test_utils.py 39.28% 17 Missing ⚠️
torchtune/modules/transforms/tokenizers/_utils.py 46.15% 7 Missing ⚠️
recipes/full_dpo_distributed.py 0.00% 6 Missing ⚠️
torchtune/data/_utils.py 0.00% 6 Missing ⚠️
tests/torchtune/data/test_data_utils.py 0.00% 4 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 3 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 3 Missing ⚠️
recipes/knowledge_distillation_distributed.py 0.00% 3 Missing ⚠️
recipes/knowledge_distillation_single_device.py 0.00% 3 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 3 Missing ⚠️
... and 21 more

❗ There is a different number of reports uploaded between BASE (4d9840c) and HEAD (302946a). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (4d9840c) HEAD (302946a)
2 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2419       +/-   ##
===========================================
- Coverage   65.34%   23.07%   -42.27%     
===========================================
  Files         374      379        +5     
  Lines       22161    22785      +624     
===========================================
- Hits        14481     5258     -9223     
- Misses       7680    17527     +9847     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants