Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Error loading model using Torch TensorRT in Libtorch on Windows #3401

Closed
Mmmyyym opened this issue Feb 17, 2025 · 3 comments
Closed
Labels
bug Something isn't working

Comments

@Mmmyyym
Copy link

Mmmyyym commented Feb 17, 2025

Environment

  • Libtorch 2.5.0.dev (latest nightly) (built with CUDA 12.4)
  • CUDA 12.4
  • TensorRT 10.1.0.27
  • PyTorch 2.4.0+cu124
  • Torch-TensorRT 2.4.0
  • Python 3.12.8
  • Windows 10

Compile Torch-TensorRT with Cmake to generate lib and dll:

Image

Option : Export
If you want to optimize your model ahead-of-time and/or deploy in a C++ environment, Torch-TensorRT provides an export-style workflow that serializes an optimized module. This module can be deployed in PyTorch or with libtorch (i.e. without a Python dependency).

  1. Optimize + serialize
import torch
import torch_tensorrt

model = MyModel().eval().cuda() # define your model here
inputs = [torch.randn((1, 3, 224, 224)).cuda()] # define a list of representative inputs here

trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs) # PyTorch only supports Python runtime for an ExportedProgram. For C++ deployment, use a TorchScript file
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
  1. Deploy
    Deployment in C++:
#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"

auto trt_mod = torch::jit::load("trt.ts");
auto input_tensor = [...]; // fill this with your inputs
auto results = trt_mod.forward({input_tensor});

ERROR
auto trt_mod = torch::jit::load("trt.ts")

Image

Unknown type name '__torch__.torch.classes.tensorrt.Engine':
  File "code/__torch__/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 6
  training : bool
  _is_full_backward_hook : Optional[bool]
  engine : __torch__.torch.classes.tensorrt.Engine
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`your text`~~~~~~~~ <--- HERE
  def forward(self: __torch__.torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule,
    input: Tensor) -> Tensor:
@Mmmyyym Mmmyyym added the bug Something isn't working label Feb 17, 2025
@narendasan
Copy link
Collaborator

What flags did you use to compile? Usually this issue comes from libtorch_tensorrt getting optimized out because there is not a direct reference in most inference programs. Not sure what the flag would be for MVSC but for GCC you do something like -Wl,--no-as-needed to avoid this issue

@Mmmyyym
Copy link
Author

Mmmyyym commented Feb 21, 2025

What flags did you use to compile? Usually this issue comes from libtorch_tensorrt getting optimized out because there is not a direct reference in most inference programs. Not sure what the flag would be for MVSC but for GCC you do something like -Wl,--no-as-needed to avoid this issue

Thank you for the suggestion! The linker optimization issue you mentioned was likely the root cause. Following your advice, I tried a different approach—explicitly loading the DLL via LoadLibraryA instead—and it resolved the problem!

@narendasan
Copy link
Collaborator

Nice, reopen if there are further issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants