Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx load gpu torchscript and reset device #4330

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
set(pnnx_pass_level0_SRCS
pass_level0/constant_unpooling.cpp
pass_level0/inline_block.cpp
pass_level0/reset_device.cpp
pass_level0/shape_inference.cpp
)

Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ int main(int argc, char** argv)

try
{
mod = torch::jit::load(ptpath);
mod = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
}
catch (const c10::Error& e)
{
Expand Down Expand Up @@ -359,7 +359,7 @@ int main(int argc, char** argv)
fprintf(stderr, "############# pass_level0\n");

std::map<std::string, pnnx::Attribute> foldable_constants;
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);

// g->dump();

Expand Down
7 changes: 5 additions & 2 deletions tools/pnnx/src/pass_level0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@

#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
#include "pass_level0/shape_inference.h"

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
{
inline_block(g, module_operators);

reset_device(g, device);

constant_unpooling(g);

if (!input_tensors.empty())
{
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx

Expand Down
36 changes: 36 additions & 0 deletions tools/pnnx/src/pass_level0/reset_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "reset_device.h"
#include "../pass_level1.h"

namespace pnnx {

void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string& device)
{
for (torch::jit::Node* n : graph->nodes())
{
if (n->kind().toDisplayString() == std::string("aten::to"))
{
if (n->hasNamedInput("device"))
{
torch::jit::Node* device_node = n->namedInput("device")->node();

device_node->s_(torch::jit::attr::value, (device == "gpu") ? "cuda" : "cpu");
}
}
}
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level0/reset_device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include <torch/script.h>

namespace pnnx {

void reset_device(std::shared_ptr<torch::jit::Graph>& graph, const std::string& device);

} // namespace pnnx
7 changes: 5 additions & 2 deletions tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
#include "pass_level0/shape_inference.h"

namespace pnnx {
Expand Down Expand Up @@ -77,7 +78,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc
return false;
}

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
{
// collect all intermediate output tensors
std::vector<std::unordered_set<std::string> > more_value_names;
Expand Down Expand Up @@ -150,13 +151,15 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::

// auto mod2 = mod.deepcopy();

torch::jit::Module mod2 = torch::jit::load(ptpath);
torch::jit::Module mod2 = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
mod2.eval();

auto graph2 = mod2.get_method("forward").graph();

inline_block(graph2, module_operators);

reset_device(graph2, device);

constant_unpooling(graph2);

std::vector<torch::jit::Value*> values2;
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace pnnx {

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx