forked from mindspore-Ecosystem/mindspore
!7671 refractor pynative
Merge pull request !7671 from lianliguang/refractor-pynative
This commit is contained in:
commit
ada2f4cf1d
|
@ -479,14 +479,18 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
abstract::BaseShapePtr base_shape = node->Shape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
|
||||
return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
|
||||
if (base_shape->isa<abstract::Shape>()) {
|
||||
if (output_idx == 0) {
|
||||
return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << output_idx
|
||||
<< ".";
|
||||
} else if (base_shape->isa<abstract::TupleShape>()) {
|
||||
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||
if (output_idx >= tuple_shape->size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
|
||||
<< ".";
|
||||
<< " node:" << node->DebugString() << ".";
|
||||
}
|
||||
auto b_shp = (*tuple_shape)[output_idx];
|
||||
if (b_shp->isa<abstract::Shape>()) {
|
||||
|
@ -495,13 +499,14 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
|
|||
return std::vector<size_t>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
|
||||
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
|
||||
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
|
||||
<< "node :" << node->DebugString() << ".";
|
||||
}
|
||||
} else if (base_shape->isa<abstract::NoShape>()) {
|
||||
return std::vector<size_t>();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
|
||||
<< base_shape->ToString();
|
||||
<< base_shape->ToString() << " node : " << node->DebugString();
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
|
|
@ -742,6 +742,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
abstract::AbstractBasePtrList *args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(op_masks);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
CNodePtr cnode = nullptr;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
||||
|
@ -750,8 +751,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
|
||||
size_t size = op_exec_info->op_inputs.size();
|
||||
auto sig_size = signature.size();
|
||||
auto size = op_exec_info->op_inputs.size();
|
||||
// ignore signature for cast op
|
||||
if (sig_size > 0 && sig_size != size) {
|
||||
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
|
||||
|
@ -759,48 +760,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
}
|
||||
bool is_cast_op = (op_exec_info->op_name == "Cast");
|
||||
if (!is_cast_op) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto obj = op_exec_info->op_inputs[i];
|
||||
auto sig = SignatureEnumRW::kRWDefault;
|
||||
if (sig_size > 0) {
|
||||
sig = signature[i].rw;
|
||||
}
|
||||
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
||||
<< std::string(py::repr(obj));
|
||||
// mix precision for non param
|
||||
bool is_cast = false;
|
||||
py::object cast_output;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor && meta_tensor->is_parameter()) {
|
||||
if (sig != SignatureEnumRW::kRWRead) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// redundant cast call if the tensor is a const Tensor.
|
||||
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
// mix precision for tuple inputs
|
||||
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
|
||||
}
|
||||
if (is_cast) {
|
||||
op_exec_info->op_inputs[i] = cast_output;
|
||||
}
|
||||
}
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
|
||||
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
|
||||
std::map<SignatureEnumDType, TypeId> dst_types;
|
||||
if (has_dtype_sig) {
|
||||
// fetch info for implicit cast
|
||||
auto type_indexes = GetTypeIndex(dtypes);
|
||||
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
|
||||
}
|
||||
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
|
||||
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
|
||||
RunParameterAutoMixPrecisionCast(op_exec_info);
|
||||
}
|
||||
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
||||
const auto &obj = op_exec_info->op_inputs[i];
|
||||
bool op_mask = false;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
|
@ -1065,32 +1028,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
|
|||
return cell_id;
|
||||
}
|
||||
|
||||
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
||||
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||
MsBackendPolicy backend_policy;
|
||||
#if (!defined ENABLE_GE)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (!context::IsTsdOpened(ms_context)) {
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(EXCEPTION) << "Open tsd failed";
|
||||
}
|
||||
}
|
||||
if (ms_context->backend_policy() == "ms") {
|
||||
backend_policy = kMsBackendMsPrior;
|
||||
} else {
|
||||
backend_policy = kMsBackendVmOnly;
|
||||
}
|
||||
#else
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
context::PynativeInitGe(ms_context);
|
||||
backend_policy = kMsBackendGeOnly;
|
||||
#endif
|
||||
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
|
||||
backend_policy = kMsBackendVmOnly;
|
||||
}
|
||||
py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
|
||||
auto backend_policy = InitEnv(op_exec_info);
|
||||
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
|
||||
// returns a null py::tuple on error
|
||||
py::tuple err_ret(0);
|
||||
|
@ -1113,7 +1052,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|||
std::vector<bool> op_masks;
|
||||
op_exec_info = GenerateOpExecInfo(args);
|
||||
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
|
||||
return RunOpInner(op_exec_info);
|
||||
return RunOpWithInitBackendPolicy(op_exec_info);
|
||||
}
|
||||
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
|
||||
bool is_find = false;
|
||||
|
@ -1171,7 +1110,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|||
(void)GetOpId(op_exec_info);
|
||||
}
|
||||
|
||||
auto result = RunOpInner(op_exec_info);
|
||||
auto result = RunOpWithInitBackendPolicy(op_exec_info);
|
||||
py::object out_real = result;
|
||||
if (result.size() == 1) {
|
||||
MS_LOG(DEBUG) << "MakeCnode out size is one.";
|
||||
|
@ -1798,6 +1737,81 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
|
|||
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
|
||||
}
|
||||
|
||||
MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
|
||||
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||
MsBackendPolicy backend_policy;
|
||||
#if (!defined ENABLE_GE)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (!context::IsTsdOpened(ms_context)) {
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(EXCEPTION) << "Open tsd failed";
|
||||
}
|
||||
}
|
||||
if (ms_context->backend_policy() == "ms") {
|
||||
backend_policy = kMsBackendMsPrior;
|
||||
} else {
|
||||
backend_policy = kMsBackendVmOnly;
|
||||
}
|
||||
#else
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
context::PynativeInitGe(ms_context);
|
||||
backend_policy = kMsBackendGeOnly;
|
||||
#endif
|
||||
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
|
||||
backend_policy = kMsBackendVmOnly;
|
||||
}
|
||||
return backend_policy;
|
||||
}
|
||||
|
||||
void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
|
||||
size_t size = op_exec_info->op_inputs.size();
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
const auto &signature = prim->signatures();
|
||||
auto sig_size = signature.size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto obj = op_exec_info->op_inputs[i];
|
||||
auto sig = SignatureEnumRW::kRWDefault;
|
||||
if (sig_size > 0) {
|
||||
sig = signature[i].rw;
|
||||
}
|
||||
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "
|
||||
<< std::string(py::repr(obj));
|
||||
// mix precision for non param
|
||||
bool is_cast = false;
|
||||
py::object cast_output;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor && meta_tensor->is_parameter()) {
|
||||
if (sig != SignatureEnumRW::kRWRead) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// redundant cast call if the tensor is a const Tensor.
|
||||
cast_output = DoParamMixPrecisionCast(&is_cast, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
// mix precision for tuple inputs
|
||||
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj);
|
||||
}
|
||||
if (is_cast) {
|
||||
op_exec_info->op_inputs[i] = cast_output;
|
||||
}
|
||||
}
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
|
||||
bool has_dtype_sig = GetSignatureType(prim, &dtypes);
|
||||
std::map<SignatureEnumDType, TypeId> dst_types;
|
||||
if (has_dtype_sig) {
|
||||
// fetch info for implicit cast
|
||||
auto type_indexes = GetTypeIndex(dtypes);
|
||||
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes);
|
||||
}
|
||||
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name;
|
||||
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info);
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
||||
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
||||
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
||||
|
|
|
@ -68,6 +68,13 @@ struct GraphInfo {
|
|||
};
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
private:
|
||||
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
|
||||
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
||||
|
||||
public:
|
||||
static std::shared_ptr<PynativeExecutor> GetInstance() {
|
||||
std::lock_guard<std::mutex> i_lock(instance_lock_);
|
||||
|
@ -117,9 +124,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
|
||||
graph_info_map_[g].param_map[obj] = std::make_pair(node, index);
|
||||
}
|
||||
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
|
||||
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
|
||||
void SaveOpForwardValue(const std::string &id, const ValuePtr &value,
|
||||
|
@ -137,7 +141,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector<int> idx);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
py::tuple RunOpInner(const py::args &args);
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
|
||||
~PynativeExecutor();
|
||||
|
||||
|
|
Loading…
Reference in New Issue