From 8f4cf323f84951d8c192f76ae41c98261128d4a5 Mon Sep 17 00:00:00 2001 From: kingfo Date: Fri, 17 Apr 2020 18:39:57 +0800 Subject: [PATCH] fix cell output issue and vm operator in pynative mode --- mindspore/ccsrc/pynative/pynative_execute.cc | 6 +++--- mindspore/ccsrc/session/anf_runtime_algorithm.cc | 2 +- mindspore/nn/cell.py | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 5620634bcca..6a1ddf6a7e4 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -39,7 +39,7 @@ const char SINGLE_OP_GRAPH[] = "single_op_graph"; // primitive unable to infer value for constant input in PyNative mode -const std::unordered_set vm_operators = {"partial", "depend"}; +const std::unordered_set vm_operators = {"partial", "depend", "make_ref"}; namespace mindspore { namespace pynative { @@ -141,7 +141,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { op_exec_info->op_inputs = py_args; op_exec_info->inputs_mask = args[PY_INPUT_MASK]; if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { - MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask"; + MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; return nullptr; } return op_exec_info; @@ -163,7 +163,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { // get prim and abstract info (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + op_exec_info->abstract->ToString()); - MS_LOG(INFO) << "graph info [" << graph_info << "]"; + MS_LOG(INFO) << "Graph info [" << graph_info << "]"; return graph_info; } diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 44472a9a6f6..0fcb3ce39e8 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -457,7 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_ } else if (tuple_i->isa()) { return tuple_i->type_id(); } else { - MS_LOG(EXCEPTION) << "Not support type " << tuple_i->ToString(); + MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); return tuple_i->type_id(); } } else if (type_ptr->isa()) { diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 088f3f3e57a..5507d12af89 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -140,7 +140,10 @@ class Cell: if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out - return self.construct(*inputs) + output = self.construct(*inputs) + if isinstance(output, Parameter): + output = output.data + return output def __setattr__(self, name, value): cells = self.__dict__.get('_cells')