Skip to content

Commit

Permalink
Merge pull request #1 from mbencer/mbencer/FixPassingParamsFromParent…
Browse files Browse the repository at this point in the history
…Graph

[ONNX][Loop] Fix passing parameters from parent graph scope to subgraph
  • Loading branch information
itikhono authored Aug 18, 2020
2 parents 72da7ea + 52eb931 commit 864e29a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 10 deletions.
1 change: 1 addition & 0 deletions ngraph/core/include/ngraph/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ namespace ngraph
NGRAPH_API
NodeVector as_node_vector(const OutputVector& values);
/// Returns a ResultVector referencing values.
NGRAPH_API
ResultVector as_result_vector(const OutputVector& values);

/// Alias useful for cloning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ namespace ngraph
void add_provenance_tags(const Node& onnx_node,
const OutputVector& ng_node_vector) const;

protected:
ParameterVector m_parameters;

private:
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
std::unique_ptr<GraphCache> m_cache;
std::vector<Node> m_nodes;
std::vector<ValueInfo> m_inputs;
std::vector<ValueInfo> m_outputs;
ParameterVector m_parameters;
Model* m_model;
};

Expand Down
32 changes: 32 additions & 0 deletions ngraph/frontend/onnx_import/src/core/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <sstream>

#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/provenance.hpp"
#include "onnx_import/core/graph.hpp"
#include "onnx_import/core/node.hpp"
Expand Down Expand Up @@ -291,6 +292,37 @@ namespace ngraph
model,
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
{
std::vector<std::shared_ptr<ngraph::Node>> subgraph_root_nodes;
const auto& outputs = as_result_vector(get_ng_outputs());
for (auto& out : outputs)
{
subgraph_root_nodes.push_back(out);
}
const auto& params = get_ng_parameters();
for (auto& param : params)
{
subgraph_root_nodes.push_back(param);
}
const auto subgraph_nodes = topological_sort(subgraph_root_nodes);

const auto& parent_graph_parameters = parent_graph.get_ng_parameters();
for (const auto& node : subgraph_nodes)
{
if (op::is_parameter(node))
{
const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node);
// not present as subgraph parameter
if (sub_it == m_parameters.end())
{
const auto parent_it = std::find(
parent_graph_parameters.begin(), parent_graph_parameters.end(), node);
if (parent_it != m_parameters.end())
{
m_parameters.push_back(*parent_it);
}
}
}
}
}

} // namespace onnx_import
Expand Down
18 changes: 9 additions & 9 deletions ngraph/frontend/onnx_import/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ namespace ngraph
const auto& graph_outputs = body_graph.get_ng_outputs();
const auto& graph_inputs = body_graph.get_ng_parameters();

CHECK_VALID_NODE(
node,
graph_inputs.size() == loop_carried_dependencies.size() + 2,
"The provided loop body graph inputs size (",
graph_inputs.size(),
"), is not equal to the sum of loop carried dependencies and two mandatory"
" inputs (",
loop_carried_dependencies.size() + 2,
")");
CHECK_VALID_NODE(node,
graph_inputs.size() >= loop_carried_dependencies.size() + 2,
"The provided loop body graph inputs size (",
graph_inputs.size(),
"), is not greater than the sum of loop carried dependencies "
"and two mandatory"
" inputs (",
loop_carried_dependencies.size() + 2,
")");

CHECK_VALID_NODE(node,
graph_outputs.size() >= loop_carried_dependencies.size() + 1,
Expand Down

0 comments on commit 864e29a

Please sign in to comment.