fix cell output issue and vm operator in pynative mode
This commit is contained in:
parent
6e183fcc0f
commit
8f4cf323f8
|
@ -39,7 +39,7 @@
|
||||||
|
|
||||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||||
// primitive unable to infer value for constant input in PyNative mode
|
// primitive unable to infer value for constant input in PyNative mode
|
||||||
const std::unordered_set<std::string> vm_operators = {"partial", "depend"};
|
const std::unordered_set<std::string> vm_operators = {"partial", "depend", "make_ref"};
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
|
@ -141,7 +141,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
|
||||||
op_exec_info->op_inputs = py_args;
|
op_exec_info->op_inputs = py_args;
|
||||||
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
||||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||||
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return op_exec_info;
|
return op_exec_info;
|
||||||
|
@ -163,7 +163,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) {
|
||||||
// get prim and abstract info
|
// get prim and abstract info
|
||||||
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
|
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
|
||||||
op_exec_info->abstract->ToString());
|
op_exec_info->abstract->ToString());
|
||||||
MS_LOG(INFO) << "graph info [" << graph_info << "]";
|
MS_LOG(INFO) << "Graph info [" << graph_info << "]";
|
||||||
return graph_info;
|
return graph_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -457,7 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
|
||||||
} else if (tuple_i->isa<Number>()) {
|
} else if (tuple_i->isa<Number>()) {
|
||||||
return tuple_i->type_id();
|
return tuple_i->type_id();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Not support type " << tuple_i->ToString();
|
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
|
||||||
return tuple_i->type_id();
|
return tuple_i->type_id();
|
||||||
}
|
}
|
||||||
} else if (type_ptr->isa<Number>()) {
|
} else if (type_ptr->isa<Number>()) {
|
||||||
|
|
|
@ -140,7 +140,10 @@ class Cell:
|
||||||
if context.get_context("mode") == context.GRAPH_MODE:
|
if context.get_context("mode") == context.GRAPH_MODE:
|
||||||
out = self.compile_and_run(*inputs)
|
out = self.compile_and_run(*inputs)
|
||||||
return out
|
return out
|
||||||
return self.construct(*inputs)
|
output = self.construct(*inputs)
|
||||||
|
if isinstance(output, Parameter):
|
||||||
|
output = output.data
|
||||||
|
return output
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
cells = self.__dict__.get('_cells')
|
cells = self.__dict__.get('_cells')
|
||||||
|
|
Loading…
Reference in New Issue