adjust-the-location-of-cleaning-unuseless-memory-in-value-node

This commit is contained in:
lvliang 2020-12-19 11:02:05 +08:00
parent f2b25d4139
commit 1490947ff0
13 changed files with 136 additions and 91 deletions

View File

@ -141,11 +141,6 @@ void RunOpsInGraphTask::Run() {
session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
}
void CleanUselessTensorsTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->CleanUselessTensorsImpl(useless_tensors_);
}
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
@ -392,15 +387,6 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
*outputs = task->outputs_;
}
void Executor::CleanUselessTensors(const SessionPtr &session,
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
MS_EXCEPTION_IF_NULL(useless_tensors);
auto task = std::make_shared<CleanUselessTensorsTask>();
task->session_ = session;
task->useless_tensors_ = useless_tensors;
SyncRunTask(task);
}
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
auto task = std::make_shared<CreateCommGroupTask>();
task->group_name_ = group_name;

View File

@ -46,8 +46,7 @@ enum TaskType {
kRunOp,
kCreateCommGroup,
kDestroyCommGroup,
kRunOpsInGraph,
kCleanUselessTensors
kRunOpsInGraph
};
class Task {
@ -110,14 +109,6 @@ class RunOpsInGraphTask : public Task {
GraphId graph_id_{0};
};
class CleanUselessTensorsTask : public Task {
public:
CleanUselessTensorsTask() { type_ = kCleanUselessTensors; }
~CleanUselessTensorsTask() override = default;
void Run() override;
std::shared_ptr<std::vector<tensor::TensorPtr>> useless_tensors_{nullptr};
};
class RunOpTask : public Task {
public:
RunOpTask() { type_ = kRunOp; }
@ -175,8 +166,6 @@ class Executor {
const std::vector<int64_t> &tensors_mask);
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
void CleanUselessTensors(const SessionPtr &session,
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name);
void OnEvent(const ExecutorEvent &event);

View File

@ -1657,11 +1657,6 @@ void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tens
executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
}
void SessionBasic::CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->CleanUselessTensors(shared_from_this(), useless_tensors);
}
void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
@ -1710,22 +1705,6 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &ro
root_graph->UpdateGraphDynamicAttr();
}
void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
auto ms_context = MsContext::GetInstance();
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == "CPU") {
return;
}
for (const auto &tensor : *useless_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
const auto &shape = tensor->shape();
if (!shape.empty()) {
// The address of scalar value node does not need to be deleted
tensor->set_device_address(nullptr);
}
}
}
bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) {
auto kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);

View File

@ -82,7 +82,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask);
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
void CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
@ -142,7 +141,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
friend class RunGraphTask;
friend class RunOpTask;
friend class RunOpsInGraphTask;
friend class CleanUselessTensorsTask;
virtual bool IsSupportSummary() { return true; }
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
@ -164,7 +162,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<int64_t> &tensors_mask) {}
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {}
void CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
virtual void SetSummaryNodes(KernelGraph *graph);

View File

@ -253,6 +253,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
k_app = k_graph_->NewCNode(inputs);
}
ReplaceEquivdout(k_app, cnode_morph);
cnode_morph->clear_inputs_value();
cnode_morph->set_forward(nullptr, "");
for (size_t i = 0; i < param_adjoints.size(); ++i) {
param_adjoints[i]->RegisterKUser(k_app, i);
@ -387,7 +388,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
MS_EXCEPTION_IF_NULL(out_node);
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
// clear resource
cnode_morph->clear_inputs_value();
fg->ClearAllManagerInfo();
func_graph->ClearAllManagerInfo();
}

View File

@ -92,6 +92,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_BODY[] = "body";
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
const char NAMED_PRIMITIVE_AUGASSIGN[] = "AugAssign";
const char NAMED_PRIMITIVE_FOR[] = "For";
const char NAMED_PRIMITIVE_IF[] = "If";
const char NAMED_PRIMITIVE_WHILE[] = "While";
@ -105,6 +106,7 @@ const char NAMED_PRIMITIVE_ATTRIBUTE[] = "Attribute";
const char NAMED_PRIMITIVE_COMPARE[] = "Compare";
const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant";
const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators";
const char NAMED_PRIMITIVE_TARGET[] = "target";
const char NAMED_PRIMITIVE_SLICE[] = "slice";
const char NAMED_PRIMITIVE_NAME[] = "Name";
const char NAMED_PRIMITIVE_NUM[] = "Num";

View File

@ -622,7 +622,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
op_exec_info->op_name = op_name;
if (grad_flag()) {
int64_t graph_id = graph_id_;
auto resource = GetResource();
auto resource = GetResource(top_cell_id_);
if (resource != nullptr) {
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
auto it = resource->results().find(pipeline::kPynativeGraphId);
@ -1007,21 +1007,21 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
MS_EXCEPTION_IF_NULL(output_value);
std::vector<tensor::TensorPtr> output_tensors;
TensorValueToTensor(output_value, &output_tensors);
if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) {
if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) {
// first step
std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
op_index_with_tensor_id_[op_index].emplace_back(tensor->id());
cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id());
});
return;
}
auto ms_context = MsContext::GetInstance();
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index];
for (size_t i = 0; i < tensor_id_list.size(); ++i) {
auto tensor_id = tensor_id_list[i];
if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) {
if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) {
auto &new_tensor = output_tensors[i];
auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id];
auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id];
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id "
<< tensor->id() << ", device address " << tensor->device_address().get()
@ -1050,7 +1050,15 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
tensor_id_with_tensor_.clear();
std::set<std::string> forward_op_tensor_id;
for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
const auto &tensor_id_list = elem.second;
for (const auto &tensor_id : tensor_id_list) {
forward_op_tensor_id.emplace(tensor_id);
}
}
cell_tensor_id_with_tensor_[top_cell_id_].clear();
const auto &func_graph = resource->func_graph();
const auto &value_node_list = func_graph->value_nodes();
for (const auto &elem : value_node_list) {
@ -1059,8 +1067,9 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value_node->value(), &tensors);
for (const auto &tensor : tensors) {
if (tensor->device_address() != nullptr) {
tensor_id_with_tensor_[tensor->id()].emplace_back(tensor);
if (tensor->device_address() != nullptr &&
forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor);
MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id()
<< ", device address " << tensor->device_address().get();
}
@ -1068,16 +1077,22 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
}
}
void PynativeExecutor::CleanTensorsInValueNode() {
// Only need clean in ms backend policy and session should not be nullptr in ms backend.
if (session == nullptr) {
void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) {
auto ms_context = MsContext::GetInstance();
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == "CPU") {
top_cell_id_ = cell_id;
return;
}
auto useless_tensors = std::make_shared<std::vector<tensor::TensorPtr>>();
for (const auto &id_tensor_pair : tensor_id_with_tensor_) {
std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors));
const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_];
for (const auto &elem : tensor_id_with_tensor) {
const auto &tensors_in_value_node = elem.second;
for (const auto &tensor : tensors_in_value_node) {
MS_EXCEPTION_IF_NULL(tensor);
tensor->set_device_address(nullptr);
}
}
session->CleanUselessTensors(useless_tensors);
top_cell_id_ = cell_id;
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
@ -1462,6 +1477,12 @@ void PynativeExecutor::SubNestedGradOrder() {
}
bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfo &value) {
return value.cell_id == cell_id && value.is_dynamic_cell;
});
if (it != top_cell_list_.end()) {
return false;
}
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) {
return value.cell_id == cell_id && (!is_grad || value.is_grad);
});
@ -1584,6 +1605,22 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs
}
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
// while self.a > self.b and changed self.a or self.b
if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
std::string left_variable;
if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
}
auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
std::string right_variable;
if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
right_variable =
py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
}
return ParseBodyContext(ast, node, {left_variable, right_variable});
}
// if a[0]
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
@ -1623,6 +1660,34 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
return false;
}
bool PynativeExecutor::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim) {
MS_LOG(DEBUG) << "Parse augassign expr";
bool ret = false;
if (compare_prim.empty()) {
return ret;
}
py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
if (py::isinstance<py::none>(target_node)) {
MS_LOG(DEBUG) << "Parse target node is none!";
return ret;
}
py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
if (py::isinstance<py::none>(value_node)) {
MS_LOG(DEBUG) << "Parse value node is none!";
return ret;
}
std::string assign_prim;
if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
}
auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
if (iter != compare_prim.end()) {
ret = true;
}
return ret;
}
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse for expr";
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
@ -1643,7 +1708,8 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &
return false;
}
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim) {
MS_EXCEPTION_IF_NULL(ast);
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
if (py::isinstance<py::none>(func_obj)) {
@ -1659,6 +1725,8 @@ bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &
const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
ret = ParseAssignExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
ret = ParseAugAssignExprNode(ast, node, compare_prim);
} else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
ret = ParseForExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
@ -1713,17 +1781,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
MS_LOG(EXCEPTION) << "Top cell list is empty";
}
if (IsTopGraph(cell_id)) {
// Clear previous step resource
op_index_map_.clear();
CleanPreMemoryInValueNode(cell_id);
}
MS_LOG(INFO) << "NewGraph already compiled";
return;
}
// init resource for constructing forward graph and grad graph
auto g = std::make_shared<FuncGraph>();
curr_g_ = g;
curr_g_ = std::make_shared<FuncGraph>();
ClearResidualRes(cell_id);
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
MakeNewTopGraph(cell_id, args, g);
MakeNewTopGraph(cell_id, args, curr_g_);
}
PushCurrentGraphToStack();
if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
@ -1732,7 +1801,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
}
for (size_t i = 0; i < args.size(); ++i) {
auto param = args[i];
auto new_param = g->add_parameter();
auto new_param = curr_g_->add_parameter();
std::string param_id = GetId(param);
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
@ -1741,6 +1810,13 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
// check whether the construct of cell will be changed
if (!dynamic_cell_) {
dynamic_cell_ = IsDynamicCell(cell);
if (dynamic_cell_) {
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&](const TopCellInfo &value) { return value.cell_id == top_cell_id_; });
if (it != top_cell_list_.end()) {
it->is_dynamic_cell = dynamic_cell_;
}
}
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
}
}
@ -1754,16 +1830,17 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
}
}
}
// Clear runop pre
// Clear resource in old top cell
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
if (it != top_cell_list_.end()) {
top_cell_list_.erase(it);
}
dynamic_cell_ = false;
op_index_map_.clear();
op_index_with_tensor_id_.clear();
CleanPreMemoryInValueNode(cell_id);
// Init resource for new top cell
dynamic_cell_ = false;
auto df_builder = std::make_shared<FuncGraph>();
GraphInfo graph_info = GraphInfo(cell_id);
graph_info_map_.emplace(df_builder, graph_info);
@ -2353,7 +2430,6 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args,
MS_LOG(DEBUG) << "Eval run " << backend;
set_grad_runing(true);
BaseRef value = (*run)(arg_list);
CleanTensorsInValueNode();
set_grad_runing(false);
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
auto out = BaseRefToPyData(value);
@ -2464,8 +2540,8 @@ void PynativeExecutor::ClearRes() {
cell_graph_list_.clear();
top_cell_list_.clear();
op_index_map_.clear();
op_index_with_tensor_id_.clear();
tensor_id_with_tensor_.clear();
cell_op_index_with_tensor_id_.clear();
cell_tensor_id_with_tensor_.clear();
cell_dynamic_map_.clear();
prim_abs_list_.clear();
std::stack<FuncGraphPtr>().swap(graph_stack_);

View File

@ -52,6 +52,8 @@ struct PrimAbsInfo {
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
using OpIndexWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensor = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
py::tuple RunOp(const py::args &args);
@ -87,6 +89,7 @@ struct TopCellInfo {
FuncGraphPtr df_builder;
FuncGraphPtr bg; // Backward graph
std::string cell_id;
bool is_dynamic_cell{false};
TopCellInfo() = default;
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
@ -154,9 +157,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool IsDynamicCell(const py::object &cell);
std::string GetCellInfo(const py::object &cell);
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim = {});
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim = {});
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type);
@ -190,7 +196,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// Update the abstract and device address info of value node and tensors in bprop graph
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveTensorsInValueNode(const ResourcePtr &resource);
void CleanTensorsInValueNode();
void CleanPreMemoryInValueNode(const std::string &cell_id);
// Construct grad graph
void PushCurrentGraphToStack();
@ -259,6 +265,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::mutex instance_lock_;
static int64_t graph_id_;
size_t grad_order_{0};
std::string top_cell_id_;
bool grad_flag_{false};
bool dynamic_cell_{false};
bool grad_is_running_{false};
@ -282,8 +289,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// Used for runop and replace forward result of grad graph
std::unordered_map<std::string, size_t> op_index_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, std::vector<std::string>> op_index_with_tensor_id_;
std::unordered_map<std::string, std::vector<tensor::TensorPtr>> tensor_id_with_tensor_;
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
};

View File

@ -553,7 +553,7 @@ std::string Tensor::ToStringInternal(int limit_size) const {
std::ostringstream buf;
auto dtype = Dtype();
MS_EXCEPTION_IF_NULL(dtype);
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value=";
buf << "Tensor(id=" << id_ << ", shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value=";
if (limit_size <= 0 || DataSize() < limit_size) {
// Only print data for small tensor.
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false);

View File

@ -361,7 +361,6 @@ class Cell(Cell_):
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
return output
def _add_attr(self, name, value):

View File

@ -182,6 +182,9 @@ class GradOperation(GradOperation_):
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False.
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
parameter, the key must be sens.
Returns:
The higher-order function which takes a function as argument and returns gradient function for it.
@ -311,16 +314,23 @@ class GradOperation(GradOperation_):
def _pynative_forward_run(self, args, kwargs, fn):
""" Pynative forward run to build grad graph. """
new_kwargs = {}
if self.sens_param:
args = args[:-1]
if not 'sens' in kwargs.keys():
args = args[:-1]
new_kwargs = kwargs
else:
for key, value in kwargs.items():
if key != 'sens':
new_kwargs[key] = value
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args, **kwargs)
output = fn(*args, **kwargs)
_pynative_exec.end_graph(fn, output, *args, **kwargs)
_pynative_exec.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
else:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
@ -328,7 +338,7 @@ class GradOperation(GradOperation_):
self.need_forward = True
if self.need_forward:
fn.set_grad()
fn(*args, **kwargs)
fn(*args, **new_kwargs)
fn.already_run = False
def __call__(self, fn, weights=None):

View File

@ -404,10 +404,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()

View File

@ -403,10 +403,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()