!17118 Add arg to control whether the output of cell should be recomputed

Merge pull request !17118 from YuJianfeng/recompute
This commit is contained in:
i-robot 2021-06-18 06:08:09 +00:00 committed by Gitee
commit e61c81a8e3
11 changed files with 138 additions and 28 deletions

View File

@ -50,6 +50,7 @@
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h" #include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h" #include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h" #include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
#include "frontend/optimizer/irpass/recompute_prepare.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -250,6 +251,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode); bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
// recompute
set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
"set_cell_output_no_recompute", IsValueNode<FuncGraph>);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -156,6 +156,9 @@ class OptimizeIRPassLib {
// Eliminate getattr bool scalar // Eliminate getattr bool scalar
SubstitutionPtr bool_scalar_eliminate_; SubstitutionPtr bool_scalar_eliminate_;
// Recompute
SubstitutionPtr set_cell_output_no_recompute_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action

View File

@ -45,6 +45,10 @@ class ReplaceApplicator : public AnfVisitor {
*(fg->switch_layer_input())) { *(fg->switch_layer_input())) {
return nullptr; return nullptr;
} }
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
auto out = fg->output(); auto out = fg->output();
MS_EXCEPTION_IF_NULL(out); MS_EXCEPTION_IF_NULL(out);
@ -100,6 +104,10 @@ class InlinerBase : public AnfVisitor {
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) { if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
return nullptr; return nullptr;
} }
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
Reset(); Reset();

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
#include <unordered_set>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
namespace irpass {
class SetCellOutputNoRecompute : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsValueNode<FuncGraph>(node)) {
return nullptr;
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
auto output = fg->output();
if (output == nullptr) {
return nullptr;
}
if (output->isa<CNode>()) {
std::unordered_set<CNodePtr> real_outputs;
GetRealOutputNodes(output, &real_outputs);
for (const auto &real_output : real_outputs) {
auto prim = GetValueNode<PrimitivePtr>(real_output->input(0));
prim->set_attr(kAttrRecompute, MakeValue(false));
}
}
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
return nullptr;
}
void GetRealOutputNodes(const AnfNodePtr &output, std::unordered_set<CNodePtr> *real_outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(real_outputs);
if (!output->isa<CNode>()) {
return;
}
auto output_cnode = output->cast<CNodePtr>();
if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
} else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
GetRealOutputNodes(output_cnode->input(i), real_outputs);
}
} else {
real_outputs->insert(output_cnode);
}
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_

View File

@ -31,6 +31,14 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr auto kGradientsFlag = "Gradients"; constexpr auto kGradientsFlag = "Gradients";
bool CanNotRecomputed(const CNodePtr &node) {
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask};
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool IsBpropNode(const AnfNodePtr &node) { bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
@ -223,11 +231,8 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
return false; return false;
} }
if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) { [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); });
return true;
}
return false;
} }
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
@ -263,8 +268,8 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
if (IsBpropNode(node)) { if (IsBpropNode(node)) {
continue; continue;
} }
// Do not recompute the communicate op. // Filter some unrecomputable operators.
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) { if (CanNotRecomputed(node)) {
continue; continue;
} }
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {

View File

@ -279,7 +279,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
auto base_graph = cloner->cloned_func_graph()[fg]; auto base_graph = cloner->cloned_func_graph()[fg];
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
if (fg->used_global_parameters().empty() || graphs.size() <= 1) { if (fg->used_global_parameters().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
continue; continue;
} }
auto &cloned_nodes = *cloner->cloned_node(); auto &cloned_nodes = *cloner->cloned_node();

View File

@ -239,7 +239,12 @@ class Parser {
// In order to keep effect order in the sub-graphs which generated by control flow. // In order to keep effect order in the sub-graphs which generated by control flow.
// We copy the flags from the top graph to the sub-graphs. // We copy the flags from the top graph to the sub-graphs.
if (func_graph_ && !func_graph_->attrs().empty()) { if (func_graph_ && !func_graph_->attrs().empty()) {
block->func_graph()->set_attrs(func_graph_->attrs()); for (const auto &attr : func_graph_->attrs()) {
// The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
block->func_graph()->set_attr(attr.first, attr.second);
}
}
} }
func_block_list_.push_back(block); func_block_list_.push_back(block);
return block; return block;

View File

@ -302,9 +302,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
}, },
false, true); false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({ opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_});
irpass.inline_without_move_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig( opt::OptPassConfig a_3 = opt::OptPassConfig(
{ {
irpass.arithmetic_simplify2_, irpass.arithmetic_simplify2_,
@ -318,9 +317,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.mini_step_allgather_replace_, irpass.mini_step_allgather_replace_,
}, },
false, true); false, true);
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({ opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
irpass.less_batch_normalization_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::irpass::ResolveIRPassLib resolve_irpass; opt::irpass::ResolveIRPassLib resolve_irpass;
@ -477,6 +474,12 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
return map; return map;
} }
OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
return map;
}
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) { OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}}); OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
return map; return map;
@ -499,6 +502,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_grad_epilogue"] = g_pass_opts["opt_grad_epilogue"] =
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_before_recompute"] =
Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
g_pass_opts["opt_after_recompute"] = g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
} }
@ -537,6 +542,7 @@ bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); } bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); } bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
@ -644,6 +650,7 @@ bool PynativeOptPass(const ResourcePtr &res) {
} }
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_before_recompute", OptBeforeRecomputeGroup},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass}, {"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},

View File

@ -82,6 +82,7 @@ const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
const char kFuncGraphFlagUndetermined[] = "Undeterminate"; const char kFuncGraphFlagUndetermined[] = "Undeterminate";
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry"; const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";

View File

@ -1213,7 +1213,7 @@ class Cell(Cell_):
elif not self._scope is None and self._scope.startswith(prefix): elif not self._scope is None and self._scope.startswith(prefix):
self._scope = self._scope[len(prefix):] self._scope = self._scope[len(prefix):]
def recompute(self, mode=True): def recompute(self, mode=True, output_recompute=False):
""" """
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into some backward nodes for computing gradient, rather than storing the set recomputed feeds into some backward nodes for computing gradient, rather than storing the
@ -1228,13 +1228,18 @@ class Cell(Cell_):
Args: Args:
mode (bool): Specifies whether the cell is recomputed. Default: True. mode (bool): Specifies whether the cell is recomputed. Default: True.
output_recompute (bool): Specifies whether the output of this cell is recomputed when
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
""" """
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.") raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode) Validator.check_bool(mode)
Validator.check_bool(output_recompute)
self._set_recompute_scope(mode) self._set_recompute_scope(mode)
if mode and not output_recompute:
self.add_flags(output_no_recompute=True)
for cell in self.cells(): for cell in self.cells():
cell.recompute(mode) cell.recompute(mode, True)
class GraphKernel(Cell): class GraphKernel(Cell):

View File

@ -565,10 +565,7 @@ class Block(nn.Cell):
self.output = Output(config, scale) self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.last_add = P.TensorAdd().shard(
((config.dp, 1, 1), (config.dp, 1, 1)))
# Last activation of this layer will be saved for recompute in backward process # Last activation of this layer will be saved for recompute in backward process
self.last_add.recompute(False)
self.dtype = config.compute_dtype self.dtype = config.compute_dtype
def construct(self, x, input_mask, layer_past=None): def construct(self, x, input_mask, layer_past=None):
@ -591,9 +588,9 @@ class Block(nn.Cell):
output_x = F.cast(output_x, self.dtype) output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x) mlp_logit = self.output(output_x)
if self.post_layernorm_residual: if self.post_layernorm_residual:
output = self.last_add(output_x, mlp_logit) output = self.add(output_x, mlp_logit)
else: else:
output = self.last_add(x, mlp_logit) output = self.add(x, mlp_logit)
return output, layer_present return output, layer_present
@ -653,10 +650,6 @@ class QueryLayer(nn.Cell):
self.output = Output(config, scale) self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.last_add = P.TensorAdd().shard(
((config.dp, 1, 1), (config.dp, 1,
1))).add_prim_attr("recompute", False)
self.dtype = config.compute_dtype self.dtype = config.compute_dtype
def construct(self, x, query_hidden_state, input_mask, layer_past=None): def construct(self, x, query_hidden_state, input_mask, layer_past=None):
@ -679,9 +672,9 @@ class QueryLayer(nn.Cell):
output_x = F.cast(output_x, self.dtype) output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x) mlp_logit = self.output(output_x)
if self.post_layernorm_residual: if self.post_layernorm_residual:
output = self.last_add(output_x, mlp_logit) output = self.add(output_x, mlp_logit)
else: else:
output = self.last_add(x, mlp_logit) output = self.add(x, mlp_logit)
return output, layer_present return output, layer_present
class PanguAlpha_Model(nn.Cell): class PanguAlpha_Model(nn.Cell):