From 0da0bdcf40efaa45e8cac01b1730e8b8ec9b934e Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 25 Mar 2020 20:04:12 +0800 Subject: [PATCH] Fix bug structure output when there is depend whose first input is constant in outputs --- mindspore/ccsrc/pipeline/pipeline.cc | 81 +++++++++++++++------ tests/ut/python/nn/test_structure_output.py | 32 ++++++-- 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 35336e975b8..0c2edfc9c3f 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -725,23 +725,15 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { return BaseRefToPyData(value); } -py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, size_t* count) { - MS_EXCEPTION_IF_NULL(output); +py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) { + MS_EXCEPTION_IF_NULL(cnode_data); + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } - if (!output->isa()) { - ValuePtr value = output->BuildValue(); - if (value != kAnyValue) { - return ValuePtrToPyData(value); - } - if (!output->isa()) { - MS_LOG(EXCEPTION) << "Output can only be tensor except for constants, but got " - << output->BuildValue()->ToString() << "."; - } - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } - auto shape = output->BuildShape(); + if (cnode_data->isa()) { + BaseShapePtr shape = cnode_data->BuildShape(); auto shape_act = shape->cast()->shape(); Tensor tensor_exp = py::cast(data[*count]); if (shape_act != tensor_exp.shape()) { @@ -751,16 +743,58 @@ py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, return data[(*count)++]; } - auto tuple_output = output->cast(); - AbstractBasePtrList elements = tuple_output->elements(); - size_t size = elements.size(); + if (!cnode_data->isa()) { + MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could " + << "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString() + << "."; + } + auto data_tp = cnode_data->cast(); + auto elements = data_tp->elements(); + size_t size = data_tp->size(); py::tuple tp = py::tuple(size); for (size_t i = 0; i < size; i++) { - tp[i] = StructureOutput(elements[i], data, count); + tp[i] = ExtractGeneralCnodeRet(elements[i], data, count); } return std::move(tp); } +py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) { + MS_EXCEPTION_IF_NULL(output_node); + + if (output_node->isa()) { + return ValuePtrToPyData(GetValueNode(output_node)); + } + + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } + if (output_node->isa()) { + return data[(*count)++]; + } + + auto output_c = output_node->cast(); + if (output_c == nullptr) { + MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got " + << output_node->ToString(); + } + + if (output_c->IsApply(prim::kPrimMakeTuple)) { + auto input_list = output_c->inputs(); + size_t size = input_list.size(); + py::tuple tp = py::tuple(size - 1); + for (size_t i = 1; i < size; i++) { + tp[i - 1] = StructureOutput(input_list[i], data, count); + } + return std::move(tp); + } + if (output_c->IsApply(prim::kPrimDepend)) { + return StructureOutput(output_c->input(1), data, count); + } + + return ExtractGeneralCnodeRet(output_c->abstract(), data, count); +} + std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::vector& inputs, const std::string& phase) { std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); @@ -806,11 +840,10 @@ std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::ve std::shared_ptr ret = nullptr; #ifdef ENABLE_GE - AnfNodePtr root = graph->get_return(); - MS_EXCEPTION_IF_NULL(root); - AbstractBasePtr output = root->abstract(); + AnfNodePtr output_node = graph->get_return()->input(1); + MS_EXCEPTION_IF_NULL(output_node); size_t count = 0; - py::object oj = StructureOutput(output, outputs, &count); + py::object oj = StructureOutput(output_node, outputs, &count); ret = std::make_shared(oj); #else if (outputs.size() == 1) { diff --git a/tests/ut/python/nn/test_structure_output.py b/tests/ut/python/nn/test_structure_output.py index eb2722878aa..f5f6d77a670 100644 --- a/tests/ut/python/nn/test_structure_output.py +++ b/tests/ut/python/nn/test_structure_output.py @@ -236,7 +236,7 @@ def test_soft(): def __init__(self): super(SoftmaxCrossEntropyWithLogitsNet, self).__init__() self.soft = P.SoftmaxCrossEntropyWithLogits() - self.value = (Tensor(np.zeros((2,)).astype(np.float32)), Tensor(np.ones((2,)).astype(np.float32))) + self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32))) def construct(self, x, y, z): xx = x + y @@ -246,8 +246,30 @@ def test_soft(): ret = (ret, self.value) return ret - input1 = Tensor(np.zeros((2,)).astype(np.float32)) - input2 = Tensor(np.ones((2,)).astype(np.float32)) - input3 = Tensor((np.ones((2,)) + np.ones((2,))).astype(np.float32)) + input1 = Tensor(np.zeros((2, 2)).astype(np.float32)) + input2 = Tensor(np.ones((2, 2)).astype(np.float32)) + input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32)) net = SoftmaxCrossEntropyWithLogitsNet() - print(net(input1, input2, input3)) + net(input1, input2, input3) + + +def test_const_depend(): + class ConstDepend(Cell): + def __init__(self): + super(ConstDepend, self).__init__() + self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32))) + self.soft = P.SoftmaxCrossEntropyWithLogits() + self.depend = depend + + def construct(self, x, y, z): + ret = x + y + ret = ret * z + ret = self.depend(self.value, ret) + ret = (ret, self.soft(x, y)) + return ret + + input1 = Tensor(np.zeros((2, 2)).astype(np.float32)) + input2 = Tensor(np.ones((2, 2)).astype(np.float32)) + input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32)) + net = ConstDepend() + net(input1, input2, input3)