forked from mindspore-Ecosystem/mindspore
Modify code to support dynamic graph.
This commit is contained in:
parent
72fd41786c
commit
e2a322b6b7
|
@ -19,14 +19,15 @@ Interfaces for parser module in c++.
|
||||||
from .parser import (Parser, create_obj_instance, generate_scope,
|
from .parser import (Parser, create_obj_instance, generate_scope,
|
||||||
get_bprop_method_of_class, get_class_instance_type,
|
get_bprop_method_of_class, get_class_instance_type,
|
||||||
get_class_member_namespace_symbol, create_slice_obj,
|
get_class_member_namespace_symbol, create_slice_obj,
|
||||||
get_dataclass_attributes, get_dataclass_methods,
|
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||||
get_module_namespace, get_obj_type, get_object_key,
|
get_module_namespace, get_obj_type, get_object_key,
|
||||||
get_parse_method_of_class, get_scope_name,
|
get_default_input, get_parse_method_of_class, get_scope_name,
|
||||||
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
|
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
|
||||||
from .serialize import *
|
from .serialize import *
|
||||||
|
|
||||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
|
'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
|
||||||
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
|
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||||
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
|
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||||
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']
|
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||||
|
'create_slice_obj', 'create_ellipsis_obj']
|
||||||
|
|
|
@ -209,6 +209,14 @@ def get_object_key(obj):
|
||||||
obj_id = instance_id + obj_id
|
obj_id = instance_id + obj_id
|
||||||
return obj_id, obj_key
|
return obj_id, obj_key
|
||||||
|
|
||||||
|
def get_default_input(obj):
|
||||||
|
if hasattr(obj, '__parameter__'):
|
||||||
|
return obj.default_input
|
||||||
|
if isinstance(obj, tuple):
|
||||||
|
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
|
||||||
|
args = tuple(convert(x) for x in obj)
|
||||||
|
return args
|
||||||
|
return obj
|
||||||
|
|
||||||
def is_class_member(node):
|
def is_class_member(node):
|
||||||
"""Check the attr is class member variable."""
|
"""Check the attr is class member variable."""
|
||||||
|
@ -221,6 +229,9 @@ def is_class_member(node):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_obj_id(obj):
|
||||||
|
"""Get the obj id."""
|
||||||
|
return str(id(obj))
|
||||||
|
|
||||||
def get_obj_type(obj):
|
def get_obj_type(obj):
|
||||||
"""Get the obj type."""
|
"""Get the obj type."""
|
||||||
|
|
|
@ -328,9 +328,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
|
||||||
DropEdge(node, index, inp);
|
DropEdge(node, index, inp);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
|
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
|
||||||
if (inp->func_graph() != nullptr) {
|
|
||||||
AddFuncGraph(inp->func_graph());
|
|
||||||
}
|
|
||||||
if (IsValueNode<FuncGraph>(inp)) {
|
if (IsValueNode<FuncGraph>(inp)) {
|
||||||
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
|
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
|
||||||
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
|
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
|
||||||
|
@ -372,9 +369,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
|
||||||
|
|
||||||
for (auto &node : acq) {
|
for (auto &node : acq) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
FuncGraphPtr fg = node->func_graph();
|
auto fg = node->func_graph();
|
||||||
if (fg != nullptr) {
|
if (fg != nullptr) {
|
||||||
AddFuncGraph(fg);
|
|
||||||
fg->AddNode(node);
|
fg->AddNode(node);
|
||||||
}
|
}
|
||||||
ProcessInputs(node, kIncEdge);
|
ProcessInputs(node, kIncEdge);
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace py = pybind11;
|
||||||
class ParamValuePy : public ParamValue {
|
class ParamValuePy : public ParamValue {
|
||||||
public:
|
public:
|
||||||
ParamValuePy() : value_(py::none()) {}
|
ParamValuePy() : value_(py::none()) {}
|
||||||
explicit ParamValuePy(py::object value) : value_(value) {}
|
explicit ParamValuePy(const py::object &value) : value_(value) {}
|
||||||
~ParamValuePy() override = default;
|
~ParamValuePy() override = default;
|
||||||
|
|
||||||
py::object value() { return value_; }
|
py::object value() { return value_; }
|
||||||
|
|
|
@ -75,7 +75,7 @@ py::function PrimitivePy::GetComputeFunction() {
|
||||||
py::function vm_fn = get_fn(python_obj_);
|
py::function vm_fn = get_fn(python_obj_);
|
||||||
|
|
||||||
if (py::isinstance<py::none>(vm_fn)) {
|
if (py::isinstance<py::none>(vm_fn)) {
|
||||||
MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
|
MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
|
||||||
vm_fn = mindspore::GetComputeFunction(Primitive::name());
|
vm_fn = mindspore::GetComputeFunction(Primitive::name());
|
||||||
}
|
}
|
||||||
return vm_fn;
|
return vm_fn;
|
||||||
|
|
|
@ -81,6 +81,7 @@ Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
|
||||||
: MetaTensor(tensor), device_address_(tensor.device_address_) {
|
: MetaTensor(tensor), device_address_(tensor.device_address_) {
|
||||||
init(tensor.data_, data_type);
|
init(tensor.data_, data_type);
|
||||||
dirty_ = tensor.is_dirty();
|
dirty_ = tensor.is_dirty();
|
||||||
|
id_ = tensor.id();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor &Tensor::operator=(const Tensor &tensor) {
|
Tensor &Tensor::operator=(const Tensor &tensor) {
|
||||||
|
@ -89,6 +90,7 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
|
||||||
dirty_ = tensor.is_dirty();
|
dirty_ = tensor.is_dirty();
|
||||||
device_address_ = tensor.device_address();
|
device_address_ = tensor.device_address();
|
||||||
data_ = tensor.data_;
|
data_ = tensor.data_;
|
||||||
|
id_ = tensor.id();
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -208,6 +210,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
|
||||||
data_ = input;
|
data_ = input;
|
||||||
}
|
}
|
||||||
dirty_ = true;
|
dirty_ = true;
|
||||||
|
id_ = std::to_string((uintptr_t)(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
|
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
|
||||||
|
@ -254,6 +257,7 @@ void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *co
|
||||||
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
|
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
id_ = std::to_string((uintptr_t)(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr Tensor::SetDtype(const TypePtr type_ptr) {
|
TypePtr Tensor::SetDtype(const TypePtr type_ptr) {
|
||||||
|
|
|
@ -263,9 +263,11 @@ class Tensor : public MetaTensor {
|
||||||
DeviceAddressPtr device_address() const { return device_address_; }
|
DeviceAddressPtr device_address() const { return device_address_; }
|
||||||
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
|
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
|
||||||
py::array data_sync();
|
py::array data_sync();
|
||||||
|
std::string id() const { return id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool dirty_{true};
|
bool dirty_{true};
|
||||||
|
std::string id_{""};
|
||||||
DeviceAddressPtr device_address_{nullptr};
|
DeviceAddressPtr device_address_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -501,10 +501,16 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
||||||
const std::vector<AnfNodePtr> ¶ms_list, bool applyJ) {
|
const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args,
|
||||||
|
bool applyJ) {
|
||||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
|
||||||
|
auto weights_node = weights;
|
||||||
|
if (weights == nullptr && !args.empty()) {
|
||||||
|
weights_node = ret->NewCNode(args);
|
||||||
|
}
|
||||||
|
|
||||||
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
|
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
|
||||||
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
|
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
|
||||||
|
|
||||||
|
@ -537,7 +543,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
||||||
inputs.push_back(NewValueNode(1));
|
inputs.push_back(NewValueNode(1));
|
||||||
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
|
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
|
||||||
|
|
||||||
doGetGrad(ret, out, ptrBprop, weights, opsTupleItem);
|
doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -129,7 +129,7 @@ class GradOperation : public MetaFuncGraph {
|
||||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||||
|
|
||||||
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
|
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
|
||||||
bool applyJ = false);
|
const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
bool sens_param() const { return sens_param_; }
|
bool sens_param() const { return sens_param_; }
|
||||||
bool get_all_;
|
bool get_all_;
|
||||||
|
|
|
@ -285,6 +285,10 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
// and add cast op on other inputs to keep the same type with assigned parameter.
|
// and add cast op on other inputs to keep the same type with assigned parameter.
|
||||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||||
AnfNodePtr param = params_list[i];
|
AnfNodePtr param = params_list[i];
|
||||||
|
if (args_spec_list[i] == nullptr) {
|
||||||
|
op_inputs.push_back(param);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
|
SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
|
||||||
// If sig_size is 0 use defalut.
|
// If sig_size is 0 use defalut.
|
||||||
if (sig_size > 0 && i < sig_size) {
|
if (sig_size > 0 && i < sig_size) {
|
||||||
|
@ -292,6 +296,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
} else if (has_var && i >= sig_size) {
|
} else if (has_var && i >= sig_size) {
|
||||||
sig = signature[sig_size - 1].rw;
|
sig = signature[sig_size - 1].rw;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr type = args_spec_list[i]->GetTypeTrack();
|
TypePtr type = args_spec_list[i]->GetTypeTrack();
|
||||||
if (type && type->type_id() == kObjectTypeRef) {
|
if (type && type->type_id() == kObjectTypeRef) {
|
||||||
if (sig == SignatureEnumRW::kRWRead) {
|
if (sig == SignatureEnumRW::kRWRead) {
|
||||||
|
|
|
@ -551,6 +551,10 @@ AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void DFunctor::CallDoutHoleOnTape() {
|
void DFunctor::CallDoutHoleOnTape() {
|
||||||
|
if (!is_top_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Call dout hole of all adjoint.
|
// Call dout hole of all adjoint.
|
||||||
for (auto &f : func_graph_to_functor_) {
|
for (auto &f : func_graph_to_functor_) {
|
||||||
for (auto &adjoint : f.second->anfnode_to_adjoin_) {
|
for (auto &adjoint : f.second->anfnode_to_adjoin_) {
|
||||||
|
|
|
@ -55,6 +55,8 @@ class DFunctor {
|
||||||
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
||||||
// Register functor objects to form a global view.
|
// Register functor objects to form a global view.
|
||||||
void Init(const DFunctorPtr &functor, bool is_top = false);
|
void Init(const DFunctorPtr &functor, bool is_top = false);
|
||||||
|
bool IsInScope(const AnfNodePtr &node);
|
||||||
|
|
||||||
// Clear resources.
|
// Clear resources.
|
||||||
static void Clear();
|
static void Clear();
|
||||||
|
|
||||||
|
@ -62,7 +64,6 @@ class DFunctor {
|
||||||
// Map one morphism.
|
// Map one morphism.
|
||||||
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
||||||
bool IsFreeMorphism(const AnfNodePtr &node);
|
bool IsFreeMorphism(const AnfNodePtr &node);
|
||||||
bool IsInScope(const AnfNodePtr &node);
|
|
||||||
// Map morphism that's not attached to output.
|
// Map morphism that's not attached to output.
|
||||||
void MapFreeMorphism();
|
void MapFreeMorphism();
|
||||||
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) {
|
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto gradkv = func_graph->transforms().find("grad");
|
auto gradkv = func_graph->transforms().find("grad");
|
||||||
if (gradkv != func_graph->transforms().end()) {
|
if (gradkv != func_graph->transforms().end()) {
|
||||||
|
@ -46,14 +46,18 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
||||||
auto user_defined = f->KUserDefined(func_graph);
|
auto user_defined = f->KUserDefined(func_graph);
|
||||||
if (user_defined != nullptr) {
|
if (user_defined != nullptr) {
|
||||||
multi_graph_sink(user_defined);
|
multi_graph_sink(user_defined);
|
||||||
DFunctor::Clear();
|
if (is_top) {
|
||||||
|
DFunctor::Clear();
|
||||||
|
}
|
||||||
return user_defined;
|
return user_defined;
|
||||||
}
|
}
|
||||||
f->Init(f, true);
|
f->Init(f, is_top);
|
||||||
f->MapObject();
|
f->MapObject();
|
||||||
f->MapMorphism();
|
f->MapMorphism();
|
||||||
auto ret = f->k_graph();
|
auto ret = f->k_graph();
|
||||||
DFunctor::Clear();
|
if (is_top) {
|
||||||
|
DFunctor::Clear();
|
||||||
|
}
|
||||||
|
|
||||||
multi_graph_sink(ret);
|
multi_graph_sink(ret);
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -71,5 +75,7 @@ MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
|
||||||
MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
|
MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CleanRes() { DFunctor::Clear(); }
|
||||||
} // namespace ad
|
} // namespace ad
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,9 +28,10 @@ namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||||
|
|
||||||
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true);
|
||||||
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||||
MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &);
|
MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &);
|
||||||
|
void CleanRes();
|
||||||
} // namespace ad
|
} // namespace ad
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,8 @@ class InlinerBase : public AnfVisitor {
|
||||||
auto params = fg->parameters();
|
auto params = fg->parameters();
|
||||||
auto old_size = params.size();
|
auto old_size = params.size();
|
||||||
if (old_size != new_params.size()) {
|
if (old_size != new_params.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Parameter size not match.";
|
MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
|
||||||
|
<< fg->output()->DebugString(10);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < old_size; i++) {
|
for (size_t i = 0; i < old_size; i++) {
|
||||||
(void)mng->Replace(params[i], new_params[i]);
|
(void)mng->Replace(params[i], new_params[i]);
|
||||||
|
|
|
@ -276,6 +276,8 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
|
||||||
|
|
||||||
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
||||||
|
|
||||||
|
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
||||||
|
|
||||||
static bool IsCtrlSink() {
|
static bool IsCtrlSink() {
|
||||||
auto ms_ctx = MsContext::GetInstance();
|
auto ms_ctx = MsContext::GetInstance();
|
||||||
std::string device_target = ms_ctx->device_target();
|
std::string device_target = ms_ctx->device_target();
|
||||||
|
|
|
@ -35,6 +35,7 @@ bool SymbolResolveAction(const ResourcePtr &res);
|
||||||
bool AbstractSpecializeAction(const ResourcePtr &res);
|
bool AbstractSpecializeAction(const ResourcePtr &res);
|
||||||
bool GeOptimizeAction(const ResourcePtr &res);
|
bool GeOptimizeAction(const ResourcePtr &res);
|
||||||
bool VmOptimizeAction(const ResourcePtr &res);
|
bool VmOptimizeAction(const ResourcePtr &res);
|
||||||
|
bool PynativeOptimizeAction(const ResourcePtr &res);
|
||||||
bool TaskEmitAction(const ResourcePtr &res);
|
bool TaskEmitAction(const ResourcePtr &res);
|
||||||
bool ExecuteAction(const ResourcePtr &res);
|
bool ExecuteAction(const ResourcePtr &res);
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
|
#include "optimizer/ad/grad.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parse {
|
namespace parse {
|
||||||
|
@ -338,6 +339,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
||||||
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
|
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
|
||||||
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
|
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
|
||||||
converted = env;
|
converted = env;
|
||||||
|
} else if (py::hasattr(obj, "__parameter__")) {
|
||||||
|
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
||||||
|
ret = ConvertData(to_convert, &converted);
|
||||||
} else {
|
} else {
|
||||||
ret = ConvertOtherObj(obj, &converted);
|
ret = ConvertOtherObj(obj, &converted);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,7 @@ const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
|
||||||
const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key";
|
const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key";
|
||||||
const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
|
const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
|
||||||
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
|
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
|
||||||
|
const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id";
|
||||||
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
|
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
|
||||||
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
|
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
|
||||||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||||
|
@ -83,6 +84,7 @@ const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
|
||||||
|
|
||||||
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
||||||
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
||||||
|
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
|
||||||
|
|
||||||
// define the common name
|
// define the common name
|
||||||
const char NAMED_PRIMITIVE_ITER[] = "iter";
|
const char NAMED_PRIMITIVE_ITER[] = "iter";
|
||||||
|
|
|
@ -278,5 +278,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
|
||||||
{"opt_control", ControlGroup},
|
{"opt_control", ControlGroup},
|
||||||
{"opt_prepare", PrepareGroup},
|
{"opt_prepare", PrepareGroup},
|
||||||
{"cconv", CconvPass}};
|
{"cconv", CconvPass}};
|
||||||
|
|
||||||
|
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
||||||
} // namespace pipeline
|
} // namespace pipeline
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
|
||||||
|
|
||||||
extern std::vector<PassItem> kGePasses;
|
extern std::vector<PassItem> kGePasses;
|
||||||
extern std::vector<PassItem> kVmPasses;
|
extern std::vector<PassItem> kVmPasses;
|
||||||
|
extern std::vector<PassItem> kPynativePasses;
|
||||||
|
|
||||||
bool CconvPass(const ResourcePtr &res);
|
bool CconvPass(const ResourcePtr &res);
|
||||||
bool ValidatePass(const ResourcePtr &res);
|
bool ValidatePass(const ResourcePtr &res);
|
||||||
|
|
|
@ -608,7 +608,7 @@ void Pipeline::Run() {
|
||||||
MS_LOG(INFO) << "End";
|
MS_LOG(INFO) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) {
|
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) {
|
||||||
std::size_t size = args.size();
|
std::size_t size = args.size();
|
||||||
|
|
||||||
for (std::size_t i = 0; i < size; i++) {
|
for (std::size_t i = 0; i < size; i++) {
|
||||||
|
@ -625,7 +625,6 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
|
||||||
arg_list->push_back(converted);
|
arg_list->push_back(converted);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResourcePtr res = GetResource(phase);
|
|
||||||
MS_EXCEPTION_IF_NULL(res);
|
MS_EXCEPTION_IF_NULL(res);
|
||||||
auto graph = res->func_graph();
|
auto graph = res->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
@ -647,6 +646,10 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) {
|
||||||
|
ProcessVmArgInner(args, GetResource(phase), arg_list);
|
||||||
|
}
|
||||||
|
|
||||||
py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
|
py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
|
||||||
std::size_t size = args.size();
|
std::size_t size = args.size();
|
||||||
if (!py::isinstance<py::str>(phase)) {
|
if (!py::isinstance<py::str>(phase)) {
|
||||||
|
@ -874,6 +877,8 @@ void ClearResAtexit() {
|
||||||
compile::ClearConvertCache();
|
compile::ClearConvertCache();
|
||||||
pipeline::GetMethodMap().clear();
|
pipeline::GetMethodMap().clear();
|
||||||
pipeline::ExecutorPy::ClearRes();
|
pipeline::ExecutorPy::ClearRes();
|
||||||
|
pipeline::ReclaimOptimizer();
|
||||||
|
pynative::PynativeExecutor::GetInstance()->Clean();
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
transform::DfGraphManager::GetInstance().ClearGraph();
|
transform::DfGraphManager::GetInstance().ClearGraph();
|
||||||
transform::DfGraphConvertor::get_adpt_map().clear();
|
transform::DfGraphConvertor::get_adpt_map().clear();
|
||||||
|
|
|
@ -139,6 +139,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
||||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||||
const std::vector<int64_t> &input_indexes, bool need_run);
|
const std::vector<int64_t> &input_indexes, bool need_run);
|
||||||
|
|
||||||
|
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list);
|
||||||
|
|
||||||
} // namespace pipeline
|
} // namespace pipeline
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -22,17 +22,30 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "ir/param_value_py.h"
|
||||||
#include "utils/any.h"
|
#include "utils/any.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
#include "operator/ops.h"
|
#include "operator/ops.h"
|
||||||
|
#include "operator/composite/composite.h"
|
||||||
#include "operator/composite/do_signature.h"
|
#include "operator/composite/do_signature.h"
|
||||||
#include "pipeline/parse/data_converter.h"
|
#include "pipeline/parse/data_converter.h"
|
||||||
|
#include "pipeline/parse/parse_base.h"
|
||||||
|
#include "pipeline/parse/resolve.h"
|
||||||
#include "pipeline/static_analysis/prim.h"
|
#include "pipeline/static_analysis/prim.h"
|
||||||
#include "session/session_factory.h"
|
#include "session/session_factory.h"
|
||||||
#include "pre_activate/pass/const_input_to_attr_registry.h"
|
#include "pre_activate/pass/const_input_to_attr_registry.h"
|
||||||
#include "pre_activate/common/helper.h"
|
#include "pre_activate/common/helper.h"
|
||||||
|
#include "pipeline/action.h"
|
||||||
|
|
||||||
#include "pynative/base.h"
|
#include "pynative/base.h"
|
||||||
|
#include "pybind_api/api_register.h"
|
||||||
|
#include "vm/transform.h"
|
||||||
|
|
||||||
|
#include "optimizer/ad/grad.h"
|
||||||
|
#include "pipeline/resource.h"
|
||||||
|
#include "pipeline/pipeline.h"
|
||||||
|
#include "pipeline/pass.h"
|
||||||
|
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
#include "pynative/pynative_execute_ge.h"
|
#include "pynative/pynative_execute_ge.h"
|
||||||
|
@ -40,21 +53,55 @@
|
||||||
|
|
||||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||||
// primitive unable to infer value for constant input in PyNative mode
|
// primitive unable to infer value for constant input in PyNative mode
|
||||||
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor"};
|
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor", "HookBackward"};
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
|
|
||||||
static std::shared_ptr<session::SessionBasic> session = nullptr;
|
static std::shared_ptr<session::SessionBasic> session = nullptr;
|
||||||
|
PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
|
||||||
|
std::mutex PynativeExecutor::instance_lock_;
|
||||||
|
ResourcePtr PynativeExecutor::resource_;
|
||||||
|
|
||||||
inline ValuePtr PyAttrValue(const py::object &obj) {
|
inline ValuePtr PyAttrValue(const py::object &obj) {
|
||||||
ValuePtr converted_ret = nullptr;
|
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
|
||||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
if (!converted_ret) {
|
||||||
if (!converted) {
|
|
||||||
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
|
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
|
||||||
}
|
}
|
||||||
return converted_ret;
|
return converted_ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
|
std::string GetId(const py::object &obj) {
|
||||||
|
py::object to_process = obj;
|
||||||
|
std::string prefix = "";
|
||||||
|
if (py::isinstance<py::tuple>(to_process)) {
|
||||||
|
auto p_list = py::cast<py::tuple>(to_process);
|
||||||
|
to_process = p_list[0];
|
||||||
|
prefix = "tuple:";
|
||||||
|
if (!py::isinstance<tensor::Tensor>(to_process)) {
|
||||||
|
std::string key = "";
|
||||||
|
for (size_t i = 0; i < p_list.size(); ++i) {
|
||||||
|
key += std::string(py::str(p_list[i])) + ":";
|
||||||
|
}
|
||||||
|
return prefix + key;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (py::isinstance<py::int_>(to_process)) {
|
||||||
|
return prefix + std::string(py::str(to_process));
|
||||||
|
}
|
||||||
|
if (py::isinstance<py::float_>(to_process)) {
|
||||||
|
return prefix + std::string(py::str(to_process));
|
||||||
|
}
|
||||||
|
if (py::isinstance<tensor::Tensor>(to_process)) {
|
||||||
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(to_process);
|
||||||
|
return prefix + tensor_ptr->id();
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
|
||||||
|
return py::cast<std::string>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
|
||||||
auto signature = prim->signatures();
|
auto signature = prim->signatures();
|
||||||
std::vector<SignatureEnumDType> dtypes;
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||||
|
@ -87,7 +134,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
|
||||||
}
|
}
|
||||||
(void)dst_type.insert(std::make_pair(type, m_index));
|
(void)dst_type.insert(std::make_pair(type, m_index));
|
||||||
}
|
}
|
||||||
py::tuple py_inputs(py_args.size());
|
py::list py_inputs(py_args.size());
|
||||||
for (size_t i = 0; i < py_args.size(); ++i) {
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||||
auto it = dst_type.find(dtypes[i]);
|
auto it = dst_type.find(dtypes[i]);
|
||||||
if (it != dst_type.end() && it->second != i &&
|
if (it != dst_type.end() && it->second != i &&
|
||||||
|
@ -105,12 +152,12 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
|
||||||
return py_inputs;
|
return py_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) {
|
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
|
||||||
size_t size = py_args.size();
|
size_t size = py_args.size();
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
if (input_value->isa<tensor::Tensor>()) {
|
||||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
||||||
} else {
|
} else {
|
||||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
||||||
|
@ -120,6 +167,12 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
|
||||||
op_exec_info->abstract = infer_res;
|
op_exec_info->abstract = infer_res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::object GetTupleObj(const py::object &obj) {
|
||||||
|
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||||
|
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
|
||||||
|
return obj_tuple;
|
||||||
|
}
|
||||||
|
|
||||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
||||||
if (args.size() != PY_ARGS_NUM) {
|
if (args.size() != PY_ARGS_NUM) {
|
||||||
MS_LOG(ERROR) << "Four args are needed by RunOp";
|
MS_LOG(ERROR) << "Four args are needed by RunOp";
|
||||||
|
@ -133,14 +186,19 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
||||||
if (pyobj == nullptr) {
|
if (pyobj == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||||
}
|
}
|
||||||
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
py::list py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
||||||
// use python infer method
|
// use python infer method
|
||||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||||
PynativeInfer(prim, py_args, op_exec_info.get());
|
PynativeInfer(prim, py_args, op_exec_info.get());
|
||||||
}
|
}
|
||||||
op_exec_info->py_primitive = prim;
|
op_exec_info->py_primitive = prim;
|
||||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||||
op_exec_info->op_inputs = py_args;
|
size_t input_num = py_args.size();
|
||||||
|
op_exec_info->op_inputs = py::tuple(input_num);
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
auto obj = py_args[i];
|
||||||
|
op_exec_info->op_inputs[i] = GetTupleObj(obj);
|
||||||
|
}
|
||||||
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
||||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||||
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||||
|
@ -154,9 +212,13 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
std::string graph_info;
|
std::string graph_info;
|
||||||
// get input tensor info
|
// get input tensor info
|
||||||
for (const auto &input_tensor : input_tensors) {
|
size_t input_num = op_exec_info->op_inputs.size();
|
||||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
for (size_t index = 0; index < input_num; ++index) {
|
||||||
(void)graph_info.append(input_tensor->GetShapeAndDataTypeInfo() + "_");
|
auto input = op_exec_info->op_inputs[index];
|
||||||
|
if (py::isinstance<tensor::Tensor>(input)) {
|
||||||
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(input);
|
||||||
|
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// get prim and abstract info
|
// get prim and abstract info
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
|
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
|
||||||
|
@ -171,6 +233,23 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
||||||
MS_EXCEPTION_IF_NULL(status);
|
MS_EXCEPTION_IF_NULL(status);
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
|
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
|
||||||
|
if (op_exec_info->op_name == "HookBackward") {
|
||||||
|
auto op_inputs = op_exec_info->op_inputs;
|
||||||
|
py::tuple result(op_inputs.size());
|
||||||
|
for (size_t i = 0; i < op_inputs.size(); i++) {
|
||||||
|
py::object input = op_inputs[i];
|
||||||
|
if (py::hasattr(input, "__parameter__")) {
|
||||||
|
result[i] = py::getattr(input, "data");
|
||||||
|
} else {
|
||||||
|
auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]);
|
||||||
|
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data());
|
||||||
|
result[i] = new_tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*status = PYNATIVE_SUCCESS;
|
||||||
|
MS_LOG(INFO) << "RunOpInVM end";
|
||||||
|
return std::move(result);
|
||||||
|
}
|
||||||
auto func = op_exec_info->py_primitive->GetComputeFunction();
|
auto func = op_exec_info->py_primitive->GetComputeFunction();
|
||||||
if (py::isinstance<py::none>(func)) {
|
if (py::isinstance<py::none>(func)) {
|
||||||
MS_LOG(ERROR) << "VM failed to get func";
|
MS_LOG(ERROR) << "VM failed to get func";
|
||||||
|
@ -288,7 +367,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
|
||||||
opt::ConstInputToAttrInfoRegister reg;
|
opt::ConstInputToAttrInfoRegister reg;
|
||||||
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||||
size_t input_num = op_run_info->op_inputs.size();
|
size_t input_num = op_run_info->op_inputs.size();
|
||||||
MS_LOG(INFO) << "py input size: " << input_num;
|
|
||||||
for (size_t index = 0; index < input_num; ++index) {
|
for (size_t index = 0; index < input_num; ++index) {
|
||||||
// convert const input to attr
|
// convert const input to attr
|
||||||
if (reg_exist &&
|
if (reg_exist &&
|
||||||
|
@ -386,7 +464,56 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
|
||||||
|
if (!grad_flag_ || graph_info_map_.size() == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||||
|
inputs.push_back(NewValueNode(prim));
|
||||||
|
py::tuple op_masks = args[PY_INPUT_MASK];
|
||||||
|
py::list op_args = args[PY_INPUTS];
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
for (size_t i = 0; i < op_args.size(); i++) {
|
||||||
|
auto node = GetInput(op_args[i], op_masks[i]);
|
||||||
|
args_spec_list.push_back(node->abstract());
|
||||||
|
inputs.push_back(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cnode = curr_g_->NewCNode(inputs);
|
||||||
|
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString();
|
||||||
|
py::object out_real = out;
|
||||||
|
if (out.size() == 1) {
|
||||||
|
MS_LOG(DEBUG) << "MakeCnode out size is one.";
|
||||||
|
out_real = out[0];
|
||||||
|
}
|
||||||
|
std::string obj_id = GetId(out_real);
|
||||||
|
if (py::isinstance<py::tuple>(out_real)) {
|
||||||
|
auto value = py::cast<py::tuple>(out_real);
|
||||||
|
if (value.size() > 1) {
|
||||||
|
for (int i = 0; i < static_cast<int>(value.size()); i++) {
|
||||||
|
auto value_id = GetId(value[i]);
|
||||||
|
set_obj_node_map(curr_g_, value_id, cnode, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
set_obj_node_map(curr_g_, obj_id, cnode);
|
||||||
|
set_pyobj(curr_g_, obj_id);
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
||||||
|
auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
|
||||||
|
if (out.second == -1) {
|
||||||
|
return out.first;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), out.first,
|
||||||
|
NewValueNode(out.second)};
|
||||||
|
return curr_g_->NewCNode(tuple_get_item_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
py::tuple RunOp(const py::args &args) {
|
py::tuple RunOp(const py::args &args) {
|
||||||
|
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
||||||
py::object result;
|
py::object result;
|
||||||
// returns a null py::tuple on error
|
// returns a null py::tuple on error
|
||||||
py::tuple err_ret(0);
|
py::tuple err_ret(0);
|
||||||
|
@ -428,10 +555,298 @@ py::tuple RunOp(const py::args &args) {
|
||||||
return err_ret;
|
return err_ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "RunOp end";
|
auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
|
||||||
|
if (node != nullptr) {
|
||||||
|
node->set_abstract(op_exec_info->abstract);
|
||||||
|
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "RunOp end";
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClearPyNativeSession() { session = nullptr; }
|
void ClearPyNativeSession() { session = nullptr; }
|
||||||
|
|
||||||
|
PynativeExecutor::~PynativeExecutor() { Clean(); }
|
||||||
|
|
||||||
|
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
||||||
|
|
||||||
|
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
||||||
|
auto cell_id = GetId(cell);
|
||||||
|
if (cell_graph_map_.count(cell_id) != 0) {
|
||||||
|
MS_LOG(DEBUG) << "Newgraph already compiled";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto g = std::make_shared<FuncGraph>();
|
||||||
|
|
||||||
|
if (top_g_ == nullptr) {
|
||||||
|
top_g_ = curr_g_ = g;
|
||||||
|
df_builder_ = std::make_shared<FuncGraph>();
|
||||||
|
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
|
||||||
|
Pushp();
|
||||||
|
} else {
|
||||||
|
Pushp();
|
||||||
|
curr_g_ = g;
|
||||||
|
}
|
||||||
|
if (graph_info_map_.count(g) == 0) {
|
||||||
|
graph_info_map_[g] = GraphInfo();
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
auto new_param = g->add_parameter();
|
||||||
|
std::string param_obj = GetId(args[i]);
|
||||||
|
graph_info_map_[g].param_map[param_obj] = new_param;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
|
||||||
|
AnfNodePtr node = nullptr;
|
||||||
|
std::string obj_id = GetId(obj);
|
||||||
|
|
||||||
|
if (op_mask != nullptr && py::cast<bool>(op_mask)) {
|
||||||
|
MS_LOG(DEBUG) << "Topgraph free parameter";
|
||||||
|
// get the parameter name from parameter object
|
||||||
|
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
|
||||||
|
if (py::isinstance<py::none>(name_attr)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||||
|
}
|
||||||
|
std::string param_name = py::cast<std::string>(name_attr);
|
||||||
|
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
||||||
|
auto free_param = df_builder_->add_parameter();
|
||||||
|
free_param->set_name(param_name);
|
||||||
|
auto free_param_new = std::make_shared<ParamValuePy>(obj);
|
||||||
|
free_param->set_default_param(free_param_new);
|
||||||
|
free_param->debug_info()->set_name(param_name);
|
||||||
|
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
|
||||||
|
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
|
||||||
|
return free_param;
|
||||||
|
}
|
||||||
|
return graph_info_map_[df_builder_].param_map[obj_id];
|
||||||
|
}
|
||||||
|
|
||||||
|
// if input is graph output
|
||||||
|
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
|
||||||
|
// op(x, y)
|
||||||
|
node = graph_info_map_[curr_g_].param_map[obj_id];
|
||||||
|
} else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
|
||||||
|
// out = op(op1(x, y))
|
||||||
|
// out = op(cell1(x, y))
|
||||||
|
// out = op(cell1(x, y)[0])
|
||||||
|
node = GetObjNode(obj);
|
||||||
|
} else {
|
||||||
|
// out = op(x, 1)
|
||||||
|
ValuePtr converted_ret = nullptr;
|
||||||
|
parse::ConvertData(obj, &converted_ret);
|
||||||
|
node = NewValueNode(converted_ret);
|
||||||
|
set_obj_node_map(curr_g_, obj_id, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Now getinput " << py::str(obj) << " node " << node->ToString();
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
|
||||||
|
|
||||||
|
void PynativeExecutor::Popp() {
|
||||||
|
if (graph_p_.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
|
||||||
|
}
|
||||||
|
curr_g_ = graph_p_.top();
|
||||||
|
graph_p_.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
||||||
|
auto cell_id = GetId(cell);
|
||||||
|
if (cell_graph_map_.count(cell_id) != 0) {
|
||||||
|
MS_LOG(DEBUG) << "Endgraph already compiled";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cell_graph_map_[cell_id] = curr_g_;
|
||||||
|
auto out_id = GetId(out);
|
||||||
|
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
|
||||||
|
MS_LOG(ERROR) << "graph has no this out: " << out_id;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_node = GetObjNode(out);
|
||||||
|
curr_g_->set_output(output_node);
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
inputs.push_back(NewValueNode(curr_g_));
|
||||||
|
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
|
||||||
|
resource_->manager()->AddFuncGraph(curr_g_);
|
||||||
|
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
|
||||||
|
if (curr_g_ != top_g_) {
|
||||||
|
Popp();
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
auto input = GetInput(args[i], py::object());
|
||||||
|
inputs.push_back(input);
|
||||||
|
}
|
||||||
|
auto out_cnode = curr_g_->NewCNode(inputs);
|
||||||
|
set_pyobj(curr_g_, GetId(cell));
|
||||||
|
set_obj_node_map(curr_g_, GetId(out), out_cnode);
|
||||||
|
} else {
|
||||||
|
parse::ResolveFuncGraph(newfg, resource_);
|
||||||
|
resource_->set_func_graph(newfg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||||
|
const py::args &args) {
|
||||||
|
MS_LOG(INFO) << "GradNet start" << args.size();
|
||||||
|
|
||||||
|
std::size_t size = args.size();
|
||||||
|
auto cell_id = GetId(cell);
|
||||||
|
if (graph_map_.count(cell_id) != 0) {
|
||||||
|
MS_LOG(DEBUG) << "GradNet already compiled";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "GradNet first compiled";
|
||||||
|
std::vector<AnfNodePtr> new_params;
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
|
||||||
|
new_params.push_back(p);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
|
||||||
|
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
|
||||||
|
df_builder_->set_parameters(new_params);
|
||||||
|
resource_->manager()->SetParameters(df_builder_, new_params);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> w_args;
|
||||||
|
if (py::hasattr(weights, "__parameter_tuple__")) {
|
||||||
|
auto tuple = weights.cast<py::tuple>();
|
||||||
|
MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size();
|
||||||
|
w_args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||||
|
auto param = tuple[it];
|
||||||
|
auto param_id = GetId(param);
|
||||||
|
AnfNodePtr para_node = nullptr;
|
||||||
|
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
|
||||||
|
para_node = graph_info_map_[df_builder_].param_map[param_id];
|
||||||
|
|
||||||
|
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
|
||||||
|
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
|
||||||
|
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
|
||||||
|
AnfNodePtr ref_key_node = NewValueNode(refkey);
|
||||||
|
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
|
||||||
|
|
||||||
|
w_args.push_back(ref_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "training not paramter_tuple";
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(resource_->func_graph());
|
||||||
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
||||||
|
resource_->set_func_graph(g);
|
||||||
|
|
||||||
|
// get the parameters items and add the value to args_spec
|
||||||
|
abstract::AbstractBasePtrList args_spec;
|
||||||
|
for (std::size_t i = 0; i < size; i++) {
|
||||||
|
ValuePtr converted = nullptr;
|
||||||
|
bool succ = parse::ConvertData(args[i], &converted);
|
||||||
|
if (!succ) {
|
||||||
|
MS_LOG(EXCEPTION) << "Args convert error";
|
||||||
|
}
|
||||||
|
bool broaden = true;
|
||||||
|
auto abs = abstract::FromValue(converted, broaden);
|
||||||
|
args_spec.push_back(abs);
|
||||||
|
auto param_node = std::static_pointer_cast<Parameter>(df_builder_->parameters()[i]);
|
||||||
|
param_node->set_abstract(abs);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto ¶m : df_builder_->parameters()) {
|
||||||
|
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||||
|
if (param_node->has_default()) {
|
||||||
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
|
||||||
|
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Args convert error";
|
||||||
|
}
|
||||||
|
args_spec.push_back(ptr);
|
||||||
|
param_node->set_abstract(ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
|
||||||
|
|
||||||
|
resource_->set_args_spec(args_spec);
|
||||||
|
MS_LOG(DEBUG) << "Start opt";
|
||||||
|
|
||||||
|
// Create backend and session
|
||||||
|
resource_->results()[pipeline::kBackend] = compile::CreateBackend();
|
||||||
|
|
||||||
|
graph_map_[cell_id] = g;
|
||||||
|
PynativeOptimizeAction(resource_);
|
||||||
|
TaskEmitAction(resource_);
|
||||||
|
ExecuteAction(resource_);
|
||||||
|
resource_->Clean();
|
||||||
|
ad::CleanRes();
|
||||||
|
pipeline::ReclaimOptimizer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::Clear() {
|
||||||
|
MS_LOG(INFO) << "Clear all res";
|
||||||
|
top_g_ = curr_g_ = nullptr;
|
||||||
|
std::stack<FuncGraphPtr>().swap(graph_p_);
|
||||||
|
graph_info_map_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::Clean() {
|
||||||
|
graph_map_.clear();
|
||||||
|
cell_graph_map_.clear();
|
||||||
|
Clear();
|
||||||
|
resource_.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
|
||||||
|
VectorRef arg_list;
|
||||||
|
pipeline::ProcessVmArgInner(args, resource_, &arg_list);
|
||||||
|
if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
|
||||||
|
!resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
|
||||||
|
}
|
||||||
|
compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
|
||||||
|
if (run == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Eval run" << backend;
|
||||||
|
BaseRef value = (*run)(arg_list);
|
||||||
|
MS_LOG(DEBUG) << "Run end" << value.ToString();
|
||||||
|
return BaseRefToPyData(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op,
|
||||||
|
const std::vector<AnfNodePtr> &weights, size_t arg_size) {
|
||||||
|
auto nparam = top_g_->parameters().size();
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "grad{" << nparam << "}";
|
||||||
|
df_builder_->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
df_builder_->debug_info()->set_name(ss.str());
|
||||||
|
|
||||||
|
auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights);
|
||||||
|
std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
|
||||||
|
for (size_t i = 0; i < arg_size; ++i) {
|
||||||
|
inputs.push_back(df_builder_->parameters()[i]);
|
||||||
|
}
|
||||||
|
auto out = df_builder_->NewCNode(inputs);
|
||||||
|
df_builder_->set_output(out);
|
||||||
|
resource_->manager()->AddFuncGraph(df);
|
||||||
|
resource_->manager()->AddFuncGraph(df_builder_);
|
||||||
|
return df_builder_;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
||||||
|
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
||||||
|
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
||||||
|
.def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
|
||||||
|
.def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
|
||||||
|
.def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
|
||||||
|
.def("clear", &PynativeExecutor::Clear, "pynative clear status.")
|
||||||
|
.def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""),
|
||||||
|
"Executor run function.")
|
||||||
|
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||||
|
"Executor set grad flag.");
|
||||||
|
}));
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,23 +22,93 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stack>
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
#include "pynative/base.h"
|
#include "pynative/base.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pipeline/resource.h"
|
||||||
|
#include "operator/composite/composite.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||||
|
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
|
||||||
|
|
||||||
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||||
|
|
||||||
py::tuple RunOp(const py::args &args);
|
py::tuple RunOp(const py::args &args);
|
||||||
|
|
||||||
|
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args);
|
||||||
|
|
||||||
void ClearPyNativeSession();
|
void ClearPyNativeSession();
|
||||||
|
|
||||||
|
struct GraphInfo {
|
||||||
|
std::unordered_map<std::string, AnfNodePtr> param_map;
|
||||||
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> obj_node_map;
|
||||||
|
AnfNodePtr output;
|
||||||
|
std::vector<std::string> objects;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
|
public:
|
||||||
|
static std::shared_ptr<PynativeExecutor> GetInstance() {
|
||||||
|
std::lock_guard<std::mutex> i_lock(instance_lock_);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
|
||||||
|
resource_ = std::make_shared<pipeline::Resource>();
|
||||||
|
}
|
||||||
|
return executor_;
|
||||||
|
}
|
||||||
|
void NewGraph(const py::object &cell, const py::args &args);
|
||||||
|
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||||
|
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||||
|
void Clear();
|
||||||
|
void Clean();
|
||||||
|
bool grad_flag() { return grad_flag_; }
|
||||||
|
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||||
|
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
|
||||||
|
AnfNodePtr GetObjNode(const py::object &obj);
|
||||||
|
FuncGraphPtr curr_g() { return curr_g_; }
|
||||||
|
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
|
||||||
|
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
|
||||||
|
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, -1);
|
||||||
|
}
|
||||||
|
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
|
||||||
|
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
|
||||||
|
}
|
||||||
|
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
|
||||||
|
py::object Run(const py::tuple &args, const py::object &phase);
|
||||||
|
|
||||||
|
void Pushp();
|
||||||
|
void Popp();
|
||||||
|
FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||||
|
size_t arg_size);
|
||||||
|
|
||||||
|
~PynativeExecutor();
|
||||||
|
|
||||||
|
private:
|
||||||
|
PynativeExecutor();
|
||||||
|
static std::shared_ptr<PynativeExecutor> executor_;
|
||||||
|
static std::mutex instance_lock_;
|
||||||
|
static ResourcePtr resource_;
|
||||||
|
bool grad_flag_;
|
||||||
|
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
|
||||||
|
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
|
||||||
|
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||||
|
std::stack<FuncGraphPtr> graph_p_;
|
||||||
|
FuncGraphPtr top_g_;
|
||||||
|
FuncGraphPtr df_builder_;
|
||||||
|
FuncGraphPtr curr_g_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||||
|
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from collections import OrderedDict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor
|
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
|
||||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
||||||
from .tensor import Tensor as MsTensor
|
from .tensor import Tensor as MsTensor
|
||||||
|
|
||||||
|
@ -273,6 +273,34 @@ def _generate_pip_args(obj, *args, method="construct"):
|
||||||
obj.__parse_method__ = parse_method
|
obj.__parse_method__ = parse_method
|
||||||
return args_names, args_list
|
return args_names, args_list
|
||||||
|
|
||||||
|
class _PynativeExecutor:
|
||||||
|
"""
|
||||||
|
An pynative executor used to compile/manage/run graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Graph, return the result of pipeline running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._executor = PynativeExecutor_.get_instance()
|
||||||
|
|
||||||
|
def new_graph(self, obj, *args):
|
||||||
|
self._executor.new_graph(obj, *args)
|
||||||
|
|
||||||
|
def end_graph(self, obj, output, *args):
|
||||||
|
self._executor.end_graph(obj, output, *args)
|
||||||
|
|
||||||
|
def grad(self, grad, obj, weights, *args):
|
||||||
|
self._executor.grad_net(grad, obj, weights, *args)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._executor.clear()
|
||||||
|
|
||||||
|
def set_grad_flag(self, flag):
|
||||||
|
self._executor.set_grad_flag(flag)
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
return self._executor(args, "")
|
||||||
|
|
||||||
class _Executor:
|
class _Executor:
|
||||||
"""
|
"""
|
||||||
|
@ -500,5 +528,6 @@ class _Executor:
|
||||||
|
|
||||||
|
|
||||||
_executor = _Executor()
|
_executor = _Executor()
|
||||||
|
_pynative_exec = _PynativeExecutor()
|
||||||
|
|
||||||
__all__ = ['ms_function']
|
__all__ = ['ms_function']
|
||||||
|
|
|
@ -89,7 +89,6 @@ class Tensor(Tensor_):
|
||||||
return hash(id(self))
|
return hash(id(self))
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
check_type('tensor input_data', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -101,7 +100,6 @@ class Tensor(Tensor_):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __radd__(self, other):
|
def __radd__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__add__')(other, self)
|
out = tensor_operator_registry.get('__add__')(other, self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -110,22 +108,18 @@ class Tensor(Tensor_):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __rmul__(self, other):
|
def __rmul__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__mul__')(other, self)
|
out = tensor_operator_registry.get('__mul__')(other, self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __truediv__(self, other):
|
def __truediv__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__div__')(self, other)
|
out = tensor_operator_registry.get('__div__')(self, other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __rtruediv__(self, other):
|
def __rtruediv__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__div__')(other, self)
|
out = tensor_operator_registry.get('__div__')(other, self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __sub__(self, other):
|
def __sub__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = self.__add__(-other)
|
out = self.__add__(-other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -134,7 +128,6 @@ class Tensor(Tensor_):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __rsub__(self, other):
|
def __rsub__(self, other):
|
||||||
check_type('tensor operation input', other, (Tensor, float, int))
|
|
||||||
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
|
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from collections import OrderedDict
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from .. import context
|
from .. import context
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..common.api import _executor
|
from ..common.api import _executor, _pynative_exec
|
||||||
from .._checkparam import _check_str_by_regular
|
from .._checkparam import _check_str_by_regular
|
||||||
from ..common.parameter import Parameter, ParameterTuple
|
from ..common.parameter import Parameter, ParameterTuple
|
||||||
from .._c_expression import init_backend
|
from .._c_expression import init_backend
|
||||||
|
@ -60,6 +60,7 @@ class Cell:
|
||||||
self._params = OrderedDict()
|
self._params = OrderedDict()
|
||||||
self._cells = OrderedDict()
|
self._cells = OrderedDict()
|
||||||
self.training = False
|
self.training = False
|
||||||
|
self.requires_grad = False
|
||||||
self.pynative = False
|
self.pynative = False
|
||||||
self._param_prefix = ''
|
self._param_prefix = ''
|
||||||
self._auto_prefix = auto_prefix
|
self._auto_prefix = auto_prefix
|
||||||
|
@ -79,6 +80,15 @@ class Cell:
|
||||||
self._backward_hook = None
|
self._backward_hook = None
|
||||||
self.enable_hook = False
|
self.enable_hook = False
|
||||||
self._bprop_debug = False
|
self._bprop_debug = False
|
||||||
|
self._is_run = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_run(self):
|
||||||
|
return self._is_run
|
||||||
|
|
||||||
|
@is_run.setter
|
||||||
|
def is_run(self, value):
|
||||||
|
self._is_run = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def create_time(self):
|
def create_time(self):
|
||||||
|
@ -192,9 +202,20 @@ class Cell:
|
||||||
out = self.compile_and_run(*inputs)
|
out = self.compile_and_run(*inputs)
|
||||||
return out
|
return out
|
||||||
self.init_parameters_data()
|
self.init_parameters_data()
|
||||||
output = self.construct(*inputs)
|
if self.requires_grad is True:
|
||||||
|
_pynative_exec.set_grad_flag(True)
|
||||||
|
_pynative_exec.new_graph(self, *inputs)
|
||||||
|
else:
|
||||||
|
_pynative_exec.set_grad_flag(False)
|
||||||
|
if self.enable_hook:
|
||||||
|
output = self._hook_construct(*inputs)
|
||||||
|
else:
|
||||||
|
output = self.construct(*inputs)
|
||||||
if isinstance(output, Parameter):
|
if isinstance(output, Parameter):
|
||||||
output = output.data
|
output = output.data
|
||||||
|
if self.requires_grad is True:
|
||||||
|
_pynative_exec.end_graph(self, output, *inputs)
|
||||||
|
self._is_run = True
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
|
@ -722,6 +743,10 @@ class Cell:
|
||||||
self.add_flags_recursive(**flags)
|
self.add_flags_recursive(**flags)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def set_grad(self, mode=True):
|
||||||
|
self.add_flags_recursive(requires_grad=mode)
|
||||||
|
return self
|
||||||
|
|
||||||
def set_train(self, mode=True):
|
def set_train(self, mode=True):
|
||||||
"""
|
"""
|
||||||
Sets the cell to training mode.
|
Sets the cell to training mode.
|
||||||
|
@ -762,9 +787,9 @@ class Cell:
|
||||||
self.add_flags(auto_parallel=True)
|
self.add_flags(auto_parallel=True)
|
||||||
self._get_construct_inputs_number_and_name()
|
self._get_construct_inputs_number_and_name()
|
||||||
|
|
||||||
def _hook_construct(self, inputs):
|
def _hook_construct(self, *inputs):
|
||||||
"""Hook construct method to replace original construct method when hook function enabled."""
|
"""Hook construct method to replace original construct method when hook function enabled."""
|
||||||
inputs = self._backward_hook(inputs)
|
inputs = self._backward_hook(*inputs)
|
||||||
inputs = self.construct(inputs)
|
inputs = self.construct(inputs)
|
||||||
outputs = self._backward_hook(inputs)
|
outputs = self._backward_hook(inputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -166,6 +166,7 @@ class TrainOneStepCell(Cell):
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.network = network
|
||||||
|
self.network.set_grad()
|
||||||
self.network.add_flags(defer_inline=True)
|
self.network.add_flags(defer_inline=True)
|
||||||
self.weights = optimizer.parameters
|
self.weights = optimizer.parameters
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
|
@ -18,14 +18,16 @@
|
||||||
"""Basic composite operations."""
|
"""Basic composite operations."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.api import ms_function
|
from ...common.api import ms_function, _pynative_exec
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ...common.parameter import Parameter
|
from ...common.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,14 +107,34 @@ class GradOperation(GradOperation_):
|
||||||
GradOperation_.__init__(self, name, get_all, get_by_list, sens_param)
|
GradOperation_.__init__(self, name, get_all, get_by_list, sens_param)
|
||||||
self.grad_fn = None
|
self.grad_fn = None
|
||||||
self.fn = None
|
self.fn = None
|
||||||
|
self.need_forward = False
|
||||||
|
|
||||||
def __call__(self, fn, weights=None):
|
def __call__(self, fn, weights=None):
|
||||||
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
|
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
|
||||||
if self.grad_fn is None or self.fn != fn:
|
if self.grad_fn is None or self.fn != fn:
|
||||||
if self.get_by_list:
|
if self.get_by_list:
|
||||||
@ms_function(obj=fn)
|
if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug:
|
||||||
def after_grad(*args):
|
@ms_function(obj=fn)
|
||||||
return grad_(fn, weights)(*args)
|
def after_grad(*args):
|
||||||
|
return grad_(fn, weights)(*args)
|
||||||
|
else:
|
||||||
|
def after_grad(*args):
|
||||||
|
if fn.is_run and not fn.requires_grad:
|
||||||
|
raise ValueError("obj must set_grad.")
|
||||||
|
if not fn.is_run:
|
||||||
|
self.need_forward = True
|
||||||
|
print("already has forward run before grad by user")
|
||||||
|
if self.need_forward:
|
||||||
|
fn.set_grad()
|
||||||
|
if self.sens_param:
|
||||||
|
f_args = args[:-1]
|
||||||
|
fn(*f_args)
|
||||||
|
else:
|
||||||
|
fn(*args)
|
||||||
|
_pynative_exec.grad(grad_, fn, weights, *args)
|
||||||
|
out = _pynative_exec(*args)
|
||||||
|
_pynative_exec.clear()
|
||||||
|
return out
|
||||||
else:
|
else:
|
||||||
@ms_function(obj=fn)
|
@ms_function(obj=fn)
|
||||||
def after_grad(*args):
|
def after_grad(*args):
|
||||||
|
|
|
@ -286,12 +286,6 @@ class HookBackward(PrimitiveWithInfer):
|
||||||
self.register_hook(hook_fn)
|
self.register_hook(hook_fn)
|
||||||
self.cell_id = cell_id
|
self.cell_id = cell_id
|
||||||
|
|
||||||
def __call__(self, *inputs):
|
|
||||||
"""run in PyNative mode."""
|
|
||||||
if len(inputs) == 1:
|
|
||||||
return inputs[0]
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def infer_shape(self, *inputs_shape):
|
def infer_shape(self, *inputs_shape):
|
||||||
if len(inputs_shape) == 1:
|
if len(inputs_shape) == 1:
|
||||||
return inputs_shape[0]
|
return inputs_shape[0]
|
||||||
|
|
|
@ -328,15 +328,9 @@ def _run_op(obj, op_name, args):
|
||||||
op_inputs = []
|
op_inputs = []
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if hasattr(arg, '__parameter__'):
|
if hasattr(arg, '__parameter__'):
|
||||||
op_inputs.append(arg.default_input)
|
|
||||||
op_mask[i] = 1
|
op_mask[i] = 1
|
||||||
elif isinstance(arg, tuple):
|
op_inputs.append(arg)
|
||||||
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
|
output = real_run_op(obj, op_name, args, tuple(op_mask))
|
||||||
args_ = tuple(convert(x) for x in arg)
|
|
||||||
op_inputs.append(args_)
|
|
||||||
else:
|
|
||||||
op_inputs.append(arg)
|
|
||||||
output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))
|
|
||||||
if not output:
|
if not output:
|
||||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||||
if len(output) == 1:
|
if len(output) == 1:
|
||||||
|
|
|
@ -54,4 +54,4 @@ class Net_Dropout(nn.Cell):
|
||||||
def test_compile_dropout():
|
def test_compile_dropout():
|
||||||
net = Net_Dropout()
|
net = Net_Dropout()
|
||||||
input_data = Tensor(np.ones([20, 16, 50], dtype=np.float32))
|
input_data = Tensor(np.ones([20, 16, 50], dtype=np.float32))
|
||||||
_executor.compile(net, input_data)
|
net(input_data)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
||||||
from .vm_interface import vm
|
from .vm_interface import vm
|
||||||
|
|
||||||
|
@ -225,7 +226,7 @@ def vm_impl_slice(self):
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
|
|
||||||
@vm_impl_getters.register(P._grad_ops.ConcatOffset)
|
@vm_impl_getters.register(G.ConcatOffset)
|
||||||
def vm_impl_concatOffset(self):
|
def vm_impl_concatOffset(self):
|
||||||
"""Generate vm_impl function for ConcatOffset"""
|
"""Generate vm_impl function for ConcatOffset"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue