forked from mindspore-Ecosystem/mindspore
add recompute nodes
This commit is contained in:
parent
237faca57e
commit
7b412d7cb2
|
@ -38,9 +38,6 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne
|
|||
if (value->isa<Scalar>()) {
|
||||
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
|
||||
|
@ -89,7 +86,11 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto new_cnode = func_graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
new_cnode->set_abstract(new_inputs[1]->abstract());
|
||||
} else {
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
}
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
|
@ -123,7 +124,8 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
|||
|
||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -0,0 +1,442 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
#include "ir/func_graph.h"
|
||||
#include "mindspore/core/base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kGradientsFlag = "Gradients";
|
||||
constexpr auto kAttrRecomputed = "recomputed";
|
||||
constexpr auto kAttrNoRecomputed = "no_recomputed";
|
||||
bool IsTargetNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
return node->fullname_with_scope().find(kGradientsFlag) == 0;
|
||||
}
|
||||
|
||||
bool HasNoRecomputedAttr(const AnfNodePtr &node) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim != nullptr) {
|
||||
auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed);
|
||||
if (no_recompute_val != nullptr && no_recompute_val->isa<BoolImm>()) {
|
||||
return GetValue<bool>(no_recompute_val);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool WithRecomputedScope(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
return node->fullname_with_scope().find(kAttrRecomputed) == 0;
|
||||
}
|
||||
|
||||
bool IsSetRecomputed(const AnfNodePtr &node) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim != nullptr) {
|
||||
auto recompute_val = prim->GetAttr(kAttrRecomputed);
|
||||
if (recompute_val != nullptr && recompute_val->isa<BoolImm>()) {
|
||||
return GetValue<bool>(recompute_val);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); }
|
||||
|
||||
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
|
||||
const std::vector<CNodePtr> &cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<CNodePtr> candidate_recomputed_nodes;
|
||||
for (const auto &cnode : cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsCandidateRecomputedNode(cnode)) {
|
||||
continue;
|
||||
}
|
||||
// Check outputs.
|
||||
const auto &node_users = mng->node_users();
|
||||
auto output_set_iter = node_users.find(cnode);
|
||||
if (output_set_iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &node_index_set = output_set_iter->second;
|
||||
if (!std::any_of(node_index_set.begin(), node_index_set.end(),
|
||||
[](const std::pair<AnfNodePtr, int> &node_index) { return IsTargetNode(node_index.first); })) {
|
||||
continue;
|
||||
}
|
||||
// Check inputs.
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) {
|
||||
continue;
|
||||
}
|
||||
candidate_recomputed_nodes.emplace_back(cnode);
|
||||
}
|
||||
return candidate_recomputed_nodes;
|
||||
}
|
||||
|
||||
void GetMaxSubGraph(const FuncGraphManagerPtr &mng, std::unordered_set<CNodePtr> *recomputed_nodes, bool get_inputs,
|
||||
bool get_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(recomputed_nodes);
|
||||
std::queue<CNodePtr> nodes_to_visit;
|
||||
for (const auto &node : *recomputed_nodes) {
|
||||
nodes_to_visit.push(node);
|
||||
}
|
||||
recomputed_nodes->clear();
|
||||
while (!nodes_to_visit.empty()) {
|
||||
auto current_node = nodes_to_visit.front();
|
||||
nodes_to_visit.pop();
|
||||
recomputed_nodes->insert(current_node);
|
||||
if (get_inputs) {
|
||||
for (const auto &input : current_node->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<CNode>()) {
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() &&
|
||||
IsCandidateRecomputedNode(input_cnode)) {
|
||||
nodes_to_visit.push(input_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (get_outputs) {
|
||||
const auto &node_users = mng->node_users();
|
||||
auto output_set_iter = node_users.find(current_node);
|
||||
if (output_set_iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
auto output_node = node_index_set.first;
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (output_node->isa<CNode>()) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() &&
|
||||
IsCandidateRecomputedNode(output_cnode)) {
|
||||
nodes_to_visit.push(output_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
|
||||
const std::unordered_set<CNodePtr> &max_recomputed_sub_graph,
|
||||
std::unordered_set<CNodePtr> *recompute_nodes,
|
||||
std::unordered_set<CNodePtr> *target_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(recompute_nodes);
|
||||
MS_EXCEPTION_IF_NULL(target_nodes);
|
||||
const auto &node_users = mng->node_users();
|
||||
for (const auto &node : max_recomputed_sub_graph) {
|
||||
bool inserted = false;
|
||||
auto output_set_iter = node_users.find(node);
|
||||
if (output_set_iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
auto output_node = node_index_set.first;
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (!IsTargetNode(output_node)) {
|
||||
continue;
|
||||
}
|
||||
target_nodes->insert(output_node->cast<CNodePtr>());
|
||||
if (!inserted) {
|
||||
recompute_nodes->insert(node);
|
||||
inserted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> GetFirstTargetInputs(const std::vector<CNodePtr> &origin_nodes_topological,
|
||||
const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
|
||||
const std::unordered_set<CNodePtr> &target_nodes) {
|
||||
std::vector<AnfNodePtr> first_target_inputs;
|
||||
for (const auto &node : origin_nodes_topological) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (target_nodes.find(node) != target_nodes.end()) {
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
auto input = node->input(i);
|
||||
if (!input->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (recomputed_origin_nodes.find(input->cast<CNodePtr>()) != recomputed_origin_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
first_target_inputs.emplace_back(input);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return first_target_inputs;
|
||||
}
|
||||
|
||||
bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool> *has_grad_inputs_map) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(has_grad_inputs_map);
|
||||
if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
|
||||
return has_grad_inputs_map->find(node)->second;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
has_grad_inputs_map->insert(std::make_pair(node, false));
|
||||
return false;
|
||||
}
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
|
||||
return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map);
|
||||
})) {
|
||||
has_grad_inputs_map->insert(std::make_pair(node, true));
|
||||
return true;
|
||||
}
|
||||
has_grad_inputs_map->insert(std::make_pair(node, false));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
const auto &node_users = mng->node_users();
|
||||
auto output_set_iter = node_users.find(node);
|
||||
if (output_set_iter == node_users.end()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
|
||||
std::vector<AnfNodePtr> *tuple_getitem_output_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes);
|
||||
const auto &node_users = mng->node_users();
|
||||
auto output_set_iter = node_users.find(node);
|
||||
if (output_set_iter == node_users.end()) {
|
||||
return;
|
||||
}
|
||||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) {
|
||||
tuple_getitem_output_nodes->emplace_back(node_index_set.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set 'recomputed' attr for the nodes according to its scope.
|
||||
// A node set 'recomputed' attr can be the candidate recomputed node.
|
||||
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto mng = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
|
||||
for (const auto &node : origin_nodes_topological) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) {
|
||||
continue;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() ||
|
||||
prim->name() == prim::kPrimAllGather->name()) {
|
||||
continue;
|
||||
}
|
||||
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make a new primitive to set attr because some nodes share the same primitive probably.
|
||||
auto new_prim = std::make_shared<Primitive>(prim->name());
|
||||
new_prim->SetAttrs(prim->attrs());
|
||||
new_prim->set_prim_type(prim->prim_type());
|
||||
new_prim->set_attr(kAttrRecomputed, MakeValue(true));
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
|
||||
const auto &origin_inputs = node->inputs();
|
||||
std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
|
||||
auto new_node = graph->NewCNode(new_inputs);
|
||||
new_node->set_abstract(node->abstract());
|
||||
new_node->set_scope(node->scope());
|
||||
mng->Replace(node, new_node);
|
||||
|
||||
// Set attr for the tuple_getitem outputs.
|
||||
std::vector<AnfNodePtr> tuple_getitem_output_nodes;
|
||||
GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes);
|
||||
for (const auto &output_node : tuple_getitem_output_nodes) {
|
||||
auto new_output_prim = std::make_shared<Primitive>(prim::kPrimTupleGetItem->name());
|
||||
new_output_prim->set_attr(kAttrRecomputed, MakeValue(true));
|
||||
std::vector<AnfNodePtr> new_tuple_getitem_inputs{NewValueNode(new_output_prim)};
|
||||
auto origin_tuple_getitem_inputs = output_node->cast<CNodePtr>()->inputs();
|
||||
std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(),
|
||||
std::back_inserter(new_tuple_getitem_inputs));
|
||||
auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs);
|
||||
new_tuple_getitem->set_abstract(output_node->abstract());
|
||||
mng->Replace(output_node, new_tuple_getitem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const std::vector<AnfNodePtr> &first_target_inputs,
|
||||
const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
|
||||
std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes);
|
||||
auto iter = origin_to_recomputed_nodes->find(origin_node);
|
||||
if (iter != origin_to_recomputed_nodes->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
|
||||
auto prim = GetCNodePrimitive(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto new_prim = std::make_shared<Primitive>(prim->name());
|
||||
new_prim->SetAttrs(prim->attrs());
|
||||
new_prim->set_attr("duplicated", MakeValue(true));
|
||||
new_prim->set_prim_type(prim->prim_type());
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
|
||||
bool has_recomputed_inputs = false;
|
||||
for (size_t i = 1; i < origin_node->size(); ++i) {
|
||||
auto input = origin_node->input(i);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (!input->isa<CNode>()) {
|
||||
new_inputs.emplace_back(input);
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
|
||||
new_inputs.emplace_back(input);
|
||||
} else {
|
||||
has_recomputed_inputs = true;
|
||||
new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,
|
||||
origin_to_recomputed_nodes));
|
||||
}
|
||||
}
|
||||
// Add the execution dependency.
|
||||
if (!has_recomputed_inputs && new_inputs.size() > 1) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs));
|
||||
auto first_input = new_inputs[1];
|
||||
MS_EXCEPTION_IF_NULL(first_input);
|
||||
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_input,
|
||||
graph->NewCNode(make_tuple_inputs)};
|
||||
auto depend_node = graph->NewCNode(depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
depend_node->set_abstract(first_input->abstract());
|
||||
new_inputs[1] = depend_node;
|
||||
}
|
||||
auto recomputed_node = graph->NewCNode(new_inputs);
|
||||
recomputed_node->set_abstract(origin_node->abstract());
|
||||
recomputed_node->set_scope(origin_node->scope());
|
||||
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
|
||||
return recomputed_node;
|
||||
}
|
||||
|
||||
void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_set<CNodePtr> &target_nodes,
|
||||
const std::unordered_set<CNodePtr> &origin_recomputed_nodes,
|
||||
const std::vector<AnfNodePtr> &first_target_inputs,
|
||||
std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto mng = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (const auto &target_node : target_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
|
||||
auto target_cnode = target_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(target_cnode);
|
||||
auto prim = GetCNodePrimitive(target_cnode);
|
||||
if (prim != nullptr) {
|
||||
prim->set_attr("target_grad", MakeValue(true));
|
||||
}
|
||||
std::vector<AnfNodePtr> new_target_inputs;
|
||||
for (const auto &input : target_cnode->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (!input->isa<CNode>()) {
|
||||
new_target_inputs.emplace_back(input);
|
||||
} else {
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) {
|
||||
new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs,
|
||||
origin_recomputed_nodes, origin_to_recomputed_nodes));
|
||||
} else {
|
||||
new_target_inputs.emplace_back(input_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_target_node = graph->NewCNode(new_target_inputs);
|
||||
new_target_node->set_abstract(target_node->abstract());
|
||||
new_target_node->set_scope(target_node->scope());
|
||||
mng->Replace(target_node, new_target_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void InsertRecomputedNodes(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto mng = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::list<CNodePtr> old_orders = graph->GetOrderedCnodes();
|
||||
std::vector<CNodePtr> old_nodes_topological(old_orders.begin(), old_orders.end());
|
||||
SetRecomputedAttr(graph, old_nodes_topological);
|
||||
std::list<CNodePtr> new_orders = graph->GetOrderedCnodes();
|
||||
std::vector<CNodePtr> origin_nodes_topological(new_orders.begin(), new_orders.end());
|
||||
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
|
||||
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
|
||||
std::unordered_set<CNodePtr> visited_nodes;
|
||||
for (const auto &candidate_recomputed_node : candidate_recomputed_nodes) {
|
||||
if (visited_nodes.find(candidate_recomputed_node) != visited_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<CNodePtr> max_recomputed_sub_graph = {candidate_recomputed_node};
|
||||
// Get max continuous recomputed sub-graph.
|
||||
GetMaxSubGraph(mng, &max_recomputed_sub_graph, true, true);
|
||||
visited_nodes.insert(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end());
|
||||
// Get the origin recomputed nodes which directly output to the grad nodes.
|
||||
std::unordered_set<CNodePtr> origin_recomputed_nodes;
|
||||
std::unordered_set<CNodePtr> target_nodes;
|
||||
GetOriginRecomputeAndTargetNodes(mng, max_recomputed_sub_graph, &origin_recomputed_nodes, &target_nodes);
|
||||
// Also get the inputs of origin recomputed nodes which eventually output to the grad nodes.
|
||||
GetMaxSubGraph(mng, &origin_recomputed_nodes, true, false);
|
||||
|
||||
// Get the inputs of the first target node in the topological sequence. The duplicated recomputed nodes should
|
||||
// not be executed until these inputs are ready.
|
||||
std::vector<AnfNodePtr> first_target_inputs =
|
||||
GetFirstTargetInputs(origin_nodes_topological, origin_recomputed_nodes, target_nodes);
|
||||
std::unordered_map<CNodePtr, CNodePtr> origin_to_recomputed_nodes;
|
||||
// Begin duplicate origin recomputed nodes with each target node.
|
||||
DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs,
|
||||
&origin_to_recomputed_nodes);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// Automatically insert duplicated recomputed nodes.
|
||||
void InsertRecomputedNodes(const FuncGraphPtr &graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
|
|
@ -37,6 +37,7 @@
|
|||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/step_auto_parallel.h"
|
||||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
|
||||
|
@ -383,6 +384,12 @@ bool AddControlDependPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AddRecomputationPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
opt::InsertRecomputedNodes(res->func_graph());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MergeDupGraphPass(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -474,7 +481,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"tuple_transform", OptPassTransformGraphGroup},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"add_recomputation", AddRecomputationPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
|
|
|
@ -913,6 +913,8 @@ class Cell(Cell_):
|
|||
"""Sets the name on the first time."""
|
||||
if self._scope is None:
|
||||
self._scope = name
|
||||
elif self._scope == 'recomputed':
|
||||
self._scope = self._scope + "_" + name
|
||||
|
||||
def _children_scope_recursive(self, parent_prefix='Default'):
|
||||
"""Generates the scope of each layer of the network recursively."""
|
||||
|
@ -1093,6 +1095,15 @@ class Cell(Cell_):
|
|||
param.comm_fusion = fusion_type
|
||||
return self
|
||||
|
||||
def recompute(self):
|
||||
"""
|
||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
|
||||
node and is set recomputed, we will compute it again for the grad node after the forward computation.
|
||||
"""
|
||||
self._set_scope('recomputed')
|
||||
for cell in self.cells():
|
||||
cell.recompute()
|
||||
|
||||
|
||||
class GraphKernel(Cell):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 32
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool.recompute()
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.conv2(output)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
print("+++++++++Loss+++++++++++++")
|
||||
print(res)
|
||||
print("+++++++++++++++++++++++++++")
|
||||
diff = res.asnumpy() - 2.302585
|
||||
assert np.all(diff < 1.e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet():
|
||||
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
train(net, data, label)
|
Loading…
Reference in New Issue