Modify code to support dynamic graph.

This commit is contained in:
rick_sanchez 2020-05-26 09:14:40 +08:00 committed by kpy
parent 72fd41786c
commit e2a322b6b7
34 changed files with 673 additions and 72 deletions

View File

@ -19,14 +19,15 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_obj_instance, generate_scope, from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type, get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol, create_slice_obj, get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_default_input, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj) is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj'] 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'create_ellipsis_obj']

View File

@ -209,6 +209,14 @@ def get_object_key(obj):
obj_id = instance_id + obj_id obj_id = instance_id + obj_id
return obj_id, obj_key return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
if isinstance(obj, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args = tuple(convert(x) for x in obj)
return args
return obj
def is_class_member(node): def is_class_member(node):
"""Check the attr is class member variable.""" """Check the attr is class member variable."""
@ -221,6 +229,9 @@ def is_class_member(node):
return True return True
return False return False
def get_obj_id(obj):
"""Get the obj id."""
return str(id(obj))
def get_obj_type(obj): def get_obj_type(obj):
"""Get the obj type.""" """Get the obj type."""

View File

@ -328,9 +328,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
DropEdge(node, index, inp); DropEdge(node, index, inp);
} else { } else {
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
if (inp->func_graph() != nullptr) {
AddFuncGraph(inp->func_graph());
}
if (IsValueNode<FuncGraph>(inp)) { if (IsValueNode<FuncGraph>(inp)) {
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
@ -372,9 +369,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
for (auto &node : acq) { for (auto &node : acq) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr) { if (fg != nullptr) {
AddFuncGraph(fg);
fg->AddNode(node); fg->AddNode(node);
} }
ProcessInputs(node, kIncEdge); ProcessInputs(node, kIncEdge);

View File

@ -28,7 +28,7 @@ namespace py = pybind11;
class ParamValuePy : public ParamValue { class ParamValuePy : public ParamValue {
public: public:
ParamValuePy() : value_(py::none()) {} ParamValuePy() : value_(py::none()) {}
explicit ParamValuePy(py::object value) : value_(value) {} explicit ParamValuePy(const py::object &value) : value_(value) {}
~ParamValuePy() override = default; ~ParamValuePy() override = default;
py::object value() { return value_; } py::object value() { return value_; }

View File

@ -75,7 +75,7 @@ py::function PrimitivePy::GetComputeFunction() {
py::function vm_fn = get_fn(python_obj_); py::function vm_fn = get_fn(python_obj_);
if (py::isinstance<py::none>(vm_fn)) { if (py::isinstance<py::none>(vm_fn)) {
MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
vm_fn = mindspore::GetComputeFunction(Primitive::name()); vm_fn = mindspore::GetComputeFunction(Primitive::name());
} }
return vm_fn; return vm_fn;

View File

@ -81,6 +81,7 @@ Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address_) { : MetaTensor(tensor), device_address_(tensor.device_address_) {
init(tensor.data_, data_type); init(tensor.data_, data_type);
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
id_ = tensor.id();
} }
Tensor &Tensor::operator=(const Tensor &tensor) { Tensor &Tensor::operator=(const Tensor &tensor) {
@ -89,6 +90,7 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
device_address_ = tensor.device_address(); device_address_ = tensor.device_address();
data_ = tensor.data_; data_ = tensor.data_;
id_ = tensor.id();
} }
return *this; return *this;
} }
@ -208,6 +210,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
data_ = input; data_ = input;
} }
dirty_ = true; dirty_ = true;
id_ = std::to_string((uintptr_t)(this));
} }
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) { void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
@ -254,6 +257,7 @@ void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *co
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
break; break;
} }
id_ = std::to_string((uintptr_t)(this));
} }
TypePtr Tensor::SetDtype(const TypePtr type_ptr) { TypePtr Tensor::SetDtype(const TypePtr type_ptr) {

View File

@ -263,9 +263,11 @@ class Tensor : public MetaTensor {
DeviceAddressPtr device_address() const { return device_address_; } DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
py::array data_sync(); py::array data_sync();
std::string id() const { return id_; }
private: private:
bool dirty_{true}; bool dirty_{true};
std::string id_{""};
DeviceAddressPtr device_address_{nullptr}; DeviceAddressPtr device_address_{nullptr};
}; };

View File

@ -501,10 +501,16 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
} }
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, bool applyJ) { const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
auto weights_node = weights;
if (weights == nullptr && !args.empty()) {
weights_node = ret->NewCNode(args);
}
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
@ -537,7 +543,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs.push_back(NewValueNode(1)); inputs.push_back(NewValueNode(1));
AnfNodePtr ptrBprop = ret->NewCNode(inputs); AnfNodePtr ptrBprop = ret->NewCNode(inputs);
doGetGrad(ret, out, ptrBprop, weights, opsTupleItem); doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
return ret; return ret;
} }

View File

@ -129,7 +129,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams, FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
bool applyJ = false); const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
bool sens_param() const { return sens_param_; } bool sens_param() const { return sens_param_; }
bool get_all_; bool get_all_;

View File

@ -285,6 +285,10 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
// and add cast op on other inputs to keep the same type with assigned parameter. // and add cast op on other inputs to keep the same type with assigned parameter.
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {
AnfNodePtr param = params_list[i]; AnfNodePtr param = params_list[i];
if (args_spec_list[i] == nullptr) {
op_inputs.push_back(param);
continue;
}
SignatureEnumRW sig = SignatureEnumRW::kRWDefault; SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
// If sig_size is 0 use defalut. // If sig_size is 0 use defalut.
if (sig_size > 0 && i < sig_size) { if (sig_size > 0 && i < sig_size) {
@ -292,6 +296,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} else if (has_var && i >= sig_size) { } else if (has_var && i >= sig_size) {
sig = signature[sig_size - 1].rw; sig = signature[sig_size - 1].rw;
} }
TypePtr type = args_spec_list[i]->GetTypeTrack(); TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) { if (type && type->type_id() == kObjectTypeRef) {
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {

View File

@ -551,6 +551,10 @@ AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
} }
void DFunctor::CallDoutHoleOnTape() { void DFunctor::CallDoutHoleOnTape() {
if (!is_top_) {
return;
}
// Call dout hole of all adjoint. // Call dout hole of all adjoint.
for (auto &f : func_graph_to_functor_) { for (auto &f : func_graph_to_functor_) {
for (auto &adjoint : f.second->anfnode_to_adjoin_) { for (auto &adjoint : f.second->anfnode_to_adjoin_) {

View File

@ -55,6 +55,8 @@ class DFunctor {
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
// Register functor objects to form a global view. // Register functor objects to form a global view.
void Init(const DFunctorPtr &functor, bool is_top = false); void Init(const DFunctorPtr &functor, bool is_top = false);
bool IsInScope(const AnfNodePtr &node);
// Clear resources. // Clear resources.
static void Clear(); static void Clear();
@ -62,7 +64,6 @@ class DFunctor {
// Map one morphism. // Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph); AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node); bool IsFreeMorphism(const AnfNodePtr &node);
bool IsInScope(const AnfNodePtr &node);
// Map morphism that's not attached to output. // Map morphism that's not attached to output.
void MapFreeMorphism(); void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);

View File

@ -23,7 +23,7 @@
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) { FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto gradkv = func_graph->transforms().find("grad"); auto gradkv = func_graph->transforms().find("grad");
if (gradkv != func_graph->transforms().end()) { if (gradkv != func_graph->transforms().end()) {
@ -46,14 +46,18 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
auto user_defined = f->KUserDefined(func_graph); auto user_defined = f->KUserDefined(func_graph);
if (user_defined != nullptr) { if (user_defined != nullptr) {
multi_graph_sink(user_defined); multi_graph_sink(user_defined);
DFunctor::Clear(); if (is_top) {
DFunctor::Clear();
}
return user_defined; return user_defined;
} }
f->Init(f, true); f->Init(f, is_top);
f->MapObject(); f->MapObject();
f->MapMorphism(); f->MapMorphism();
auto ret = f->k_graph(); auto ret = f->k_graph();
DFunctor::Clear(); if (is_top) {
DFunctor::Clear();
}
multi_graph_sink(ret); multi_graph_sink(ret);
return ret; return ret;
@ -71,5 +75,7 @@ MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
return fg; return fg;
} }
void CleanRes() { DFunctor::Clear(); }
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore

View File

@ -28,9 +28,10 @@ namespace mindspore {
namespace ad { namespace ad {
using ResourcePtr = std::shared_ptr<pipeline::Resource>; using ResourcePtr = std::shared_ptr<pipeline::Resource>;
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources); FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true);
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &);
void CleanRes();
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore

View File

@ -167,7 +167,8 @@ class InlinerBase : public AnfVisitor {
auto params = fg->parameters(); auto params = fg->parameters();
auto old_size = params.size(); auto old_size = params.size();
if (old_size != new_params.size()) { if (old_size != new_params.size()) {
MS_LOG(EXCEPTION) << "Parameter size not match."; MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
<< fg->output()->DebugString(10);
} }
for (size_t i = 0; i < old_size; i++) { for (size_t i = 0; i < old_size; i++) {
(void)mng->Replace(params[i], new_params[i]); (void)mng->Replace(params[i], new_params[i]);

View File

@ -276,6 +276,8 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
static bool IsCtrlSink() { static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance(); auto ms_ctx = MsContext::GetInstance();
std::string device_target = ms_ctx->device_target(); std::string device_target = ms_ctx->device_target();

View File

@ -35,6 +35,7 @@ bool SymbolResolveAction(const ResourcePtr &res);
bool AbstractSpecializeAction(const ResourcePtr &res); bool AbstractSpecializeAction(const ResourcePtr &res);
bool GeOptimizeAction(const ResourcePtr &res); bool GeOptimizeAction(const ResourcePtr &res);
bool VmOptimizeAction(const ResourcePtr &res); bool VmOptimizeAction(const ResourcePtr &res);
bool PynativeOptimizeAction(const ResourcePtr &res);
bool TaskEmitAction(const ResourcePtr &res); bool TaskEmitAction(const ResourcePtr &res);
bool ExecuteAction(const ResourcePtr &res); bool ExecuteAction(const ResourcePtr &res);

View File

@ -32,6 +32,7 @@
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "optimizer/ad/grad.h"
namespace mindspore { namespace mindspore {
namespace parse { namespace parse {
@ -338,6 +339,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env; converted = env;
} else if (py::hasattr(obj, "__parameter__")) {
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ret = ConvertData(to_convert, &converted);
} else { } else {
ret = ConvertOtherObj(obj, &converted); ret = ConvertOtherObj(obj, &converted);
} }

View File

@ -60,6 +60,7 @@ const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key"; const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key";
const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member"; const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type"; const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id";
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type"; const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance"; const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
@ -83,6 +84,7 @@ const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
// define the common name // define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_ITER[] = "iter";

View File

@ -278,5 +278,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_control", ControlGroup}, {"opt_control", ControlGroup},
{"opt_prepare", PrepareGroup}, {"opt_prepare", PrepareGroup},
{"cconv", CconvPass}}; {"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern std::vector<PassItem> kGePasses; extern std::vector<PassItem> kGePasses;
extern std::vector<PassItem> kVmPasses; extern std::vector<PassItem> kVmPasses;
extern std::vector<PassItem> kPynativePasses;
bool CconvPass(const ResourcePtr &res); bool CconvPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res); bool ValidatePass(const ResourcePtr &res);

View File

@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG(INFO) << "End"; MS_LOG(INFO) << "End";
} }
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) { void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) {
std::size_t size = args.size(); std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) { for (std::size_t i = 0; i < size; i++) {
@ -625,7 +625,6 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
arg_list->push_back(converted); arg_list->push_back(converted);
} }
ResourcePtr res = GetResource(phase);
MS_EXCEPTION_IF_NULL(res); MS_EXCEPTION_IF_NULL(res);
auto graph = res->func_graph(); auto graph = res->func_graph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -647,6 +646,10 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
} }
} }
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) {
ProcessVmArgInner(args, GetResource(phase), arg_list);
}
py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
std::size_t size = args.size(); std::size_t size = args.size();
if (!py::isinstance<py::str>(phase)) { if (!py::isinstance<py::str>(phase)) {
@ -874,6 +877,8 @@ void ClearResAtexit() {
compile::ClearConvertCache(); compile::ClearConvertCache();
pipeline::GetMethodMap().clear(); pipeline::GetMethodMap().clear();
pipeline::ExecutorPy::ClearRes(); pipeline::ExecutorPy::ClearRes();
pipeline::ReclaimOptimizer();
pynative::PynativeExecutor::GetInstance()->Clean();
#ifdef ENABLE_GE #ifdef ENABLE_GE
transform::DfGraphManager::GetInstance().ClearGraph(); transform::DfGraphManager::GetInstance().ClearGraph();
transform::DfGraphConvertor::get_adpt_map().clear(); transform::DfGraphConvertor::get_adpt_map().clear();

View File

@ -139,6 +139,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run); const std::vector<int64_t> &input_indexes, bool need_run);
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -22,17 +22,30 @@
#include <unordered_set> #include <unordered_set>
#include <algorithm> #include <algorithm>
#include "ir/param_value_py.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "operator/composite/composite.h"
#include "operator/composite/do_signature.h" #include "operator/composite/do_signature.h"
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/resolve.h"
#include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h" #include "session/session_factory.h"
#include "pre_activate/pass/const_input_to_attr_registry.h" #include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "pipeline/action.h"
#include "pynative/base.h" #include "pynative/base.h"
#include "pybind_api/api_register.h"
#include "vm/transform.h"
#include "optimizer/ad/grad.h"
#include "pipeline/resource.h"
#include "pipeline/pipeline.h"
#include "pipeline/pass.h"
#ifdef ENABLE_GE #ifdef ENABLE_GE
#include "pynative/pynative_execute_ge.h" #include "pynative/pynative_execute_ge.h"
@ -40,21 +53,55 @@
const char SINGLE_OP_GRAPH[] = "single_op_graph"; const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode // primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor"}; const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor", "HookBackward"};
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
static std::shared_ptr<session::SessionBasic> session = nullptr; static std::shared_ptr<session::SessionBasic> session = nullptr;
PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
ResourcePtr PynativeExecutor::resource_;
inline ValuePtr PyAttrValue(const py::object &obj) { inline ValuePtr PyAttrValue(const py::object &obj) {
ValuePtr converted_ret = nullptr; ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
bool converted = parse::ConvertData(obj, &converted_ret); if (!converted_ret) {
if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
} }
return converted_ret; return converted_ret;
} }
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) { std::string GetId(const py::object &obj) {
py::object to_process = obj;
std::string prefix = "";
if (py::isinstance<py::tuple>(to_process)) {
auto p_list = py::cast<py::tuple>(to_process);
to_process = p_list[0];
prefix = "tuple:";
if (!py::isinstance<tensor::Tensor>(to_process)) {
std::string key = "";
for (size_t i = 0; i < p_list.size(); ++i) {
key += std::string(py::str(p_list[i])) + ":";
}
return prefix + key;
}
}
if (py::isinstance<py::int_>(to_process)) {
return prefix + std::string(py::str(to_process));
}
if (py::isinstance<py::float_>(to_process)) {
return prefix + std::string(py::str(to_process));
}
if (py::isinstance<tensor::Tensor>(to_process)) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
return prefix + tensor_ptr->id();
}
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
return py::cast<std::string>(ret);
}
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
auto signature = prim->signatures(); auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
@ -87,7 +134,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
} }
(void)dst_type.insert(std::make_pair(type, m_index)); (void)dst_type.insert(std::make_pair(type, m_index));
} }
py::tuple py_inputs(py_args.size()); py::list py_inputs(py_args.size());
for (size_t i = 0; i < py_args.size(); ++i) { for (size_t i = 0; i < py_args.size(); ++i) {
auto it = dst_type.find(dtypes[i]); auto it = dst_type.find(dtypes[i]);
if (it != dst_type.end() && it->second != i && if (it != dst_type.end() && it->second != i &&
@ -105,12 +152,12 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
return py_inputs; return py_inputs;
} }
void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) { void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
size_t size = py_args.size(); size_t size = py_args.size();
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]); ValuePtr input_value = PyAttrValue(py_args[i]);
if (py::isinstance<tensor::Tensor>(py_args[i])) { if (input_value->isa<tensor::Tensor>()) {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
} else { } else {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
@ -120,6 +167,12 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
op_exec_info->abstract = infer_res; op_exec_info->abstract = infer_res;
} }
py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
return obj_tuple;
}
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) { if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Four args are needed by RunOp"; MS_LOG(ERROR) << "Four args are needed by RunOp";
@ -133,14 +186,19 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (pyobj == nullptr) { if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "pyobj is empty"; MS_LOG(EXCEPTION) << "pyobj is empty";
} }
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]); py::list py_args = ConvertInputs(prim, args[PY_INPUTS]);
// use python infer method // use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, py_args, op_exec_info.get()); PynativeInfer(prim, py_args, op_exec_info.get());
} }
op_exec_info->py_primitive = prim; op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = py_args; size_t input_num = py_args.size();
op_exec_info->op_inputs = py::tuple(input_num);
for (size_t i = 0; i < input_num; ++i) {
auto obj = py_args[i];
op_exec_info->op_inputs[i] = GetTupleObj(obj);
}
op_exec_info->inputs_mask = args[PY_INPUT_MASK]; op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
@ -154,9 +212,13 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
std::string graph_info; std::string graph_info;
// get input tensor info // get input tensor info
for (const auto &input_tensor : input_tensors) { size_t input_num = op_exec_info->op_inputs.size();
MS_EXCEPTION_IF_NULL(input_tensor); for (size_t index = 0; index < input_num; ++index) {
(void)graph_info.append(input_tensor->GetShapeAndDataTypeInfo() + "_"); auto input = op_exec_info->op_inputs[index];
if (py::isinstance<tensor::Tensor>(input)) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
}
} }
// get prim and abstract info // get prim and abstract info
MS_EXCEPTION_IF_NULL(op_exec_info->abstract); MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
@ -171,6 +233,23 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(status); MS_EXCEPTION_IF_NULL(status);
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
if (op_exec_info->op_name == "HookBackward") {
auto op_inputs = op_exec_info->op_inputs;
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
if (py::hasattr(input, "__parameter__")) {
result[i] = py::getattr(input, "data");
} else {
auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
result[i] = new_tensor;
}
}
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
}
auto func = op_exec_info->py_primitive->GetComputeFunction(); auto func = op_exec_info->py_primitive->GetComputeFunction();
if (py::isinstance<py::none>(func)) { if (py::isinstance<py::none>(func)) {
MS_LOG(ERROR) << "VM failed to get func"; MS_LOG(ERROR) << "VM failed to get func";
@ -288,7 +367,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
opt::ConstInputToAttrInfoRegister reg; opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg); bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
size_t input_num = op_run_info->op_inputs.size(); size_t input_num = op_run_info->op_inputs.size();
MS_LOG(INFO) << "py input size: " << input_num;
for (size_t index = 0; index < input_num; ++index) { for (size_t index = 0; index < input_num; ++index) {
// convert const input to attr // convert const input to attr
if (reg_exist && if (reg_exist &&
@ -386,7 +464,56 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return result; return result;
} }
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.size() == 0) {
return nullptr;
}
std::vector<AnfNodePtr> inputs;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
inputs.push_back(NewValueNode(prim));
py::tuple op_masks = args[PY_INPUT_MASK];
py::list op_args = args[PY_INPUTS];
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < op_args.size(); i++) {
auto node = GetInput(op_args[i], op_masks[i]);
args_spec_list.push_back(node->abstract());
inputs.push_back(node);
}
auto cnode = curr_g_->NewCNode(inputs);
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString();
py::object out_real = out;
if (out.size() == 1) {
MS_LOG(DEBUG) << "MakeCnode out size is one.";
out_real = out[0];
}
std::string obj_id = GetId(out_real);
if (py::isinstance<py::tuple>(out_real)) {
auto value = py::cast<py::tuple>(out_real);
if (value.size() > 1) {
for (int i = 0; i < static_cast<int>(value.size()); i++) {
auto value_id = GetId(value[i]);
set_obj_node_map(curr_g_, value_id, cnode, i);
}
}
}
set_obj_node_map(curr_g_, obj_id, cnode);
set_pyobj(curr_g_, obj_id);
return cnode;
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
if (out.second == -1) {
return out.first;
}
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), out.first,
NewValueNode(out.second)};
return curr_g_->NewCNode(tuple_get_item_inputs);
}
py::tuple RunOp(const py::args &args) { py::tuple RunOp(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
py::object result; py::object result;
// returns a null py::tuple on error // returns a null py::tuple on error
py::tuple err_ret(0); py::tuple err_ret(0);
@ -428,10 +555,298 @@ py::tuple RunOp(const py::args &args) {
return err_ret; return err_ret;
} }
MS_LOG(INFO) << "RunOp end"; auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
if (node != nullptr) {
node->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
}
MS_LOG(DEBUG) << "RunOp end";
return result; return result;
} }
void ClearPyNativeSession() { session = nullptr; } void ClearPyNativeSession() { session = nullptr; }
PynativeExecutor::~PynativeExecutor() { Clean(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Newgraph already compiled";
return;
}
auto g = std::make_shared<FuncGraph>();
if (top_g_ == nullptr) {
top_g_ = curr_g_ = g;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
Pushp();
} else {
Pushp();
curr_g_ = g;
}
if (graph_info_map_.count(g) == 0) {
graph_info_map_[g] = GraphInfo();
}
for (size_t i = 0; i < args.size(); i++) {
auto new_param = g->add_parameter();
std::string param_obj = GetId(args[i]);
graph_info_map_[g].param_map[param_obj] = new_param;
}
}
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
AnfNodePtr node = nullptr;
std::string obj_id = GetId(obj);
if (op_mask != nullptr && py::cast<bool>(op_mask)) {
MS_LOG(DEBUG) << "Topgraph free parameter";
// get the parameter name from parameter object
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
std::string param_name = py::cast<std::string>(name_attr);
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);
auto free_param_new = std::make_shared<ParamValuePy>(obj);
free_param->set_default_param(free_param_new);
free_param->debug_info()->set_name(param_name);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
return free_param;
}
return graph_info_map_[df_builder_].param_map[obj_id];
}
// if input is graph output
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
// op(x, y)
node = graph_info_map_[curr_g_].param_map[obj_id];
} else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
// out = op(op1(x, y))
// out = op(cell1(x, y))
// out = op(cell1(x, y)[0])
node = GetObjNode(obj);
} else {
// out = op(x, 1)
ValuePtr converted_ret = nullptr;
parse::ConvertData(obj, &converted_ret);
node = NewValueNode(converted_ret);
set_obj_node_map(curr_g_, obj_id, node);
}
MS_LOG(DEBUG) << "Now getinput " << py::str(obj) << " node " << node->ToString();
return node;
}
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
void PynativeExecutor::Popp() {
if (graph_p_.empty()) {
MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
}
curr_g_ = graph_p_.top();
graph_p_.pop();
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Endgraph already compiled";
return;
}
cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
MS_LOG(ERROR) << "graph has no this out: " << out_id;
return;
}
auto output_node = GetObjNode(out);
curr_g_->set_output(output_node);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(curr_g_));
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (curr_g_ != top_g_) {
Popp();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], py::object());
inputs.push_back(input);
}
auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetId(cell));
set_obj_node_map(curr_g_, GetId(out), out_cnode);
} else {
parse::ResolveFuncGraph(newfg, resource_);
resource_->set_func_graph(newfg);
}
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size();
auto cell_id = GetId(cell);
if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled";
return;
}
MS_LOG(DEBUG) << "GradNet first compiled";
std::vector<AnfNodePtr> new_params;
for (size_t i = 0; i < size; i++) {
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
new_params.push_back(p);
}
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
df_builder_->set_parameters(new_params);
resource_->manager()->SetParameters(df_builder_, new_params);
std::vector<AnfNodePtr> w_args;
if (py::hasattr(weights, "__parameter_tuple__")) {
auto tuple = weights.cast<py::tuple>();
MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t it = 0; it < tuple.size(); ++it) {
auto param = tuple[it];
auto param_id = GetId(param);
AnfNodePtr para_node = nullptr;
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
para_node = graph_info_map_[df_builder_].param_map[param_id];
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
w_args.push_back(ref_node);
}
}
} else {
MS_LOG(EXCEPTION) << "training not paramter_tuple";
}
MS_EXCEPTION_IF_NULL(resource_->func_graph());
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
resource_->set_func_graph(g);
// get the parameters items and add the value to args_spec
abstract::AbstractBasePtrList args_spec;
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
bool succ = parse::ConvertData(args[i], &converted);
if (!succ) {
MS_LOG(EXCEPTION) << "Args convert error";
}
bool broaden = true;
auto abs = abstract::FromValue(converted, broaden);
args_spec.push_back(abs);
auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
param_node->set_abstract(abs);
}
for (const auto &param : df_builder_->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error";
}
args_spec.push_back(ptr);
param_node->set_abstract(ptr);
}
}
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
resource_->set_args_spec(args_spec);
MS_LOG(DEBUG) << "Start opt";
// Create backend and session
resource_->results()[pipeline::kBackend] = compile::CreateBackend();
graph_map_[cell_id] = g;
PynativeOptimizeAction(resource_);
TaskEmitAction(resource_);
ExecuteAction(resource_);
resource_->Clean();
ad::CleanRes();
pipeline::ReclaimOptimizer();
}
void PynativeExecutor::Clear() {
MS_LOG(INFO) << "Clear all res";
top_g_ = curr_g_ = nullptr;
std::stack<FuncGraphPtr>().swap(graph_p_);
graph_info_map_.clear();
}
void PynativeExecutor::Clean() {
graph_map_.clear();
cell_graph_map_.clear();
Clear();
resource_.reset();
}
py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
VectorRef arg_list;
pipeline::ProcessVmArgInner(args, resource_, &arg_list);
if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
!resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
}
compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
if (run == nullptr) {
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
}
std::string backend = MsContext::GetInstance()->backend_policy();
MS_LOG(DEBUG) << "Eval run" << backend;
BaseRef value = (*run)(arg_list);
MS_LOG(DEBUG) << "Run end" << value.ToString();
return BaseRefToPyData(value);
}
FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
const std::vector<AnfNodePtr> &weights, size_t arg_size) {
auto nparam = top_g_->parameters().size();
std::ostringstream ss;
ss << "grad{" << nparam << "}";
df_builder_->set_flags(FUNC_GRAPH_FLAG_CORE, true);
df_builder_->debug_info()->set_name(ss.str());
auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
for (size_t i = 0; i < arg_size; ++i) {
inputs.push_back(df_builder_->parameters()[i]);
}
auto out = df_builder_->NewCNode(inputs);
df_builder_->set_output(out);
resource_->manager()->AddFuncGraph(df);
resource_->manager()->AddFuncGraph(df_builder_);
return df_builder_;
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
.def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
.def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
.def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
.def("clear", &PynativeExecutor::Clear, "pynative clear status.")
.def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
"Executor run function.")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.");
}));
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore

View File

@ -22,23 +22,93 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <mutex>
#include <stack>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pynative/base.h" #include "pynative/base.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "ir/anf.h"
#include "pipeline/resource.h"
#include "operator/composite/composite.h"
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
namespace py = pybind11; namespace py = pybind11;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args); py::tuple RunOp(const py::args &args);
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args);
void ClearPyNativeSession(); void ClearPyNativeSession();
struct GraphInfo {
std::unordered_map<std::string, AnfNodePtr> param_map;
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> obj_node_map;
AnfNodePtr output;
std::vector<std::string> objects;
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
public:
static std::shared_ptr<PynativeExecutor> GetInstance() {
std::lock_guard<std::mutex> i_lock(instance_lock_);
if (executor_ == nullptr) {
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
resource_ = std::make_shared<pipeline::Resource>();
}
return executor_;
}
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void Clear();
void Clean();
bool grad_flag() { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, -1);
}
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);
void Pushp();
void Popp();
FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size);
~PynativeExecutor();
private:
PynativeExecutor();
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static ResourcePtr resource_;
bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore

View File

@ -20,7 +20,7 @@ from collections import OrderedDict
from functools import wraps from functools import wraps
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor from .tensor import Tensor as MsTensor
@ -273,6 +273,34 @@ def _generate_pip_args(obj, *args, method="construct"):
obj.__parse_method__ = parse_method obj.__parse_method__ = parse_method
return args_names, args_list return args_names, args_list
class _PynativeExecutor:
"""
An pynative executor used to compile/manage/run graph.
Returns:
Graph, return the result of pipeline running.
"""
def __init__(self):
self._executor = PynativeExecutor_.get_instance()
def new_graph(self, obj, *args):
self._executor.new_graph(obj, *args)
def end_graph(self, obj, output, *args):
self._executor.end_graph(obj, output, *args)
def grad(self, grad, obj, weights, *args):
self._executor.grad_net(grad, obj, weights, *args)
def clear(self):
self._executor.clear()
def set_grad_flag(self, flag):
self._executor.set_grad_flag(flag)
def __call__(self, *args):
return self._executor(args, "")
class _Executor: class _Executor:
""" """
@ -500,5 +528,6 @@ class _Executor:
_executor = _Executor() _executor = _Executor()
_pynative_exec = _PynativeExecutor()
__all__ = ['ms_function'] __all__ = ['ms_function']

View File

@ -89,7 +89,6 @@ class Tensor(Tensor_):
return hash(id(self)) return hash(id(self))
def __mul__(self, other): def __mul__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(self, other) out = tensor_operator_registry.get('__mul__')(self, other)
return out return out
@ -101,7 +100,6 @@ class Tensor(Tensor_):
return out return out
def __radd__(self, other): def __radd__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, self) out = tensor_operator_registry.get('__add__')(other, self)
return out return out
@ -110,22 +108,18 @@ class Tensor(Tensor_):
return out return out
def __rmul__(self, other): def __rmul__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(other, self) out = tensor_operator_registry.get('__mul__')(other, self)
return out return out
def __truediv__(self, other): def __truediv__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__div__')(self, other) out = tensor_operator_registry.get('__div__')(self, other)
return out return out
def __rtruediv__(self, other): def __rtruediv__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__div__')(other, self) out = tensor_operator_registry.get('__div__')(other, self)
return out return out
def __sub__(self, other): def __sub__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = self.__add__(-other) out = self.__add__(-other)
return out return out
@ -134,7 +128,6 @@ class Tensor(Tensor_):
return out return out
def __rsub__(self, other): def __rsub__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy())) out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
return out return out

View File

@ -19,7 +19,7 @@ from collections import OrderedDict
from mindspore import log as logger from mindspore import log as logger
from .. import context from .. import context
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _executor from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend from .._c_expression import init_backend
@ -60,6 +60,7 @@ class Cell:
self._params = OrderedDict() self._params = OrderedDict()
self._cells = OrderedDict() self._cells = OrderedDict()
self.training = False self.training = False
self.requires_grad = False
self.pynative = False self.pynative = False
self._param_prefix = '' self._param_prefix = ''
self._auto_prefix = auto_prefix self._auto_prefix = auto_prefix
@ -79,6 +80,15 @@ class Cell:
self._backward_hook = None self._backward_hook = None
self.enable_hook = False self.enable_hook = False
self._bprop_debug = False self._bprop_debug = False
self._is_run = False
@property
def is_run(self):
return self._is_run
@is_run.setter
def is_run(self, value):
self._is_run = value
@property @property
def create_time(self): def create_time(self):
@ -192,9 +202,20 @@ class Cell:
out = self.compile_and_run(*inputs) out = self.compile_and_run(*inputs)
return out return out
self.init_parameters_data() self.init_parameters_data()
output = self.construct(*inputs) if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs)
else:
_pynative_exec.set_grad_flag(False)
if self.enable_hook:
output = self._hook_construct(*inputs)
else:
output = self.construct(*inputs)
if isinstance(output, Parameter): if isinstance(output, Parameter):
output = output.data output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs)
self._is_run = True
return output return output
def __setattr__(self, name, value): def __setattr__(self, name, value):
@ -722,6 +743,10 @@ class Cell:
self.add_flags_recursive(**flags) self.add_flags_recursive(**flags)
return self return self
def set_grad(self, mode=True):
self.add_flags_recursive(requires_grad=mode)
return self
def set_train(self, mode=True): def set_train(self, mode=True):
""" """
Sets the cell to training mode. Sets the cell to training mode.
@ -762,9 +787,9 @@ class Cell:
self.add_flags(auto_parallel=True) self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name() self._get_construct_inputs_number_and_name()
def _hook_construct(self, inputs): def _hook_construct(self, *inputs):
"""Hook construct method to replace original construct method when hook function enabled.""" """Hook construct method to replace original construct method when hook function enabled."""
inputs = self._backward_hook(inputs) inputs = self._backward_hook(*inputs)
inputs = self.construct(inputs) inputs = self.construct(inputs)
outputs = self._backward_hook(inputs) outputs = self._backward_hook(inputs)
return outputs return outputs

View File

@ -166,6 +166,7 @@ class TrainOneStepCell(Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer

View File

@ -18,14 +18,16 @@
"""Basic composite operations.""" """Basic composite operations."""
from functools import partial from functools import partial
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.api import ms_function from ...common.api import ms_function, _pynative_exec
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
from ...common.parameter import Parameter from ...common.parameter import Parameter
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
@ -105,14 +107,34 @@ class GradOperation(GradOperation_):
GradOperation_.__init__(self, name, get_all, get_by_list, sens_param) GradOperation_.__init__(self, name, get_all, get_by_list, sens_param)
self.grad_fn = None self.grad_fn = None
self.fn = None self.fn = None
self.need_forward = False
def __call__(self, fn, weights=None): def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn: if self.grad_fn is None or self.fn != fn:
if self.get_by_list: if self.get_by_list:
@ms_function(obj=fn) if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug:
def after_grad(*args): @ms_function(obj=fn)
return grad_(fn, weights)(*args) def after_grad(*args):
return grad_(fn, weights)(*args)
else:
def after_grad(*args):
if fn.is_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
if self.sens_param:
f_args = args[:-1]
fn(*f_args)
else:
fn(*args)
_pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args)
_pynative_exec.clear()
return out
else: else:
@ms_function(obj=fn) @ms_function(obj=fn)
def after_grad(*args): def after_grad(*args):

View File

@ -286,12 +286,6 @@ class HookBackward(PrimitiveWithInfer):
self.register_hook(hook_fn) self.register_hook(hook_fn)
self.cell_id = cell_id self.cell_id = cell_id
def __call__(self, *inputs):
"""run in PyNative mode."""
if len(inputs) == 1:
return inputs[0]
return inputs
def infer_shape(self, *inputs_shape): def infer_shape(self, *inputs_shape):
if len(inputs_shape) == 1: if len(inputs_shape) == 1:
return inputs_shape[0] return inputs_shape[0]

View File

@ -328,15 +328,9 @@ def _run_op(obj, op_name, args):
op_inputs = [] op_inputs = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if hasattr(arg, '__parameter__'): if hasattr(arg, '__parameter__'):
op_inputs.append(arg.default_input)
op_mask[i] = 1 op_mask[i] = 1
elif isinstance(arg, tuple): op_inputs.append(arg)
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x output = real_run_op(obj, op_name, args, tuple(op_mask))
args_ = tuple(convert(x) for x in arg)
op_inputs.append(args_)
else:
op_inputs.append(arg)
output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))
if not output: if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name) raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1: if len(output) == 1:

View File

@ -54,4 +54,4 @@ class Net_Dropout(nn.Cell):
def test_compile_dropout(): def test_compile_dropout():
net = Net_Dropout() net = Net_Dropout()
input_data = Tensor(np.ones([20, 16, 50], dtype=np.float32)) input_data = Tensor(np.ones([20, 16, 50], dtype=np.float32))
_executor.compile(net, input_data) net(input_data)

View File

@ -18,6 +18,7 @@ import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm from .vm_interface import vm
@ -225,7 +226,7 @@ def vm_impl_slice(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P._grad_ops.ConcatOffset) @vm_impl_getters.register(G.ConcatOffset)
def vm_impl_concatOffset(self): def vm_impl_concatOffset(self):
"""Generate vm_impl function for ConcatOffset""" """Generate vm_impl function for ConcatOffset"""