forked from mindspore-Ecosystem/mindspore
adjust-the-location-of-cleaning-unuseless-memory-in-value-node
This commit is contained in:
parent
f2b25d4139
commit
1490947ff0
|
@ -141,11 +141,6 @@ void RunOpsInGraphTask::Run() {
|
|||
session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
|
||||
}
|
||||
|
||||
void CleanUselessTensorsTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->CleanUselessTensorsImpl(useless_tensors_);
|
||||
}
|
||||
|
||||
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
|
||||
|
||||
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
|
||||
|
@ -392,15 +387,6 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
|
|||
*outputs = task->outputs_;
|
||||
}
|
||||
|
||||
void Executor::CleanUselessTensors(const SessionPtr &session,
|
||||
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(useless_tensors);
|
||||
auto task = std::make_shared<CleanUselessTensorsTask>();
|
||||
task->session_ = session;
|
||||
task->useless_tensors_ = useless_tensors;
|
||||
SyncRunTask(task);
|
||||
}
|
||||
|
||||
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
|
||||
auto task = std::make_shared<CreateCommGroupTask>();
|
||||
task->group_name_ = group_name;
|
||||
|
|
|
@ -46,8 +46,7 @@ enum TaskType {
|
|||
kRunOp,
|
||||
kCreateCommGroup,
|
||||
kDestroyCommGroup,
|
||||
kRunOpsInGraph,
|
||||
kCleanUselessTensors
|
||||
kRunOpsInGraph
|
||||
};
|
||||
|
||||
class Task {
|
||||
|
@ -110,14 +109,6 @@ class RunOpsInGraphTask : public Task {
|
|||
GraphId graph_id_{0};
|
||||
};
|
||||
|
||||
class CleanUselessTensorsTask : public Task {
|
||||
public:
|
||||
CleanUselessTensorsTask() { type_ = kCleanUselessTensors; }
|
||||
~CleanUselessTensorsTask() override = default;
|
||||
void Run() override;
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> useless_tensors_{nullptr};
|
||||
};
|
||||
|
||||
class RunOpTask : public Task {
|
||||
public:
|
||||
RunOpTask() { type_ = kRunOp; }
|
||||
|
@ -175,8 +166,6 @@ class Executor {
|
|||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs);
|
||||
void CleanUselessTensors(const SessionPtr &session,
|
||||
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
|
||||
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
|
||||
bool DestroyCommGroup(const std::string &group_name);
|
||||
void OnEvent(const ExecutorEvent &event);
|
||||
|
|
|
@ -1657,11 +1657,6 @@ void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tens
|
|||
executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
|
||||
}
|
||||
|
||||
void SessionBasic::CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->CleanUselessTensors(shared_from_this(), useless_tensors);
|
||||
}
|
||||
|
||||
void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
|
||||
|
@ -1710,22 +1705,6 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &ro
|
|||
root_graph->UpdateGraphDynamicAttr();
|
||||
}
|
||||
|
||||
void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == "CPU") {
|
||||
return;
|
||||
}
|
||||
for (const auto &tensor : *useless_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
const auto &shape = tensor->shape();
|
||||
if (!shape.empty()) {
|
||||
// The address of scalar value node does not need to be deleted
|
||||
tensor->set_device_address(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) {
|
||||
auto kernel_graph = graphs_[graph_id];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
|
|
@ -82,7 +82,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
void CleanUselessTensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
|
||||
|
||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
|
||||
|
@ -142,7 +141,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
friend class RunGraphTask;
|
||||
friend class RunOpTask;
|
||||
friend class RunOpsInGraphTask;
|
||||
friend class CleanUselessTensorsTask;
|
||||
virtual bool IsSupportSummary() { return true; }
|
||||
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
VectorRef *outputs,
|
||||
|
@ -164,7 +162,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {}
|
||||
void CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
|
||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
|
|
|
@ -253,6 +253,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
k_app = k_graph_->NewCNode(inputs);
|
||||
}
|
||||
ReplaceEquivdout(k_app, cnode_morph);
|
||||
cnode_morph->clear_inputs_value();
|
||||
cnode_morph->set_forward(nullptr, "");
|
||||
for (size_t i = 0; i < param_adjoints.size(); ++i) {
|
||||
param_adjoints[i]->RegisterKUser(k_app, i);
|
||||
|
@ -387,7 +388,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
|
||||
// clear resource
|
||||
cnode_morph->clear_inputs_value();
|
||||
fg->ClearAllManagerInfo();
|
||||
func_graph->ClearAllManagerInfo();
|
||||
}
|
||||
|
|
|
@ -92,6 +92,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
|||
const char NAMED_PRIMITIVE_LEN[] = "len";
|
||||
const char NAMED_PRIMITIVE_BODY[] = "body";
|
||||
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
|
||||
const char NAMED_PRIMITIVE_AUGASSIGN[] = "AugAssign";
|
||||
const char NAMED_PRIMITIVE_FOR[] = "For";
|
||||
const char NAMED_PRIMITIVE_IF[] = "If";
|
||||
const char NAMED_PRIMITIVE_WHILE[] = "While";
|
||||
|
@ -105,6 +106,7 @@ const char NAMED_PRIMITIVE_ATTRIBUTE[] = "Attribute";
|
|||
const char NAMED_PRIMITIVE_COMPARE[] = "Compare";
|
||||
const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant";
|
||||
const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators";
|
||||
const char NAMED_PRIMITIVE_TARGET[] = "target";
|
||||
const char NAMED_PRIMITIVE_SLICE[] = "slice";
|
||||
const char NAMED_PRIMITIVE_NAME[] = "Name";
|
||||
const char NAMED_PRIMITIVE_NUM[] = "Num";
|
||||
|
|
|
@ -622,7 +622,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
|
|||
op_exec_info->op_name = op_name;
|
||||
if (grad_flag()) {
|
||||
int64_t graph_id = graph_id_;
|
||||
auto resource = GetResource();
|
||||
auto resource = GetResource(top_cell_id_);
|
||||
if (resource != nullptr) {
|
||||
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
|
||||
auto it = resource->results().find(pipeline::kPynativeGraphId);
|
||||
|
@ -1007,21 +1007,21 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
|
|||
MS_EXCEPTION_IF_NULL(output_value);
|
||||
std::vector<tensor::TensorPtr> output_tensors;
|
||||
TensorValueToTensor(output_value, &output_tensors);
|
||||
if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) {
|
||||
if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) {
|
||||
// first step
|
||||
std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
|
||||
op_index_with_tensor_id_[op_index].emplace_back(tensor->id());
|
||||
cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id());
|
||||
});
|
||||
return;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
|
||||
const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index];
|
||||
for (size_t i = 0; i < tensor_id_list.size(); ++i) {
|
||||
auto tensor_id = tensor_id_list[i];
|
||||
if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) {
|
||||
if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) {
|
||||
auto &new_tensor = output_tensors[i];
|
||||
auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id];
|
||||
auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id];
|
||||
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
|
||||
MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id "
|
||||
<< tensor->id() << ", device address " << tensor->device_address().get()
|
||||
|
@ -1050,7 +1050,15 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
|
|||
|
||||
void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
tensor_id_with_tensor_.clear();
|
||||
std::set<std::string> forward_op_tensor_id;
|
||||
for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) {
|
||||
const auto &tensor_id_list = elem.second;
|
||||
for (const auto &tensor_id : tensor_id_list) {
|
||||
forward_op_tensor_id.emplace(tensor_id);
|
||||
}
|
||||
}
|
||||
|
||||
cell_tensor_id_with_tensor_[top_cell_id_].clear();
|
||||
const auto &func_graph = resource->func_graph();
|
||||
const auto &value_node_list = func_graph->value_nodes();
|
||||
for (const auto &elem : value_node_list) {
|
||||
|
@ -1059,8 +1067,9 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
|
|||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value_node->value(), &tensors);
|
||||
for (const auto &tensor : tensors) {
|
||||
if (tensor->device_address() != nullptr) {
|
||||
tensor_id_with_tensor_[tensor->id()].emplace_back(tensor);
|
||||
if (tensor->device_address() != nullptr &&
|
||||
forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) {
|
||||
cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor);
|
||||
MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id()
|
||||
<< ", device address " << tensor->device_address().get();
|
||||
}
|
||||
|
@ -1068,16 +1077,22 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
|
|||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::CleanTensorsInValueNode() {
|
||||
// Only need clean in ms backend policy and session should not be nullptr in ms backend.
|
||||
if (session == nullptr) {
|
||||
void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == "CPU") {
|
||||
top_cell_id_ = cell_id;
|
||||
return;
|
||||
}
|
||||
auto useless_tensors = std::make_shared<std::vector<tensor::TensorPtr>>();
|
||||
for (const auto &id_tensor_pair : tensor_id_with_tensor_) {
|
||||
std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors));
|
||||
const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_];
|
||||
for (const auto &elem : tensor_id_with_tensor) {
|
||||
const auto &tensors_in_value_node = elem.second;
|
||||
for (const auto &tensor : tensors_in_value_node) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensor->set_device_address(nullptr);
|
||||
}
|
||||
}
|
||||
session->CleanUselessTensors(useless_tensors);
|
||||
top_cell_id_ = cell_id;
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
|
||||
|
@ -1462,6 +1477,12 @@ void PynativeExecutor::SubNestedGradOrder() {
|
|||
}
|
||||
|
||||
bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
|
||||
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfo &value) {
|
||||
return value.cell_id == cell_id && value.is_dynamic_cell;
|
||||
});
|
||||
if (it != top_cell_list_.end()) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) {
|
||||
return value.cell_id == cell_id && (!is_grad || value.is_grad);
|
||||
});
|
||||
|
@ -1584,6 +1605,22 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs
|
|||
}
|
||||
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
|
||||
// while self.a > self.b and changed self.a or self.b
|
||||
if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
|
||||
auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
std::string left_variable;
|
||||
if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
|
||||
left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
|
||||
}
|
||||
auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
|
||||
std::string right_variable;
|
||||
if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
|
||||
right_variable =
|
||||
py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
|
||||
}
|
||||
return ParseBodyContext(ast, node, {left_variable, right_variable});
|
||||
}
|
||||
// if a[0]
|
||||
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
|
||||
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
|
||||
|
@ -1623,6 +1660,34 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
|
|||
return false;
|
||||
}
|
||||
|
||||
bool PynativeExecutor::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_LOG(DEBUG) << "Parse augassign expr";
|
||||
bool ret = false;
|
||||
if (compare_prim.empty()) {
|
||||
return ret;
|
||||
}
|
||||
py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
|
||||
if (py::isinstance<py::none>(target_node)) {
|
||||
MS_LOG(DEBUG) << "Parse target node is none!";
|
||||
return ret;
|
||||
}
|
||||
py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
if (py::isinstance<py::none>(value_node)) {
|
||||
MS_LOG(DEBUG) << "Parse value node is none!";
|
||||
return ret;
|
||||
}
|
||||
std::string assign_prim;
|
||||
if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
|
||||
assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
|
||||
}
|
||||
auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
|
||||
if (iter != compare_prim.end()) {
|
||||
ret = true;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse for expr";
|
||||
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
|
||||
|
@ -1643,7 +1708,8 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &
|
|||
return false;
|
||||
}
|
||||
|
||||
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
|
||||
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
|
||||
if (py::isinstance<py::none>(func_obj)) {
|
||||
|
@ -1659,6 +1725,8 @@ bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &
|
|||
const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
|
||||
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
|
||||
ret = ParseAssignExprNode(ast, node);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
|
||||
ret = ParseAugAssignExprNode(ast, node, compare_prim);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
|
||||
ret = ParseForExprNode(ast, node);
|
||||
} else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
|
||||
|
@ -1713,17 +1781,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
MS_LOG(EXCEPTION) << "Top cell list is empty";
|
||||
}
|
||||
if (IsTopGraph(cell_id)) {
|
||||
// Clear previous step resource
|
||||
op_index_map_.clear();
|
||||
CleanPreMemoryInValueNode(cell_id);
|
||||
}
|
||||
MS_LOG(INFO) << "NewGraph already compiled";
|
||||
return;
|
||||
}
|
||||
// init resource for constructing forward graph and grad graph
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
curr_g_ = g;
|
||||
curr_g_ = std::make_shared<FuncGraph>();
|
||||
ClearResidualRes(cell_id);
|
||||
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
|
||||
MakeNewTopGraph(cell_id, args, g);
|
||||
MakeNewTopGraph(cell_id, args, curr_g_);
|
||||
}
|
||||
PushCurrentGraphToStack();
|
||||
if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
|
||||
|
@ -1732,7 +1801,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
}
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
auto param = args[i];
|
||||
auto new_param = g->add_parameter();
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
std::string param_id = GetId(param);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
|
||||
|
@ -1741,6 +1810,13 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
// check whether the construct of cell will be changed
|
||||
if (!dynamic_cell_) {
|
||||
dynamic_cell_ = IsDynamicCell(cell);
|
||||
if (dynamic_cell_) {
|
||||
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
||||
[&](const TopCellInfo &value) { return value.cell_id == top_cell_id_; });
|
||||
if (it != top_cell_list_.end()) {
|
||||
it->is_dynamic_cell = dynamic_cell_;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
|
||||
}
|
||||
}
|
||||
|
@ -1754,16 +1830,17 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
|
|||
}
|
||||
}
|
||||
}
|
||||
// Clear runop pre
|
||||
// Clear resource in old top cell
|
||||
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
||||
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
|
||||
if (it != top_cell_list_.end()) {
|
||||
top_cell_list_.erase(it);
|
||||
}
|
||||
dynamic_cell_ = false;
|
||||
op_index_map_.clear();
|
||||
op_index_with_tensor_id_.clear();
|
||||
CleanPreMemoryInValueNode(cell_id);
|
||||
|
||||
// Init resource for new top cell
|
||||
dynamic_cell_ = false;
|
||||
auto df_builder = std::make_shared<FuncGraph>();
|
||||
GraphInfo graph_info = GraphInfo(cell_id);
|
||||
graph_info_map_.emplace(df_builder, graph_info);
|
||||
|
@ -2353,7 +2430,6 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args,
|
|||
MS_LOG(DEBUG) << "Eval run " << backend;
|
||||
set_grad_runing(true);
|
||||
BaseRef value = (*run)(arg_list);
|
||||
CleanTensorsInValueNode();
|
||||
set_grad_runing(false);
|
||||
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
|
||||
auto out = BaseRefToPyData(value);
|
||||
|
@ -2464,8 +2540,8 @@ void PynativeExecutor::ClearRes() {
|
|||
cell_graph_list_.clear();
|
||||
top_cell_list_.clear();
|
||||
op_index_map_.clear();
|
||||
op_index_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_.clear();
|
||||
cell_op_index_with_tensor_id_.clear();
|
||||
cell_tensor_id_with_tensor_.clear();
|
||||
cell_dynamic_map_.clear();
|
||||
prim_abs_list_.clear();
|
||||
std::stack<FuncGraphPtr>().swap(graph_stack_);
|
||||
|
|
|
@ -52,6 +52,8 @@ struct PrimAbsInfo {
|
|||
|
||||
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
|
||||
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
|
||||
using OpIndexWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensor = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
|
||||
|
||||
py::tuple RunOp(const py::args &args);
|
||||
|
||||
|
@ -87,6 +89,7 @@ struct TopCellInfo {
|
|||
FuncGraphPtr df_builder;
|
||||
FuncGraphPtr bg; // Backward graph
|
||||
std::string cell_id;
|
||||
bool is_dynamic_cell{false};
|
||||
TopCellInfo() = default;
|
||||
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
|
||||
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
|
||||
|
@ -154,9 +157,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool IsDynamicCell(const py::object &cell);
|
||||
std::string GetCellInfo(const py::object &cell);
|
||||
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
|
@ -190,7 +196,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
// Update the abstract and device address info of value node and tensors in bprop graph
|
||||
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void CleanTensorsInValueNode();
|
||||
void CleanPreMemoryInValueNode(const std::string &cell_id);
|
||||
|
||||
// Construct grad graph
|
||||
void PushCurrentGraphToStack();
|
||||
|
@ -259,6 +265,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
static std::mutex instance_lock_;
|
||||
static int64_t graph_id_;
|
||||
size_t grad_order_{0};
|
||||
std::string top_cell_id_;
|
||||
bool grad_flag_{false};
|
||||
bool dynamic_cell_{false};
|
||||
bool grad_is_running_{false};
|
||||
|
@ -282,8 +289,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
// Used for runop and replace forward result of grad graph
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> op_index_with_tensor_id_;
|
||||
std::unordered_map<std::string, std::vector<tensor::TensorPtr>> tensor_id_with_tensor_;
|
||||
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
|
||||
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
};
|
||||
|
|
|
@ -553,7 +553,7 @@ std::string Tensor::ToStringInternal(int limit_size) const {
|
|||
std::ostringstream buf;
|
||||
auto dtype = Dtype();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value=";
|
||||
buf << "Tensor(id=" << id_ << ", shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value=";
|
||||
if (limit_size <= 0 || DataSize() < limit_size) {
|
||||
// Only print data for small tensor.
|
||||
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false);
|
||||
|
|
|
@ -361,7 +361,6 @@ class Cell(Cell_):
|
|||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(origin_grad[i])
|
||||
self._already_run = True
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
|
|
|
@ -182,6 +182,9 @@ class GradOperation(GradOperation_):
|
|||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||
Default: False.
|
||||
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
|
||||
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
|
||||
parameter, the key must be sens.
|
||||
|
||||
Returns:
|
||||
The higher-order function which takes a function as argument and returns gradient function for it.
|
||||
|
@ -311,16 +314,23 @@ class GradOperation(GradOperation_):
|
|||
|
||||
def _pynative_forward_run(self, args, kwargs, fn):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
new_kwargs = {}
|
||||
if self.sens_param:
|
||||
args = args[:-1]
|
||||
if not 'sens' in kwargs.keys():
|
||||
args = args[:-1]
|
||||
new_kwargs = kwargs
|
||||
else:
|
||||
for key, value in kwargs.items():
|
||||
if key != 'sens':
|
||||
new_kwargs[key] = value
|
||||
for arg in args:
|
||||
if not isinstance(arg, Tensor):
|
||||
raise TypeError("grad inputs should be tensor in pynative mode")
|
||||
if isinstance(fn, FunctionType):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(fn, *args, **kwargs)
|
||||
output = fn(*args, **kwargs)
|
||||
_pynative_exec.end_graph(fn, output, *args, **kwargs)
|
||||
_pynative_exec.new_graph(fn, *args, **new_kwargs)
|
||||
output = fn(*args, **new_kwargs)
|
||||
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
|
||||
else:
|
||||
if fn.already_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
|
@ -328,7 +338,7 @@ class GradOperation(GradOperation_):
|
|||
self.need_forward = True
|
||||
if self.need_forward:
|
||||
fn.set_grad()
|
||||
fn(*args, **kwargs)
|
||||
fn(*args, **new_kwargs)
|
||||
fn.already_run = False
|
||||
|
||||
def __call__(self, fn, weights=None):
|
||||
|
|
|
@ -404,10 +404,10 @@ def test_pynative_resnet50():
|
|||
step = step + 1
|
||||
if step > max_step:
|
||||
break
|
||||
start_time = time.time()
|
||||
input_data = element["image"]
|
||||
input_label = element["label"]
|
||||
loss_output = net_with_criterion(input_data, input_label)
|
||||
start_time = time.time()
|
||||
grads = train_network(input_data, input_label)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
|
|
|
@ -403,10 +403,10 @@ def test_pynative_resnet50():
|
|||
step = step + 1
|
||||
if step > max_step:
|
||||
break
|
||||
start_time = time.time()
|
||||
input_data = element["image"]
|
||||
input_label = element["label"]
|
||||
loss_output = net_with_criterion(input_data, input_label)
|
||||
start_time = time.time()
|
||||
grads = train_network(input_data, input_label)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
|
|
Loading…
Reference in New Issue