diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 2c815f81f4a094ee028faa76e9847c2059a6c440..f3a4ca15c7e117d459944a0078edea6700b7d3c3 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -112,20 +112,32 @@ AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) { return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad); } +// Convert: +// a = former(xxx) +// b = latter(x, xxx) +// To: +// a = former(xxx) +// d1 = Depend(x, a) +// b = latter(d1, xxx) +// ... +// out = Depend(out, latter) void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) { if (latter->isa()) { - auto latter_input = latter->cast()->input(kFirstDataInputIndex); + auto latter_cnode = latter->cast(); + constexpr size_t inputsize = 2; + constexpr size_t kFirstDataInputIndex = 1; + if (latter_cnode->inputs().size() < inputsize) { + return; + } + auto latter_input = latter_cnode->input(kFirstDataInputIndex); auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former}); - - auto mgr = kg->manager(); - mgr->SetEdge(latter, kFirstDataInputIndex, depend1); + latter_cnode->set_input(kFirstDataInputIndex, depend1); auto return_node = kg->get_return(); MS_EXCEPTION_IF_NULL(return_node); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_node->input(kFirstDataInputIndex), latter}; - auto depend2 = kg->NewCNode(inputs); - mgr->SetEdge(return_node, kFirstDataInputIndex, depend2); + auto depend2 = kg->NewCNode({NewValueNode(prim::kPrimDepend), + return_node->cast()->input(kFirstDataInputIndex), latter}); + kg->set_output(depend2); } }