From a412cac763f954d019d52db9a8175e70a2137a99 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Tue, 23 Feb 2021 10:02:42 +0800 Subject: [PATCH] add common interface for control sink --- .../backend/session/anf_runtime_algorithm.cc | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 2c815f81f4..f3a4ca15c7 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -112,20 +112,32 @@ AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) { return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad); } +// Convert: +// a = former(xxx) +// b = latter(x, xxx) +// To: +// a = former(xxx) +// d1 = Depend(x, a) +// b = latter(d1, xxx) +// ... +// out = Depend(out, latter) void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) { if (latter->isa()) { - auto latter_input = latter->cast()->input(kFirstDataInputIndex); + auto latter_cnode = latter->cast(); + constexpr size_t inputsize = 2; + constexpr size_t kFirstDataInputIndex = 1; + if (latter_cnode->inputs().size() < inputsize) { + return; + } + auto latter_input = latter_cnode->input(kFirstDataInputIndex); auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former}); - - auto mgr = kg->manager(); - mgr->SetEdge(latter, kFirstDataInputIndex, depend1); + latter_cnode->set_input(kFirstDataInputIndex, depend1); auto return_node = kg->get_return(); MS_EXCEPTION_IF_NULL(return_node); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_node->input(kFirstDataInputIndex), latter}; - auto depend2 = kg->NewCNode(inputs); - mgr->SetEdge(return_node, kFirstDataInputIndex, depend2); + auto depend2 = kg->NewCNode({NewValueNode(prim::kPrimDepend), + return_node->cast()->input(kFirstDataInputIndex), latter}); + kg->set_output(depend2); } } -- Gitee