Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

多卡训练卡住 #169

Open
whodianjiao opened this issue Jan 30, 2025 · 3 comments
Open

多卡训练卡住 #169

whodianjiao opened this issue Jan 30, 2025 · 3 comments

Comments

@whodianjiao
Copy link

我在尝试自己多卡训练DIS5K数据集的时候出现了卡住的情况:
`(birefnet) amax@amax-Super-Server:~/apj/project/birefnet/main/codes/dis/BiRefNet$ sh train_test.sh DIS5K 0,1,2,3 0
task = 'DIS5K'
Training started at 2025年 01月 30日 星期四 23:08:18 CST
Multi-GPU mode received...
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
WARNING:torch.distributed.run:


Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.


INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 2
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 3
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 1
INFO:torch.distributed.distributed_c10d:Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2025-01-30 23:08:23,983 INFO datasets: load_all=False, compile=True.
2025-01-30 23:08:23,983 INFO Other hyperparameters:
2025-01-30 23:08:23,983 INFO Namespace(resume='xx/xx-epoch_244.pth', epochs=500, ckpt_dir='ckpt/DIS5K', dist=False, use_accelerate=True)
batch size: 1
2025-01-30 23:08:23,985 INFO datasets: load_all=False, compile=True.
2025-01-30 23:08:23,985 INFO Other hyperparameters:
2025-01-30 23:08:23,985 INFO Namespace(resume='xx/xx-epoch_244.pth', epochs=500, ckpt_dir='ckpt/DIS5K', dist=False, use_accelerate=True)
batch size: 1
2025-01-30 23:08:23,995 INFO datasets: load_all=False, compile=True.
2025-01-30 23:08:23,995 INFO Other hyperparameters:
2025-01-30 23:08:23,995 INFO Namespace(resume='xx/xx-epoch_244.pth', epochs=500, ckpt_dir='ckpt/DIS5K', dist=False, use_accelerate=True)
batch size: 1
3000 3000batches of train dataloader DIS-TR have been created.
batches of train dataloader DIS-TR have been created.
3000 batches of train dataloader DIS-TR have been created.
2025-01-30 23:08:24,035 INFO datasets: load_all=False, compile=True.
2025-01-30 23:08:24,035 INFO Other hyperparameters:
2025-01-30 23:08:24,035 INFO Namespace(resume='xx/xx-epoch_244.pth', epochs=500, ckpt_dir='ckpt/DIS5K', dist=False, use_accelerate=True)
batch size: 1
3000 batches of train dataloader DIS-TR have been created.
/home/amax/.conda/envs/birefnet/lib/python3.11/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
/home/amax/.conda/envs/birefnet/lib/python3.11/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
/home/amax/.conda/envs/birefnet/lib/python3.11/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
/home/amax/.conda/envs/birefnet/lib/python3.11/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
Found correct weights in the "model" item of loaded state_dict.
Found correct weights in the "model" item of loaded state_dict.
Found correct weights in the "model" item of loaded state_dict.
Found correct weights in the "model" item of loaded state_dict.
2025-01-30 23:08:28,039 INFO => no checkpoint found at 'xx/xx-epoch_244.pth'
2025-01-30 23:08:28,040 INFO => no checkpoint found at 'xx/xx-epoch_244.pth'
2025-01-30 23:08:28,040 INFO => no checkpoint found at 'xx/xx-epoch_244.pth'
2025-01-30 23:08:28,041 INFO => no checkpoint found at 'xx/xx-epoch_244.pth'
[W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
2025-01-30 23:08:33,028 INFO Epoch[1/500] Iter[0/750]. Training Losses, loss_pix: 104.246
2025-01-30 23:08:33,028 INFO Epoch[1/500] Iter[0/750]. Training Losses, loss_pix: 103.945
2025-01-30 23:08:33,028 INFO Epoch[1/500] Iter[0/750]. Training Losses, loss_pix: 106.886
2025-01-30 23:08:33,029 INFO Epoch[1/500] Iter[0/750]. Training Losses, loss_pix: 104.542
2025-01-30 23:08:48,241 INFO Epoch[1/500] Iter[20/750]. Training Losses, loss_pix: 139.518
2025-01-30 23:08:48,241 INFO Epoch[1/500] Iter[20/750]. Training Losses, loss_pix: 79.888
2025-01-30 23:08:48,243 INFO Epoch[1/500] Iter[20/750]. Training Losses, loss_pix: 74.888
2025-01-30 23:08:48,244 INFO Epoch[1/500] Iter[20/750]. Training Losses, loss_pix: 71.626
2025-01-30 23:09:03,501 INFO Epoch[1/500] Iter[40/750]. Training Losses, loss_pix: 46.662
2025-01-30 23:09:03,501 INFO Epoch[1/500] Iter[40/750]. Training Losses, loss_pix: 50.000
2025-01-30 23:09:03,502 INFO Epoch[1/500] Iter[40/750]. Training Losses, loss_pix: 84.729
2025-01-30 23:09:03,502 INFO Epoch[1/500] Iter[40/750]. Training Losses, loss_pix: 88.846
2025-01-30 23:09:18,847 INFO Epoch[1/500] Iter[60/750]. Training Losses, loss_pix: 33.655
2025-01-30 23:09:18,848 INFO Epoch[1/500] Iter[60/750]. Training Losses, loss_pix: 56.818
2025-01-30 23:09:18,849 INFO Epoch[1/500] Iter[60/750]. Training Losses, loss_pix: 38.736
2025-01-30 23:09:18,850 INFO Epoch[1/500] Iter[60/750]. Training Losses, loss_pix: 49.841
2025-01-30 23:09:34,081 INFO Epoch[1/500] Iter[80/750]. Training Losses, loss_pix: 34.528
2025-01-30 23:09:34,081 INFO Epoch[1/500] Iter[80/750]. Training Losses, loss_pix: 81.659
2025-01-30 23:09:34,081 INFO Epoch[1/500] Iter[80/750]. Training Losses, loss_pix: 27.556
2025-01-30 23:09:34,084 INFO Epoch[1/500] Iter[80/750]. Training Losses, loss_pix: 38.524
2025-01-30 23:09:49,303 INFO Epoch[1/500] Iter[100/750]. Training Losses, loss_pix: 41.373
2025-01-30 23:09:49,303 INFO Epoch[1/500] Iter[100/750]. Training Losses, loss_pix: 55.972
2025-01-30 23:09:49,304 INFO Epoch[1/500] Iter[100/750]. Training Losses, loss_pix: 14.896
2025-01-30 23:09:49,306 INFO Epoch[1/500] Iter[100/750]. Training Losses, loss_pix: 31.533
2025-01-30 23:10:05,462 INFO Epoch[1/500] Iter[120/750]. Training Losses, loss_pix: 17.740
2025-01-30 23:10:05,462 INFO Epoch[1/500] Iter[120/750]. Training Losses, loss_pix: 25.784
2025-01-30 23:10:05,462 INFO Epoch[1/500] Iter[120/750]. Training Losses, loss_pix: 16.948
2025-01-30 23:10:05,465 INFO Epoch[1/500] Iter[120/750]. Training Losses, loss_pix: 7.412
2025-01-30 23:10:22,359 INFO Epoch[1/500] Iter[140/750]. Training Losses, loss_pix: 19.856
2025-01-30 23:10:22,359 INFO Epoch[1/500] Iter[140/750]. Training Losses, loss_pix: 95.755
2025-01-30 23:10:22,361 INFO Epoch[1/500] Iter[140/750]. Training Losses, loss_pix: 25.095
2025-01-30 23:10:22,362 INFO Epoch[1/500] Iter[140/750]. Training Losses, loss_pix: 25.498
2025-01-30 23:10:39,955 INFO Epoch[1/500] Iter[160/750]. Training Losses, loss_pix: 13.419
2025-01-30 23:10:39,955 INFO Epoch[1/500] Iter[160/750]. Training Losses, loss_pix: 20.907
2025-01-30 23:10:39,956 INFO Epoch[1/500] Iter[160/750]. Training Losses, loss_pix: 18.298
2025-01-30 23:10:39,960 INFO Epoch[1/500] Iter[160/750]. Training Losses, loss_pix: 28.231
2025-01-30 23:10:58,526 INFO Epoch[1/500] Iter[180/750]. Training Losses, loss_pix: 31.708
2025-01-30 23:10:58,526 INFO Epoch[1/500] Iter[180/750]. Training Losses, loss_pix: 26.530
2025-01-30 23:10:58,527 INFO Epoch[1/500] Iter[180/750]. Training Losses, loss_pix: 17.400
2025-01-30 23:10:58,531 INFO Epoch[1/500] Iter[180/750]. Training Losses, loss_pix: 16.572
2025-01-30 23:11:18,317 INFO Epoch[1/500] Iter[200/750]. Training Losses, loss_pix: 24.737
2025-01-30 23:11:18,317 INFO Epoch[1/500] Iter[200/750]. Training Losses, loss_pix: 35.536
2025-01-30 23:11:18,318 INFO Epoch[1/500] Iter[200/750]. Training Losses, loss_pix: 24.866
2025-01-30 23:11:18,321 INFO Epoch[1/500] Iter[200/750]. Training Losses, loss_pix: 12.260
2025-01-30 23:11:37,420 INFO Epoch[1/500] Iter[220/750]. Training Losses, loss_pix: 19.276
2025-01-30 23:11:37,420 INFO Epoch[1/500] Iter[220/750]. Training Losses, loss_pix: 25.449
2025-01-30 23:11:37,421 INFO Epoch[1/500] Iter[220/750]. Training Losses, loss_pix: 23.373
2025-01-30 23:11:37,424 INFO Epoch[1/500] Iter[220/750]. Training Losses, loss_pix: 21.741
2025-01-30 23:11:57,442 INFO Epoch[1/500] Iter[240/750]. Training Losses, loss_pix: 30.334
2025-01-30 23:11:57,442 INFO Epoch[1/500] Iter[240/750]. Training Losses, loss_pix: 14.267
2025-01-30 23:11:57,442 INFO Epoch[1/500] Iter[240/750]. Training Losses, loss_pix: 15.777
2025-01-30 23:11:57,446 INFO Epoch[1/500] Iter[240/750]. Training Losses, loss_pix: 13.934
2025-01-30 23:12:17,771 INFO Epoch[1/500] Iter[260/750]. Training Losses, loss_pix: 28.562
2025-01-30 23:12:17,771 INFO Epoch[1/500] Iter[260/750]. Training Losses, loss_pix: 25.206
2025-01-30 23:12:17,771 INFO Epoch[1/500] Iter[260/750]. Training Losses, loss_pix: 10.646
2025-01-30 23:12:17,777 INFO Epoch[1/500] Iter[260/750]. Training Losses, loss_pix: 10.856
2025-01-30 23:12:38,831 INFO Epoch[1/500] Iter[280/750]. Training Losses, loss_pix: 35.183
2025-01-30 23:12:38,831 INFO Epoch[1/500] Iter[280/750]. Training Losses, loss_pix: 10.211
2025-01-30 23:12:38,831 INFO Epoch[1/500] Iter[280/750]. Training Losses, loss_pix: 28.263
2025-01-30 23:12:38,833 INFO Epoch[1/500] Iter[280/750]. Training Losses, loss_pix: 18.090
2025-01-30 23:13:00,413 INFO Epoch[1/500] Iter[300/750]. Training Losses, loss_pix: 25.030
2025-01-30 23:13:00,413 INFO Epoch[1/500] Iter[300/750]. Training Losses, loss_pix: 22.602
2025-01-30 23:13:00,415 INFO Epoch[1/500] Iter[300/750]. Training Losses, loss_pix: 13.231
2025-01-30 23:13:00,416 INFO Epoch[1/500] Iter[300/750]. Training Losses, loss_pix: 19.964
2025-01-30 23:13:22,138 INFO Epoch[1/500] Iter[320/750]. Training Losses, loss_pix: 6.547
2025-01-30 23:13:22,138 INFO Epoch[1/500] Iter[320/750]. Training Losses, loss_pix: 14.633
2025-01-30 23:13:22,138 INFO Epoch[1/500] Iter[320/750]. Training Losses, loss_pix: 7.163
2025-01-30 23:13:22,140 INFO Epoch[1/500] Iter[320/750]. Training Losses, loss_pix: 15.571
2025-01-30 23:13:44,506 INFO Epoch[1/500] Iter[340/750]. Training Losses, loss_pix: 20.822
2025-01-30 23:13:44,506 INFO Epoch[1/500] Iter[340/750]. Training Losses, loss_pix: 17.056
2025-01-30 23:13:44,506 INFO Epoch[1/500] Iter[340/750]. Training Losses, loss_pix: 19.136
2025-01-30 23:13:44,510 INFO Epoch[1/500] Iter[340/750]. Training Losses, loss_pix: 11.685
2025-01-30 23:14:06,458 INFO Epoch[1/500] Iter[360/750]. Training Losses, loss_pix: 11.883
2025-01-30 23:14:06,459 INFO Epoch[1/500] Iter[360/750]. Training Losses, loss_pix: 15.026
2025-01-30 23:14:06,460 INFO Epoch[1/500] Iter[360/750]. Training Losses, loss_pix: 33.659
2025-01-30 23:14:06,463 INFO Epoch[1/500] Iter[360/750]. Training Losses, loss_pix: 15.400
2025-01-30 23:14:28,715 INFO Epoch[1/500] Iter[380/750]. Training Losses, loss_pix: 18.810
2025-01-30 23:14:28,715 INFO Epoch[1/500] Iter[380/750]. Training Losses, loss_pix: 10.593
2025-01-30 23:14:28,716 INFO Epoch[1/500] Iter[380/750]. Training Losses, loss_pix: 18.225
2025-01-30 23:14:28,718 INFO Epoch[1/500] Iter[380/750]. Training Losses, loss_pix: 12.087
2025-01-30 23:14:50,923 INFO Epoch[1/500] Iter[400/750]. Training Losses, loss_pix: 9.107
2025-01-30 23:14:50,924 INFO Epoch[1/500] Iter[400/750]. Training Losses, loss_pix: 13.522
2025-01-30 23:14:50,925 INFO Epoch[1/500] Iter[400/750]. Training Losses, loss_pix: 25.991
2025-01-30 23:14:50,926 INFO Epoch[1/500] Iter[400/750]. Training Losses, loss_pix: 226.189
2025-01-30 23:15:13,901 INFO Epoch[1/500] Iter[420/750]. Training Losses, loss_pix: 18.309
2025-01-30 23:15:13,901 INFO Epoch[1/500] Iter[420/750]. Training Losses, loss_pix: 18.238
2025-01-30 23:15:13,901 INFO Epoch[1/500] Iter[420/750]. Training Losses, loss_pix: 10.009
2025-01-30 23:15:13,904 INFO Epoch[1/500] Iter[420/750]. Training Losses, loss_pix: 43.536
2025-01-30 23:15:37,102 INFO Epoch[1/500] Iter[440/750]. Training Losses, loss_pix: 32.735
2025-01-30 23:15:37,103 INFO Epoch[1/500] Iter[440/750]. Training Losses, loss_pix: 30.946
2025-01-30 23:15:37,104 INFO Epoch[1/500] Iter[440/750]. Training Losses, loss_pix: 16.624
2025-01-30 23:15:37,105 INFO Epoch[1/500] Iter[440/750]. Training Losses, loss_pix: 42.170
2025-01-30 23:15:59,295 INFO Epoch[1/500] Iter[460/750]. Training Losses, loss_pix: 7.830
2025-01-30 23:15:59,295 INFO Epoch[1/500] Iter[460/750]. Training Losses, loss_pix: 34.940
2025-01-30 23:15:59,296 INFO Epoch[1/500] Iter[460/750]. Training Losses, loss_pix: 29.512
2025-01-30 23:15:59,298 INFO Epoch[1/500] Iter[460/750]. Training Losses, loss_pix: 21.797
2025-01-30 23:16:22,328 INFO Epoch[1/500] Iter[480/750]. Training Losses, loss_pix: 23.224
2025-01-30 23:16:22,328 INFO Epoch[1/500] Iter[480/750]. Training Losses, loss_pix: 27.806
2025-01-30 23:16:22,328 INFO Epoch[1/500] Iter[480/750]. Training Losses, loss_pix: 15.378
2025-01-30 23:16:22,331 INFO Epoch[1/500] Iter[480/750]. Training Losses, loss_pix: 8.479
2025-01-30 23:16:45,357 INFO Epoch[1/500] Iter[500/750]. Training Losses, loss_pix: 54.381
2025-01-30 23:16:45,357 INFO Epoch[1/500] Iter[500/750]. Training Losses, loss_pix: 9.336
2025-01-30 23:16:45,358 INFO Epoch[1/500] Iter[500/750]. Training Losses, loss_pix: 15.031
2025-01-30 23:16:45,360 INFO Epoch[1/500] Iter[500/750]. Training Losses, loss_pix: 32.982
2025-01-30 23:17:08,159 INFO Epoch[1/500] Iter[520/750]. Training Losses, loss_pix: 20.007
2025-01-30 23:17:08,159 INFO Epoch[1/500] Iter[520/750]. Training Losses, loss_pix: 9.449
2025-01-30 23:17:08,159 INFO Epoch[1/500] Iter[520/750]. Training Losses, loss_pix: 6.705
2025-01-30 23:17:08,162 INFO Epoch[1/500] Iter[520/750]. Training Losses, loss_pix: 28.453
2025-01-30 23:17:31,320 INFO Epoch[1/500] Iter[540/750]. Training Losses, loss_pix: 85.709
2025-01-30 23:17:31,320 INFO Epoch[1/500] Iter[540/750]. Training Losses, loss_pix: 42.685
2025-01-30 23:17:31,320 INFO Epoch[1/500] Iter[540/750]. Training Losses, loss_pix: 15.957
2025-01-30 23:17:31,323 INFO Epoch[1/500] Iter[540/750]. Training Losses, loss_pix: 84.751
然后就卡住了,不再输出,用gpustat工具查询gpu情况
amax-Super-Server Thu Jan 30 23:23:20 2025 550.120
[0] NVIDIA GeForce RTX 3090 | 62°C, 100 % | 16299 / 24576 MB | amax(16284M) gdm(4M)
[1] ((Unknown Error)) | ?°C, ? % | ? / ? MB | (Not Supported)
[2] NVIDIA GeForce RTX 3090 | 67°C, 100 % | 16299 / 24576 MB | amax(16284M) gdm(4M)
[3] NVIDIA GeForce RTX 3090 | 64°C, 100 % | 16281 / 24576 MB | amax(16266M) gdm(4M)
`
用的是四张3090,显存似乎是足够的,初学者不知道为什么会这样,请大佬拷打

@ZhengPeng7
Copy link
Owner

标准版的BiRefNet在3090是不够训的, 请问是使用了更小的backbone么?

这个问题我确实也没碰到过, 如果该小样本数 (比如这里Iter有750, 只用前400), 然后看看是不是第二个epoch的350/400处卡住呢? 如果是, 那很可能就是内存之类的有瓶颈了, 不然可能是样本有问题.

还有就是版本确保一直哈, 最好是用最新的torch==2.5.1的setting, 因为torch 2.0.1其实有compile下内存逐渐溢出的固有bug.

@whodianjiao
Copy link
Author

感谢作者回答!
很抱歉,回复晚了。
下面是我更新环境(根据最新的requirements)之后的输出,我试过几次,每一次卡住的epoch不是固定的,有的训练了20epoch,有的像下面这样直接会卡住。观察卡住的时候其他正常gpu的状态(我是四卡训练,卡住的时候总是随机一张或者两张卡住),他们的显存占用远远没到会溢出的程度。这个问题很困扰我,因为这个问题一旦发生,我无法把这个进程给kill掉,只能重启电脑,不知道有什么好方法。

  • 我用nvidia-smi命令查看我的cuda版本是12.4,但是用nvcc --version命令得到Cuda compilation tools, release 10.1, V10.1.243,不知道这个会不会造成上面的问题?
  • 我使用的骨干网络是swin-large,也许是骨干网络太大的问题?我晚点会尝试小一点的骨干网络,看看还会不会卡住
  • 在原来的train.sh中to_be_distributed=echo ${nproc_per_node} | awk '{if($e > 0) print "True"; else print "False";}' 这一句中 $e>0 这个是不是有问题
  • 我还有一个问题,尽管我是按照微调的形式来进行训练,但是不知道为什么,模型总会从epoch=0开始,不知道是不是我的设置有问题
(birefnet) amax@amax-Super-Server:~/apj/project/birefnet/main/codes/dis/BiRefNet$ bash train_test.sh custom_task 0,1,2,3 0
Training started at 2025年 02月 24日 星期一 20:52:05 CST
Multi-GPU mode received...
W0224 20:52:07.352000 79004 site-packages/torch/distributed/run.py:793]
W0224 20:52:07.352000 79004 site-packages/torch/distributed/run.py:793] *****************************************
W0224 20:52:07.352000 79004 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0224 20:52:07.352000 79004 site-packages/torch/distributed/run.py:793] *****************************************
[W224 20:52:10.669061487 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())
[W224 20:52:10.674492097 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())
[W224 20:52:10.676259746 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())
[W224 20:52:10.678651741 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())
2025-02-24 20:52:11,023 INFO datasets: load_all=True, compile=True.
2025-02-24 20:52:11,023 INFO Other hyperparameters:
2025-02-24 20:52:11,023 INFO Namespace(resume='weights/BiRefNet-general-epoch_244.pth', epochs=274, ckpt_dir='ckpt/custom_task', dist=False, use_accelerate=True)
batch size: 1
  0%|                                                                                                                                                       | 0/162 [00:00<?, ?it/s]2025-02-24 20:52:11,174 INFO datasets: load_all=True, compile=True.
2025-02-24 20:52:11,174 INFO Other hyperparameters:
2025-02-24 20:52:11,174 INFO Namespace(resume='weights/BiRefNet-general-epoch_244.pth', epochs=274, ckpt_dir='ckpt/custom_task', dist=False, use_accelerate=True)
batch size: 1
  0%|                                                                                                                                                       | 0/162 [00:00<?, ?it/s]2025-02-24 20:52:11,182 INFO datasets: load_all=True, compile=True.
2025-02-24 20:52:11,182 INFO Other hyperparameters:
2025-02-24 20:52:11,182 INFO Namespace(resume='weights/BiRefNet-general-epoch_244.pth', epochs=274, ckpt_dir='ckpt/custom_task', dist=False, use_accelerate=True)
batch size: 1
  0%|                                                                                                                                                       | 0/162 [00:00<?, ?it/s]2025-02-24 20:52:11,215 INFO datasets: load_all=True, compile=True.
2025-02-24 20:52:11,216 INFO Other hyperparameters:
2025-02-24 20:52:11,216 INFO Namespace(resume='weights/BiRefNet-general-epoch_244.pth', epochs=274, ckpt_dir='ckpt/custom_task', dist=False, use_accelerate=True)
batch size: 1
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:32<00:00,  5.02it/s]
162 batches of train dataloader 003+002+004 have been created.
 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏| 161/162 [00:32<00:00,  4.98it/s]
162 batches of train dataloader 003+002+004 have been created.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:32<00:00,  5.01it/s]
162 batches of train dataloader 003+002+004 have been created.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:32<00:00,  4.99it/s]
162 batches of train dataloader 003+002+004 have been created.
Found correct weights in the "model" item of loaded state_dict.
Found correct weights in the "model" item of loaded state_dict.
Found correct weights in the "model" item of loaded state_dict.
2025-02-24 20:52:45,757 INFO => no checkpoint found at 'weights/BiRefNet-general-epoch_244.pth'
Found correct weights in the "model" item of loaded state_dict.
2025-02-24 20:52:45,910 INFO => no checkpoint found at 'weights/BiRefNet-general-epoch_244.pth'
2025-02-24 20:52:45,923 INFO => no checkpoint found at 'weights/BiRefNet-general-epoch_244.pth'
2025-02-24 20:52:46,136 INFO => no checkpoint found at 'weights/BiRefNet-general-epoch_244.pth'
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/amax/apj/project/birefnet/main/codes/dis/BiRefNet/train.py", line 255, in <module>
[rank1]:     main()
[rank1]:   File "/home/amax/apj/project/birefnet/main/codes/dis/BiRefNet/train.py", line 240, in main
[rank1]:     train_loss = trainer.train_epoch(epoch)
[rank1]:   File "/home/amax/apj/project/birefnet/main/codes/dis/BiRefNet/train.py", line 217, in train_epoch
[rank1]:     self._train_batch(batch)
[rank1]:   File "/home/amax/apj/project/birefnet/main/codes/dis/BiRefNet/train.py", line 170, in _train_batch
[rank1]:     scaled_preds, class_preds_lst = self.model(inputs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank1]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank1]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank1]:     return model_forward(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank1]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/amax/apj/project/birefnet/main/codes/dis/BiRefNet/models/birefnet.py", line 126, in forward
[rank1]:     def forward(self, x):
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped
[rank1]:     return self._wrapped_call(self, *args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/fx/graph_module.py", line 361, in __call__
[rank1]:     raise e
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/fx/graph_module.py", line 348, in __call__
[rank1]:     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "<eval_with_key>.497", line 605, in forward
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_dynamo/backends/distributed.py", line 154, in forward
[rank1]:     x = self.submod(*args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
[rank1]:     return compiled_fn(full_args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 308, in runtime_wrapper
[rank1]:     all_outs = call_func_at_runtime_with_args(
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 98, in g
[rank1]:     return f(*args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1525, in forward
[rank1]:     fw_outs = call_func_at_runtime_with_args(
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 579, in wrapper
[rank1]:     return compiled_fn(runtime_args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
[rank1]:     return compiled_fn(runtime_args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
[rank1]:     outs = compiled_fn(args)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
[rank1]:     return self.current_callable(inputs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/utils.py", line 1977, in run
[rank1]:     return model(new_inputs)
[rank1]:   File "/tmp/torchinductor_amax/ww/cwwtdiin3cvnbr3syqec7o3zv4ndtnk23vabnwhjgb3vcjblvw3i.py", line 1057, in call
[rank1]:     triton_poi_fused_clone_5.run(buf16, primals_8, primals_9, buf17, 768, 4096, grid=grid(768, 4096), stream=stream1)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 836, in run
[rank1]:     self.autotune_to_one_config(*args, grid=grid, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 729, in autotune_to_one_config
[rank1]:     timings = self.benchmark_all_configs(*args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 704, in benchmark_all_configs
[rank1]:     timings = {
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 705, in <dictcomp>
[rank1]:     launcher: self.bench(launcher, *args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 675, in bench
[rank1]:     return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
[rank1]:     return fn(self, *args, **kwargs)
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 201, in benchmark_gpu
[rank1]:     return self.triton_do_bench(_callable, **kwargs, return_mode="median")
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/triton/testing.py", line 151, in do_bench
[rank1]:     di.synchronize()
[rank1]:   File "/home/amax/.conda/envs/birefnet/lib/python3.10/site-packages/torch/cuda/__init__.py", line 954, in synchronize
[rank1]:     return torch._C._cuda_synchronize()
[rank1]: RuntimeError: CUDA error: unspecified launch failure
[rank1]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank1]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank1]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank2]:[W224 20:56:51.679552824 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank0]:[W224 20:56:52.149261875 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank3]:[W224 20:56:53.337508501 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())

下面是我的config.pytrain.sh以及我的文件结构
config.py

import os
import math


class Config():
    def __init__(self) -> None:
        # PATH settings
        # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
        self.sys_home_dir = [os.path.expanduser('~'), '/home/amax/apj/project/birefnet/main'][1]   # Default, custom 
        # self.data_root_dir = os.path.join(self.sys_home_dir, 'mp/datasets/dis')
        self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')

        # TASK settings
        self.task = ['DIS5K', 'COD', 'HRSOD', 'custom_task', 'General-2K', 'Matting'][3]
        self.testsets = {
            # Benchmarks
            # 'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:1]),
            'DIS5K': ','.join(['DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:1]),
            'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']),
            'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']),
            # Practical use
            'custom_task': ','.join(['005']),
            'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']),
            'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']),
        }[self.task]
        datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')])
        self.training_set = {
            'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
            'COD': 'TR-COD10K+TR-CAMO',
            'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
            'custom_task': datasets_all,
            'General-2K': datasets_all,
            'Matting': datasets_all,
        }[self.task]
        self.prompt4loc = ['dense', 'sparse'][0]

        # Faster-Training settings
        self.load_all = True   # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data.
        self.compile = True                             # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
                                                        #   Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
                                                        # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
                                                        # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
        self.precisionHigh = True

        # MODEL settings
        self.ms_supervision = True
        self.out_ref = self.ms_supervision and True
        self.dec_ipt = True
        self.dec_ipt_split = True
        self.cxt_num = [0, 3][1]    # multi-scale skip connections from encoder
        self.mul_scl_ipt = ['', 'add', 'cat'][2]
        self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2]
        self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1]
        self.dec_blk = ['BasicDecBlk', 'ResBlk'][0]

        # TRAINING settings
        self.batch_size = 1
        self.finetune_last_epochs = [
            0,
            {
                'DIS5K': -40,
                'COD': -20,
                'HRSOD': -20,
                'custom_task': -20,
                'General-2K': -20,
                'Matting': -20,
            }[self.task]
        ][1]    # choose 0 to skip
        self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4)     # DIS needs high lr to converge faster. Adapt the lr linearly
        self.size = (1024, 1024) if self.task not in ['General-2K'] else (2560, 1440)   # wid, hei
        self.num_workers = max(4, self.batch_size)          # will be decrease to min(it, batch_size) at the initialization of the data_loader

        # Backbone settings
        self.bb = [
            'vgg16', 'vgg16bn', 'resnet50',         # 0, 1, 2
            'swin_v1_t', 'swin_v1_s',               # 3, 4
            'swin_v1_b', 'swin_v1_l',               # 5-bs9, 6-bs4
            'pvt_v2_b0', 'pvt_v2_b1',               # 7, 8
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][6]
        self.lateral_channels_in_collection = {
            'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
            'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
            'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
            'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96],
            'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64],
        }[self.bb]
        if self.mul_scl_ipt == 'cat':
            self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection]
        self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else []

        # MODEL settings - inactive
        self.lat_blk = ['BasicLatBlk'][0]
        self.dec_channels_inter = ['fixed', 'adap'][0]
        self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0]
        self.progressive_ref = self.refine and True
        self.ender = self.progressive_ref and False
        self.scale = self.progressive_ref and 2
        self.auxiliary_classification = False       # Only for DIS5K, where class labels are saved in `dataset.py`.
        self.refine_iteration = 1
        self.freeze_bb = False
        self.model = [
            'BiRefNet',
            'BiRefNetC2F',
        ][0]

        # TRAINING settings - inactive
        self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4]
        self.optimizer = ['Adam', 'AdamW'][1]
        self.lr_decay_epochs = [1e5]    # Set to negative N to decay the lr in the last N-th epoch.
        self.lr_decay_rate = 0.5
        # Loss
        if self.task in ['Matting']:
            self.lambdas_pix_last = {
                'bce': 30 * 1,
                'iou': 0.5 * 0,
                'iou_patch': 0.5 * 0,
                'mae': 100 * 1,
                'mse': 30 * 0,
                'triplet': 3 * 0,
                'reg': 100 * 0,
                'ssim': 10 * 1,
                'cnt': 5 * 0,
                'structure': 5 * 0,
            }
        elif self.task in ['custom_task', 'General-2K']:
            self.lambdas_pix_last = {
                'bce': 30 * 1,
                'iou': 0.5 * 1,
                'iou_patch': 0.5 * 0,
                'mae': 100 * 1,
                'mse': 30 * 0,
                'triplet': 3 * 0,
                'reg': 100 * 0,
                'ssim': 10 * 1,
                'cnt': 5 * 0,
                'structure': 5 * 0,
            }
        else:
            self.lambdas_pix_last = {
                # not 0 means opening this loss
                # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
                'bce': 30 * 1,          # high performance
                'iou': 0.5 * 1,         # 0 / 255
                'iou_patch': 0.5 * 0,   # 0 / 255, win_size = (64, 64)
                'mae': 30 * 0,
                'mse': 30 * 0,         # can smooth the saliency map
                'triplet': 3 * 0,
                'reg': 100 * 0,
                'ssim': 10 * 1,          # help contours,
                'cnt': 5 * 0,          # help contours
                'structure': 5 * 0,    # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
            }
        self.lambdas_cls = {
            'ce': 5.0
        }

        # PATH settings - inactive
        self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights/cv')
        self.weights = {
            'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'),
            'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
            'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
            'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
            'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
            'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
            'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
            'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
        }

        # Callbacks - inactive
        self.verbose_eval = True
        self.only_S_MAE = False
        self.SDPA_enabled = False    # Bugs. Slower and errors occur in multi-GPUs

        # others
        self.device = [0, 'cpu'][0]     # .to(0) == .to('cuda:0')

        self.batch_size_valid = 1
        self.rand_seed = 7
        run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
        if run_sh_file:
            with open(run_sh_file[0], 'r') as f:
                lines = f.readlines()
                self.save_last = int([l.strip() for l in lines if "'{}')".format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
                self.save_step = int([l.strip() for l in lines if "'{}')".format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])


# Return task for choosing settings in shell scripts.
if __name__ == '__main__':
    import argparse


    parser = argparse.ArgumentParser(description='Only choose one argument to activate.')
    parser.add_argument('--print_task', action='store_true', help='print task name')
    parser.add_argument('--print_testsets', action='store_true', help='print validation set')
    args = parser.parse_args()

    config = Config()
    for arg_name, arg_value in args._get_kwargs():
        if arg_value:
            print(config.__getattribute__(arg_name[len('print_'):]))

train.sh

#!/bin/bash

# Run script
# Settings of training & test for different tasks.
method="$1"
task=$(python3 config.py --print_task)
case "${task}" in
    'DIS5K') epochs=500 && val_last=50 && step=5 ;;
    'COD') epochs=150 && val_last=50 && step=5 ;;
    'HRSOD') epochs=150 && val_last=50 && step=5 ;;
    'General') epochs=150 && val_last=50 && step=5 ;;
    'General-2K') epochs=250 && val_last=30 && step=2 ;;
    'Matting') epochs=150 && val_last=50 && step=5 ;;
    'custom_task') epochs=274 && val_last=50 && step=5 ;;
esac

# Train
devices=$2
nproc_per_node=$(echo ${devices%%,} | grep -o "," | wc -l)

to_be_distributed=`echo ${nproc_per_node} | awk '{if($1 > 0) print "True"; else print "False";}'`

echo Training started at $(date)
if [ ${to_be_distributed} == "True" ]
then
    # Adapt the nproc_per_node by the number of GPUs. Give 8989 as the default value of master_port.
    echo "Multi-GPU mode received..."
    CUDA_VISIBLE_DEVICES=${devices} \
    torchrun --standalone --nproc_per_node $((nproc_per_node+1)) \
    train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \
        --dist ${to_be_distributed} \
        --resume weights/BiRefNet-general-epoch_244.pth \
        --use_accelerate
else
    echo "Single-GPU mode received..."
    CUDA_VISIBLE_DEVICES=${devices} \
    python3 train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \
        --dist ${to_be_distributed} \
        --resume weights/BiRefNet-general-epoch_244.pth \
        --use_accelerate
fi

echo Training finished at $(date)

我的文件结构:
main/code/dis/BiRefNet(从git上拉下来的文件夹)/config.py等
main/weights/cv/swin_large_patch4_window12_384_22kto1k.pth
main/weights/BiRefNet-general-epoch_244.pth
main/datasets/dis/custom_task/002/im
main/datasets/dis/custom_task/002/gt

@ZhengPeng7
Copy link
Owner

关于epoch可以看下这里的epoch是多少:

if os.path.isfile(args.resume):
.
还有你的batch_size怎么是1呢, 这是不会有BN的, 我提供的预训练权重都是有的.

@ZhengPeng7 ZhengPeng7 reopened this Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants