Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed Precision Support #69

Merged
merged 9 commits into from
Jun 2, 2021
Merged

Conversation

clee-ai
Copy link
Collaborator

@clee-ai clee-ai commented May 24, 2021

I implemented basic half precision support compatible with torch.cuda.amp.autocast. I also annotated the c++ convolution code a bit.

I experimented a lot with resizing tensors to have dimensions of multiples of 8, but it seems like it won't change execution time significantly, so I left that out. With size=400000, batch_size=4, on my 2080ti, I get the following results (I attached the nvprof benchmarks as well):

Default precision (nvprof):

Done in: 15.723482s
max_memory_allocated 3492.76513671875 MB
max_memory_reserved 5024.0 MB

Mixed precision, no optimization (nvprof):

Done in: 14.207781s
max_memory_allocated 1876.51904296875 MB
max_memory_reserved 3284.0 MB

Mixed precision, all mm ops in multiples of 8 (nvprof):

Done in: 14.748600s
max_memory_allocated 1924.66552734375 MB
max_memory_reserved 3274.0 MB

Looking at the nvprof results, it looks like barely any computation time is spent on mm ops anyways:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   43.98%  3.13228s        40  78.307ms  64.375ms  110.07ms  void at::native::batch_norm_backward_kernel
                   20.28%  1.44425s        40  36.106ms  31.259ms  45.500ms  void at::native::batch_norm_collect_statistics_kernel
                    6.36%  453.17ms        40  11.329ms  7.7208ms  17.821ms  void at::native::batch_norm_transform_input_kernel
                    6.30%  449.02ms        20  22.451ms  7.9090ms  39.716ms  cuckooLookupKernel_Multi
                    3.83%  272.89ms        10  27.289ms  26.730ms  29.211ms  void cunn_ClassNLLCriterion_updateOutput_kernel
                    3.22%  229.55ms      2080  110.36us  18.337us  1.8771ms  void gather_kernel

But mixed precision will slightly speed up models and reduce their memory footprint to some degree. Let me know if you have any questions/suggestions. (should address #17 )

@zhijian-liu zhijian-liu linked an issue May 30, 2021 that may be closed by this pull request
@zhijian-liu
Copy link
Contributor

Thank you so much for your great efforts! We were currently running experiments using this new version to verify whether it has any negative influence on the performance. Btw, I'm wondering if it is possible to also support mixed precision for spvoxelize and spdevoxelize. Thanks!

@zhijian-liu zhijian-liu linked an issue May 31, 2021 that may be closed by this pull request
@clee-ai
Copy link
Collaborator Author

clee-ai commented May 31, 2021

Great! Please let me know your results for performance.

It should be no problem adding those functions. If possible, it would be great if you could provide me with a minimum test script that uses these functions, similar to examples/test.py, so that I can make sure it works correctly, if you have one. If not I will try to make my own.

@zhijian-liu
Copy link
Contributor

The inference of SPVNAS should be a pretty good example to test these functions: https://github.com/mit-han-lab/spvnas. Thanks!

@zhijian-liu
Copy link
Contributor

zhijian-liu commented Jun 1, 2021

The large-scale experiments of MinkowskiNet on NuScenes have just finished:

  • The performance of the mixed-precision training is almost the same as that of the full-precision training: 76.43 v.s. 76.78.
  • The speedup is fairly limited: the total training time is reduced from 7.3 hours to 6.4 hours (around 10% reduction).
  • The memory reduction is very significant: the memory usage is reduced from 48.8G to 28.8G (40% reduction).

@clee-ai
Copy link
Collaborator Author

clee-ai commented Jun 1, 2021

Great, I'm glad it works! 10% and 40% was about what I saw as well in my tests.

I added support for insertion and devoxelization in half/double precision with my latest commits. It worked well on my spvnas inference test but I didn't test the backwards functions. Please let me know how it looks in your tests!

@zhijian-liu
Copy link
Contributor

zhijian-liu commented Jun 1, 2021

Thanks for the efforts! I will launch some large-scale experiments to test these functions as well.

UPDATE: The results of SPVNAS are similar to those of MinkowskiNet.

Copy link
Contributor

@zhijian-liu zhijian-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks great! I think it's ready to be merged.

Copy link
Collaborator

@kentang-mit kentang-mit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CCInc for the great effort on mix precision support! I've also been through the changes and believe that this pull request is ready for merging.

@zhijian-liu zhijian-liu merged commit 0d5c9f8 into mit-han-lab:master Jun 2, 2021
@clee-ai
Copy link
Collaborator Author

clee-ai commented Jun 2, 2021

Great, glad to hear it! I will also be happy to help implement SPVNAS architecture search or any other tasks you need, feel free to let me know in an email.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mixed precision or FP16 support! Does torchsparse supports 16-bit training ?
3 participants