!17118 Add arg to control whether the output of cell should be recomputed

Merge pull request !17118 from YuJianfeng/recompute
This commit is contained in:
i-robot 2021-06-18 06:08:09 +00:00 committed by Gitee
commit e61c81a8e3
11 changed files with 138 additions and 28 deletions

View File

@ -50,6 +50,7 @@
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
#include "frontend/optimizer/irpass/recompute_prepare.h"
namespace mindspore {
namespace opt {
@ -250,6 +251,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
// recompute
set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
"set_cell_output_no_recompute", IsValueNode<FuncGraph>);
}
ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -156,6 +156,9 @@ class OptimizeIRPassLib {
// Eliminate getattr bool scalar
SubstitutionPtr bool_scalar_eliminate_;
// Recompute
SubstitutionPtr set_cell_output_no_recompute_;
};
// the collection of irpass for resolve action

View File

@ -45,6 +45,10 @@ class ReplaceApplicator : public AnfVisitor {
*(fg->switch_layer_input())) {
return nullptr;
}
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
auto out = fg->output();
MS_EXCEPTION_IF_NULL(out);
@ -100,6 +104,10 @@ class InlinerBase : public AnfVisitor {
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
return nullptr;
}
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
Reset();

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
#include <unordered_set>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
namespace irpass {
class SetCellOutputNoRecompute : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsValueNode<FuncGraph>(node)) {
return nullptr;
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
auto output = fg->output();
if (output == nullptr) {
return nullptr;
}
if (output->isa<CNode>()) {
std::unordered_set<CNodePtr> real_outputs;
GetRealOutputNodes(output, &real_outputs);
for (const auto &real_output : real_outputs) {
auto prim = GetValueNode<PrimitivePtr>(real_output->input(0));
prim->set_attr(kAttrRecompute, MakeValue(false));
}
}
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
return nullptr;
}
void GetRealOutputNodes(const AnfNodePtr &output, std::unordered_set<CNodePtr> *real_outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(real_outputs);
if (!output->isa<CNode>()) {
return;
}
auto output_cnode = output->cast<CNodePtr>();
if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
} else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
GetRealOutputNodes(output_cnode->input(i), real_outputs);
}
} else {
real_outputs->insert(output_cnode);
}
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_

View File

@ -31,6 +31,14 @@ namespace mindspore {
namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
bool CanNotRecomputed(const CNodePtr &node) {
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask};
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -223,11 +231,8 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
return false;
}
if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) {
return true;
}
return false;
return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); });
}
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
@ -263,8 +268,8 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
if (IsBpropNode(node)) {
continue;
}
// Do not recompute the communicate op.
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) {
// Filter some unrecomputable operators.
if (CanNotRecomputed(node)) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {

View File

@ -279,7 +279,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
auto base_graph = cloner->cloned_func_graph()[fg];
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
if (fg->used_global_parameters().empty() || graphs.size() <= 1) {
if (fg->used_global_parameters().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
continue;
}
auto &cloned_nodes = *cloner->cloned_node();

View File

@ -239,7 +239,12 @@ class Parser {
// In order to keep effect order in the sub-graphs which generated by control flow.
// We copy the flags from the top graph to the sub-graphs.
if (func_graph_ && !func_graph_->attrs().empty()) {
block->func_graph()->set_attrs(func_graph_->attrs());
for (const auto &attr : func_graph_->attrs()) {
// The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
block->func_graph()->set_attr(attr.first, attr.second);
}
}
}
func_block_list_.push_back(block);
return block;

View File

@ -302,9 +302,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
},
false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_});
opt::OptPassConfig a_3 = opt::OptPassConfig(
{
irpass.arithmetic_simplify2_,
@ -318,9 +317,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.mini_step_allgather_replace_,
},
false, true);
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({
irpass.less_batch_normalization_,
});
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::irpass::ResolveIRPassLib resolve_irpass;
@ -477,6 +474,12 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
return map;
}
OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
return map;
}
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
return map;
@ -499,6 +502,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_grad_epilogue"] =
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_before_recompute"] =
Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
}
@ -537,6 +542,7 @@ bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
@ -644,6 +650,7 @@ bool PynativeOptPass(const ResourcePtr &res) {
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_before_recompute", OptBeforeRecomputeGroup},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup},

View File

@ -82,6 +82,7 @@ const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";

View File

@ -1213,7 +1213,7 @@ class Cell(Cell_):
elif not self._scope is None and self._scope.startswith(prefix):
self._scope = self._scope[len(prefix):]
def recompute(self, mode=True):
def recompute(self, mode=True, output_recompute=False):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
@ -1228,13 +1228,18 @@ class Cell(Cell_):
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
output_recompute (bool): Specifies whether the output of this cell is recomputed when
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
Validator.check_bool(output_recompute)
self._set_recompute_scope(mode)
if mode and not output_recompute:
self.add_flags(output_no_recompute=True)
for cell in self.cells():
cell.recompute(mode)
cell.recompute(mode, True)
class GraphKernel(Cell):

View File

@ -565,10 +565,7 @@ class Block(nn.Cell):
self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.last_add = P.TensorAdd().shard(
((config.dp, 1, 1), (config.dp, 1, 1)))
# Last activation of this layer will be saved for recompute in backward process
self.last_add.recompute(False)
self.dtype = config.compute_dtype
def construct(self, x, input_mask, layer_past=None):
@ -591,9 +588,9 @@ class Block(nn.Cell):
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
if self.post_layernorm_residual:
output = self.last_add(output_x, mlp_logit)
output = self.add(output_x, mlp_logit)
else:
output = self.last_add(x, mlp_logit)
output = self.add(x, mlp_logit)
return output, layer_present
@ -653,10 +650,6 @@ class QueryLayer(nn.Cell):
self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.last_add = P.TensorAdd().shard(
((config.dp, 1, 1), (config.dp, 1,
1))).add_prim_attr("recompute", False)
self.dtype = config.compute_dtype
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
@ -679,9 +672,9 @@ class QueryLayer(nn.Cell):
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
if self.post_layernorm_residual:
output = self.last_add(output_x, mlp_logit)
output = self.add(output_x, mlp_logit)
else:
output = self.last_add(x, mlp_logit)
output = self.add(x, mlp_logit)
return output, layer_present
class PanguAlpha_Model(nn.Cell):