!7671 refractor pynative

Merge pull request !7671 from lianliguang/refractor-pynative
This commit is contained in:
mindspore-ci-bot 2020-10-28 14:46:16 +08:00 committed by Gitee
commit ada2f4cf1d
3 changed files with 100 additions and 78 deletions

View File

@ -479,14 +479,18 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
if (base_shape->isa<abstract::Shape>()) {
if (output_idx == 0) {
return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
}
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
<< ".";
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (output_idx >= tuple_shape->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
<< ".";
<< " node:" << node->DebugString() << ".";
}
auto b_shp = (*tuple_shape)[output_idx];
if (b_shp->isa<abstract::Shape>()) {
@ -495,13 +499,14 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
return std::vector<size_t>();
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
<< "node :" << node->DebugString() << ".";
}
} else if (base_shape->isa<abstract::NoShape>()) {
return std::vector<size_t>();
}
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString();
<< base_shape->ToString() << " node : " << node->DebugString();
}
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {

View File

@ -742,6 +742,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
MS_EXCEPTION_IF_NULL(op_exec_info);
CNodePtr cnode = nullptr;
std::vector<AnfNodePtr> inputs;
@ -750,8 +751,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs.push_back(NewValueNode(prim));
size_t size = op_exec_info->op_inputs.size();
auto sig_size = signature.size();
auto size = op_exec_info->op_inputs.size();
// ignore signature for cast op
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
@ -759,48 +760,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
}
bool is_cast_op = (op_exec_info->op_name == "Cast");
if (!is_cast_op) {
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
auto sig = SignatureEnumRW::kRWDefault;
if (sig_size > 0) {
sig = signature[i].rw;
}
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
<< std::string(py::repr(obj));
// mix precision for non param
bool is_cast = false;
py::object cast_output;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor && meta_tensor->is_parameter()) {
if (sig != SignatureEnumRW::kRWRead) {
continue;
}
}
// redundant cast call if the tensor is a const Tensor.
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
// mix precision for tuple inputs
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
}
if (is_cast) {
op_exec_info->op_inputs[i] = cast_output;
}
}
std::vector<SignatureEnumDType> dtypes;
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
std::map<SignatureEnumDType, TypeId> dst_types;
if (has_dtype_sig) {
// fetch info for implicit cast
auto type_indexes = GetTypeIndex(dtypes);
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
}
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
RunParameterAutoMixPrecisionCast(op_exec_info);
}
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
for (size_t i = 0; i < size; i++) {
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
const auto &obj = op_exec_info->op_inputs[i];
bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
@ -1065,32 +1028,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
return cell_id;
}
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
#if (!defined ENABLE_GE)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!context::IsTsdOpened(ms_context)) {
if (!context::OpenTsd(ms_context)) {
MS_LOG(EXCEPTION) << "Open tsd failed";
}
}
if (ms_context->backend_policy() == "ms") {
backend_policy = kMsBackendMsPrior;
} else {
backend_policy = kMsBackendVmOnly;
}
#else
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
context::PynativeInitGe(ms_context);
backend_policy = kMsBackendGeOnly;
#endif
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
backend_policy = kMsBackendVmOnly;
}
py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
auto backend_policy = InitEnv(op_exec_info);
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
// returns a null py::tuple on error
py::tuple err_ret(0);
@ -1113,7 +1052,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
std::vector<bool> op_masks;
op_exec_info = GenerateOpExecInfo(args);
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
return RunOpInner(op_exec_info);
return RunOpWithInitBackendPolicy(op_exec_info);
}
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
bool is_find = false;
@ -1171,7 +1110,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
(void)GetOpId(op_exec_info);
}
auto result = RunOpInner(op_exec_info);
auto result = RunOpWithInitBackendPolicy(op_exec_info);
py::object out_real = result;
if (result.size() == 1) {
MS_LOG(DEBUG) << "MakeCnode out size is one.";
@ -1798,6 +1737,81 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
}
MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
#if (!defined ENABLE_GE)
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!context::IsTsdOpened(ms_context)) {
if (!context::OpenTsd(ms_context)) {
MS_LOG(EXCEPTION) << "Open tsd failed";
}
}
if (ms_context->backend_policy() == "ms") {
backend_policy = kMsBackendMsPrior;
} else {
backend_policy = kMsBackendVmOnly;
}
#else
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
context::PynativeInitGe(ms_context);
backend_policy = kMsBackendGeOnly;
#endif
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
backend_policy = kMsBackendVmOnly;
}
return backend_policy;
}
void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
size_t size = op_exec_info->op_inputs.size();
auto prim = op_exec_info->py_primitive;
const auto &signature = prim->signatures();
auto sig_size = signature.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
auto sig = SignatureEnumRW::kRWDefault;
if (sig_size > 0) {
sig = signature[i].rw;
}
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
<< std::string(py::repr(obj));
// mix precision for non param
bool is_cast = false;
py::object cast_output;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor && meta_tensor->is_parameter()) {
if (sig != SignatureEnumRW::kRWRead) {
continue;
}
}
// redundant cast call if the tensor is a const Tensor.
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
// mix precision for tuple inputs
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
}
if (is_cast) {
op_exec_info->op_inputs[i] = cast_output;
}
}
std::vector<SignatureEnumDType> dtypes;
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
std::map<SignatureEnumDType, TypeId> dst_types;
if (has_dtype_sig) {
// fetch info for implicit cast
auto type_indexes = GetTypeIndex(dtypes);
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
}
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")

View File

@ -68,6 +68,13 @@ struct GraphInfo {
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
private:
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
public:
static std::shared_ptr<PynativeExecutor> GetInstance() {
std::lock_guard<std::mutex> i_lock(instance_lock_);
@ -117,9 +124,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].param_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const std::string &id, const ValuePtr &value,
@ -137,7 +141,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetTupleParam(const py::object &obj, const AnfNodePtr &para_node, std::vector<int> idx);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
py::tuple RunOpInner(const py::args &args);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
~PynativeExecutor();