Support_non_tensor_as_inputs_of_outermost_cell_in_pynative
This commit is contained in:
parent
40370dc8d7
commit
f020fd5865
|
@ -136,18 +136,6 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
|
||||
}
|
||||
|
||||
void CheckArgsValid(const py::tuple &args) {
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (!CheckArgValid(args[i])) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "The inputs types of the outermost network support bool, int, float, tensor, "
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
|
||||
"and tuple or list containing only these types, and dict whose values are these types, but got "
|
||||
<< i << "th arg is " << py::str(args[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetCompileExceptionInfo() {
|
||||
std::ostringstream oss;
|
||||
trace::TraceGraphEval();
|
||||
|
@ -236,6 +224,18 @@ void CacheFuncGraph(const ResourcePtr &resource) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void CheckArgsValid(const py::tuple &args) {
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (!CheckArgValid(args[i])) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "The inputs types of the outermost network support bool, int, float, tensor, "
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
|
||||
"and tuple or list containing only these types, and dict whose values are these types, but got "
|
||||
<< i << "th arg is " << py::str(args[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
|
||||
MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
|
|
|
@ -130,6 +130,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
};
|
||||
using ExecutorPyPtr = std::shared_ptr<ExecutorPy>;
|
||||
|
||||
void CheckArgsValid(const py::tuple &args);
|
||||
// Generate a key for mapping function graph
|
||||
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults);
|
||||
py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs);
|
||||
|
@ -162,7 +163,6 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
|
|||
py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const size_t key_len, std::string enc_mode);
|
||||
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_len, std::string dec_mode);
|
||||
bool PyIsCipherFile(const std::string &file_path);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -345,6 +345,24 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::v
|
|||
return graph_info;
|
||||
}
|
||||
|
||||
py::args FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
||||
size_t size = args.size();
|
||||
if (size == 0 && has_sens) {
|
||||
MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
|
||||
}
|
||||
py::list only_tensors;
|
||||
size_t forward_args_size = has_sens ? size - 1 : size;
|
||||
for (size_t i = 0; i < forward_args_size; ++i) {
|
||||
if (py::isinstance<tensor::Tensor>(args[i])) {
|
||||
only_tensors.append(args[i]);
|
||||
}
|
||||
}
|
||||
if (has_sens) {
|
||||
only_tensors.append(args[forward_args_size]);
|
||||
}
|
||||
return only_tensors;
|
||||
}
|
||||
|
||||
bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
|
||||
const std::unordered_set<size_t> &input_attrs) {
|
||||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
|
@ -1913,50 +1931,63 @@ pipeline::ResourcePtr GradExecutor::GetResource(const std::string &cell_id) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top) {
|
||||
if (is_bprop_top) {
|
||||
// Convert input args to parameters for top cell graph in bprop.
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
auto param = args[i];
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
std::string param_id = GetId(param);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Convert input args to parameters for top cell graph in construct.
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
py::args only_tensors = FilterTensorArgs(args);
|
||||
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
auto param_i = only_tensors[i];
|
||||
ValuePtr param_i_value = PyAttrValue(param_i);
|
||||
MS_EXCEPTION_IF_NULL(param_i_value);
|
||||
input_param_values.emplace_back(param_i_value);
|
||||
auto param_i_abs = param_i_value->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
new_param->set_abstract(param_i_abs->Broaden());
|
||||
std::string param_i_id = GetId(param_i);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, param_i, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
}
|
||||
top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters(), input_param_values));
|
||||
}
|
||||
|
||||
void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args) {
|
||||
auto bprop_fn = [this, &cell_id, &args]() {
|
||||
if (IsBpropGraph(cell_id)) {
|
||||
if (cell_stack_.empty() || IsNestedGrad()) {
|
||||
if (cell_stack_.empty() && !grad_is_running_) {
|
||||
MS_LOG(DEBUG) << "Make new topest graph";
|
||||
MakeNewTopGraph(cell_id, args, true);
|
||||
} else if (grad_is_running_ && IsBpropGraph(cell_id)) {
|
||||
MS_LOG(DEBUG) << "Run bprop cell";
|
||||
curr_g_ = std::make_shared<FuncGraph>();
|
||||
auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
|
||||
top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
auto param = args[i];
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
std::string param_id = GetId(param);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
|
||||
}
|
||||
HandleInputArgsForTopCell(args, true);
|
||||
bprop_grad_stack_.push(std::make_pair(cell_id, false));
|
||||
} else if (top_cell()->grad_order() != grad_order_) {
|
||||
} else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
|
||||
MS_LOG(DEBUG) << "Nested grad graph existed in bprop";
|
||||
MakeNewTopGraph(cell_id, args, false);
|
||||
bprop_grad_stack_.push(std::make_pair(cell_id, true));
|
||||
}
|
||||
};
|
||||
|
||||
if (cell_stack_.empty()) {
|
||||
if (grad_is_running_) {
|
||||
bprop_fn();
|
||||
} else {
|
||||
MakeNewTopGraph(cell_id, args, true);
|
||||
}
|
||||
} else {
|
||||
// High order
|
||||
if (IsNestedGrad()) {
|
||||
if (grad_is_running_) {
|
||||
bprop_fn();
|
||||
} else if (top_cell()->grad_order() != grad_order_) {
|
||||
MS_LOG(DEBUG) << "Enter nested graph";
|
||||
auto cur_top_is_dynamic = top_cell()->is_dynamic();
|
||||
MakeNewTopGraph(cell_id, args, false);
|
||||
// If outer is dynamic, inner set dynamic too
|
||||
if (cur_top_is_dynamic) {
|
||||
top_cell()->set_is_dynamic(cur_top_is_dynamic);
|
||||
}
|
||||
}
|
||||
} else if (!cell_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_) {
|
||||
MS_LOG(DEBUG) << "Nested grad graph existed in construct";
|
||||
auto cur_top_is_dynamic = top_cell()->is_dynamic();
|
||||
MakeNewTopGraph(cell_id, args, false);
|
||||
top_cell()->set_is_dynamic(cur_top_is_dynamic);
|
||||
}
|
||||
}
|
||||
|
||||
PushCellStack(cell_id);
|
||||
// Init kPynativeCellPtr with input parameters of top cell
|
||||
if (!top_cell()->is_init_kpynative()) {
|
||||
|
@ -1965,23 +1996,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
auto df_builder = GetDfbuilder(cell_id);
|
||||
auto graph_info_df = std::make_shared<GraphInfo>(cell_id);
|
||||
top_cell()->graph_info_map()[df_builder] = graph_info_df;
|
||||
// Init parameter info for make cnode and curr_g
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
auto param_i = args[i];
|
||||
ValuePtr param_i_value = PyAttrValue(param_i);
|
||||
MS_EXCEPTION_IF_NULL(param_i_value);
|
||||
input_param_values.emplace_back(param_i_value);
|
||||
auto param_i_abs = param_i_value->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
new_param->set_abstract(param_i_abs->Broaden());
|
||||
std::string param_i_id = GetId(param_i);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, param_i, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
}
|
||||
top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters(), input_param_values));
|
||||
HandleInputArgsForTopCell(args, false);
|
||||
top_cell()->set_need_compile_graph(true);
|
||||
top_cell()->set_init_kpynative(true);
|
||||
} else {
|
||||
|
@ -2011,12 +2026,11 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const
|
|||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
custom_bprop_cell_count_ += 1;
|
||||
}
|
||||
// Init resource for resource and df_builder
|
||||
// Make top graph and init resource for resource and df_builder
|
||||
InitResourceAndDfBuilder(cell_id, args);
|
||||
// Check whether cell has dynamic construct
|
||||
if (!top_cell()->is_dynamic()) {
|
||||
|
@ -2029,26 +2043,18 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const
|
|||
}
|
||||
|
||||
void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest) {
|
||||
for (const auto &arg : args) {
|
||||
if (py::isinstance<tensor::Tensor>(arg)) {
|
||||
auto tensor = arg.cast<tensor::TensorPtr>();
|
||||
if (tensor && tensor->is_parameter()) {
|
||||
MS_EXCEPTION(TypeError) << "The inputs could not be Parameter.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pipeline::CheckArgsValid(args);
|
||||
// Record input args info
|
||||
std::string input_args_id;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
input_args_id += GetId(args[i]) + "_";
|
||||
}
|
||||
|
||||
// Run forward first need plus 1
|
||||
if (grad_order_ == 0) {
|
||||
++grad_order_;
|
||||
}
|
||||
// Create top cell
|
||||
curr_g_ = std::make_shared<FuncGraph>();
|
||||
// Init resource for new top cell
|
||||
auto df_builder = std::make_shared<FuncGraph>();
|
||||
auto resource = std::make_shared<pipeline::Resource>();
|
||||
auto top_cell = std::make_shared<TopCellInfo>(is_topest, grad_order_, resource, df_builder, cell_id);
|
||||
|
@ -2166,25 +2172,20 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
MakeValueNode(out, out_id);
|
||||
}
|
||||
}
|
||||
|
||||
DoGradForCustomBprop(cell, out, args);
|
||||
// Set output node for forward graph when need.
|
||||
PopCellStack();
|
||||
auto set_fg_fn = [this, &out, &out_id]() {
|
||||
AnfNodePtr output_node = GetObjNode(out, out_id);
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
curr_g_->set_output(output_node);
|
||||
};
|
||||
if (grad_is_running_) {
|
||||
if (!bprop_grad_stack_.top().second) {
|
||||
bprop_grad_stack_.pop();
|
||||
set_fg_fn();
|
||||
curr_g_->set_output(GetObjNode(out, out_id));
|
||||
return;
|
||||
} else if (bprop_grad_stack_.top().first == cell_id) {
|
||||
bprop_grad_stack_.pop();
|
||||
}
|
||||
}
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
set_fg_fn();
|
||||
curr_g_->set_output(GetObjNode(out, out_id));
|
||||
DumpIR("fg.ir", curr_g_);
|
||||
}
|
||||
// Reset grad flag and update output node of top cell
|
||||
|
@ -2440,7 +2441,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
|
|||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop_graph->debug_info()->set_name(ss.str());
|
||||
// Get the parameters items and add the value to args_spec
|
||||
(void)GetArgsSpec(args, bprop_graph);
|
||||
(void)GetArgsSpec(FilterTensorArgs(args, grad->sens_param_), bprop_graph);
|
||||
|
||||
// Do opt for final bprop graph
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
|
@ -2552,7 +2553,7 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
|
|||
MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
|
||||
|
||||
VectorRef arg_list;
|
||||
py::tuple converted_args = ConvertArgs(args);
|
||||
py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens));
|
||||
pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
|
||||
if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find run graph output";
|
||||
|
|
|
@ -202,7 +202,6 @@ class GradExecutor {
|
|||
|
||||
private:
|
||||
ForwardExecutorPtr forward() const;
|
||||
|
||||
// Higher derivative
|
||||
bool IsNestedGrad() const;
|
||||
void SwitchTopcell();
|
||||
|
@ -213,17 +212,20 @@ class GradExecutor {
|
|||
void PopCellStack();
|
||||
void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
|
||||
TopCellInfoPtr PopHighOrderGraphStack();
|
||||
|
||||
// Manage information of top cell.
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
pipeline::ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top);
|
||||
void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args);
|
||||
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest);
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
|
||||
// Manage resource when run grad process.
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
|
||||
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
|
||||
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
|
||||
py::args *forward_args = nullptr);
|
||||
void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
|
@ -232,11 +234,12 @@ class GradExecutor {
|
|||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
// Manage resource for construct forward graph.
|
||||
std::string &graph_phase() { return graph_phase_; }
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
std::string &graph_phase() { return graph_phase_; }
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
|
@ -252,7 +255,6 @@ class GradExecutor {
|
|||
}
|
||||
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out,
|
||||
const std::string &out_id);
|
||||
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
||||
private:
|
||||
bool grad_flag_{false};
|
||||
|
|
|
@ -56,6 +56,7 @@ AbstractFunctionPtr FuncGraph::abstract() {
|
|||
}
|
||||
|
||||
void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (force_new_ret || return_ == nullptr) {
|
||||
std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
|
||||
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
||||
|
|
|
@ -25,7 +25,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
|
|||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
from .. import functional as F
|
||||
from ...common.tensor import Tensor
|
||||
from .. import signature as sig
|
||||
|
||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
@ -339,9 +338,6 @@ class GradOperation(GradOperation_):
|
|||
else:
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.pop('sens')
|
||||
for arg in args:
|
||||
if not isinstance(arg, Tensor):
|
||||
raise TypeError("grad inputs should be tensor in pynative mode")
|
||||
if isinstance(fn, FunctionType):
|
||||
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
|
|
Loading…
Reference in New Issue