forked from mindspore-Ecosystem/mindspore
!17118 Add arg to control whether the output of cell should be recomputed
Merge pull request !17118 from YuJianfeng/recompute
This commit is contained in:
commit
e61c81a8e3
|
@ -50,6 +50,7 @@
|
||||||
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
|
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
|
||||||
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
||||||
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
|
#include "frontend/optimizer/irpass/bool_scalar_eliminate.h"
|
||||||
|
#include "frontend/optimizer/irpass/recompute_prepare.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -250,6 +251,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
||||||
|
|
||||||
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
|
bool_scalar_eliminate_ = MakeSubstitution(std::make_shared<BoolScalarEliminate>(), "bool_scalar_eliminate_", IsCNode);
|
||||||
|
|
||||||
|
// recompute
|
||||||
|
set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
|
||||||
|
"set_cell_output_no_recompute", IsValueNode<FuncGraph>);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResolveIRPassLib::ResolveIRPassLib() {
|
ResolveIRPassLib::ResolveIRPassLib() {
|
||||||
|
|
|
@ -156,6 +156,9 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// Eliminate getattr bool scalar
|
// Eliminate getattr bool scalar
|
||||||
SubstitutionPtr bool_scalar_eliminate_;
|
SubstitutionPtr bool_scalar_eliminate_;
|
||||||
|
|
||||||
|
// Recompute
|
||||||
|
SubstitutionPtr set_cell_output_no_recompute_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the collection of irpass for resolve action
|
// the collection of irpass for resolve action
|
||||||
|
|
|
@ -45,6 +45,10 @@ class ReplaceApplicator : public AnfVisitor {
|
||||||
*(fg->switch_layer_input())) {
|
*(fg->switch_layer_input())) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
|
||||||
|
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
auto out = fg->output();
|
auto out = fg->output();
|
||||||
MS_EXCEPTION_IF_NULL(out);
|
MS_EXCEPTION_IF_NULL(out);
|
||||||
|
@ -100,6 +104,10 @@ class InlinerBase : public AnfVisitor {
|
||||||
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
|
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
|
||||||
|
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
Reset();
|
Reset();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
#include "frontend/optimizer/optimizer.h"
|
||||||
|
#include "frontend/optimizer/anf_visitor.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
class SetCellOutputNoRecompute : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
if (!IsValueNode<FuncGraph>(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
|
if (fg == nullptr || !fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto output = fg->output();
|
||||||
|
if (output == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (output->isa<CNode>()) {
|
||||||
|
std::unordered_set<CNodePtr> real_outputs;
|
||||||
|
GetRealOutputNodes(output, &real_outputs);
|
||||||
|
for (const auto &real_output : real_outputs) {
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(real_output->input(0));
|
||||||
|
prim->set_attr(kAttrRecompute, MakeValue(false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetRealOutputNodes(const AnfNodePtr &output, std::unordered_set<CNodePtr> *real_outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
MS_EXCEPTION_IF_NULL(real_outputs);
|
||||||
|
if (!output->isa<CNode>()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto output_cnode = output->cast<CNodePtr>();
|
||||||
|
if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
|
||||||
|
GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
|
||||||
|
} else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
|
||||||
|
auto &inputs = output_cnode->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
GetRealOutputNodes(output_cnode->input(i), real_outputs);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
real_outputs->insert(output_cnode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_PREPARE_H_
|
|
@ -31,6 +31,14 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr auto kGradientsFlag = "Gradients";
|
constexpr auto kGradientsFlag = "Gradients";
|
||||||
|
|
||||||
|
bool CanNotRecomputed(const CNodePtr &node) {
|
||||||
|
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask};
|
||||||
|
|
||||||
|
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
|
}
|
||||||
|
|
||||||
bool IsBpropNode(const AnfNodePtr &node) {
|
bool IsBpropNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
|
@ -223,11 +231,8 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
|
return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
|
||||||
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) {
|
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); });
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
|
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
|
||||||
|
@ -263,8 +268,8 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
|
||||||
if (IsBpropNode(node)) {
|
if (IsBpropNode(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Do not recompute the communicate op.
|
// Filter some unrecomputable operators.
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimAllGather)) {
|
if (CanNotRecomputed(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||||
|
|
|
@ -279,7 +279,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
||||||
auto base_graph = cloner->cloned_func_graph()[fg];
|
auto base_graph = cloner->cloned_func_graph()[fg];
|
||||||
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
|
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
|
||||||
|
|
||||||
if (fg->used_global_parameters().empty() || graphs.size() <= 1) {
|
if (fg->used_global_parameters().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto &cloned_nodes = *cloner->cloned_node();
|
auto &cloned_nodes = *cloner->cloned_node();
|
||||||
|
|
|
@ -239,7 +239,12 @@ class Parser {
|
||||||
// In order to keep effect order in the sub-graphs which generated by control flow.
|
// In order to keep effect order in the sub-graphs which generated by control flow.
|
||||||
// We copy the flags from the top graph to the sub-graphs.
|
// We copy the flags from the top graph to the sub-graphs.
|
||||||
if (func_graph_ && !func_graph_->attrs().empty()) {
|
if (func_graph_ && !func_graph_->attrs().empty()) {
|
||||||
block->func_graph()->set_attrs(func_graph_->attrs());
|
for (const auto &attr : func_graph_->attrs()) {
|
||||||
|
// The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
|
||||||
|
if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
|
||||||
|
block->func_graph()->set_attr(attr.first, attr.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
func_block_list_.push_back(block);
|
func_block_list_.push_back(block);
|
||||||
return block;
|
return block;
|
||||||
|
|
|
@ -302,9 +302,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
},
|
},
|
||||||
false, true);
|
false, true);
|
||||||
|
|
||||||
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
|
opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_});
|
||||||
irpass.inline_without_move_,
|
|
||||||
});
|
|
||||||
opt::OptPassConfig a_3 = opt::OptPassConfig(
|
opt::OptPassConfig a_3 = opt::OptPassConfig(
|
||||||
{
|
{
|
||||||
irpass.arithmetic_simplify2_,
|
irpass.arithmetic_simplify2_,
|
||||||
|
@ -318,9 +317,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.mini_step_allgather_replace_,
|
irpass.mini_step_allgather_replace_,
|
||||||
},
|
},
|
||||||
false, true);
|
false, true);
|
||||||
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({
|
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
|
||||||
irpass.less_batch_normalization_,
|
|
||||||
});
|
|
||||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||||
|
|
||||||
|
@ -477,6 +474,12 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
|
opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
|
||||||
|
OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
|
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
|
||||||
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
|
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
|
||||||
return map;
|
return map;
|
||||||
|
@ -499,6 +502,8 @@ void InitOpt(const ResourcePtr &res) {
|
||||||
g_pass_opts["opt_grad_epilogue"] =
|
g_pass_opts["opt_grad_epilogue"] =
|
||||||
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
|
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
|
||||||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||||
|
g_pass_opts["opt_before_recompute"] =
|
||||||
|
Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
|
||||||
g_pass_opts["opt_after_recompute"] =
|
g_pass_opts["opt_after_recompute"] =
|
||||||
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
||||||
}
|
}
|
||||||
|
@ -537,6 +542,7 @@ bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "
|
||||||
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
||||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||||
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
||||||
|
bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
|
||||||
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
||||||
|
|
||||||
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
||||||
|
@ -644,6 +650,7 @@ bool PynativeOptPass(const ResourcePtr &res) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
|
{"opt_before_recompute", OptBeforeRecomputeGroup},
|
||||||
{"opt_a", OptPassAGroup},
|
{"opt_a", OptPassAGroup},
|
||||||
{"clean_after_opta", CleanAfterOptAPass},
|
{"clean_after_opta", CleanAfterOptAPass},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
|
|
|
@ -82,6 +82,7 @@ const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
|
||||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||||
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
||||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||||
|
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
|
||||||
|
|
||||||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||||
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
|
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
|
||||||
|
|
|
@ -1213,7 +1213,7 @@ class Cell(Cell_):
|
||||||
elif not self._scope is None and self._scope.startswith(prefix):
|
elif not self._scope is None and self._scope.startswith(prefix):
|
||||||
self._scope = self._scope[len(prefix):]
|
self._scope = self._scope[len(prefix):]
|
||||||
|
|
||||||
def recompute(self, mode=True):
|
def recompute(self, mode=True, output_recompute=False):
|
||||||
"""
|
"""
|
||||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||||
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
|
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
|
||||||
|
@ -1228,13 +1228,18 @@ class Cell(Cell_):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode (bool): Specifies whether the cell is recomputed. Default: True.
|
mode (bool): Specifies whether the cell is recomputed. Default: True.
|
||||||
|
output_recompute (bool): Specifies whether the output of this cell is recomputed when
|
||||||
|
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
|
||||||
"""
|
"""
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
raise TypeError("Recompute is not supported in pynative mode currently.")
|
raise TypeError("Recompute is not supported in pynative mode currently.")
|
||||||
Validator.check_bool(mode)
|
Validator.check_bool(mode)
|
||||||
|
Validator.check_bool(output_recompute)
|
||||||
self._set_recompute_scope(mode)
|
self._set_recompute_scope(mode)
|
||||||
|
if mode and not output_recompute:
|
||||||
|
self.add_flags(output_no_recompute=True)
|
||||||
for cell in self.cells():
|
for cell in self.cells():
|
||||||
cell.recompute(mode)
|
cell.recompute(mode, True)
|
||||||
|
|
||||||
|
|
||||||
class GraphKernel(Cell):
|
class GraphKernel(Cell):
|
||||||
|
|
|
@ -565,10 +565,7 @@ class Block(nn.Cell):
|
||||||
self.output = Output(config, scale)
|
self.output = Output(config, scale)
|
||||||
self.post_layernorm_residual = config.post_layernorm_residual
|
self.post_layernorm_residual = config.post_layernorm_residual
|
||||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
self.last_add = P.TensorAdd().shard(
|
|
||||||
((config.dp, 1, 1), (config.dp, 1, 1)))
|
|
||||||
# Last activation of this layer will be saved for recompute in backward process
|
# Last activation of this layer will be saved for recompute in backward process
|
||||||
self.last_add.recompute(False)
|
|
||||||
self.dtype = config.compute_dtype
|
self.dtype = config.compute_dtype
|
||||||
|
|
||||||
def construct(self, x, input_mask, layer_past=None):
|
def construct(self, x, input_mask, layer_past=None):
|
||||||
|
@ -591,9 +588,9 @@ class Block(nn.Cell):
|
||||||
output_x = F.cast(output_x, self.dtype)
|
output_x = F.cast(output_x, self.dtype)
|
||||||
mlp_logit = self.output(output_x)
|
mlp_logit = self.output(output_x)
|
||||||
if self.post_layernorm_residual:
|
if self.post_layernorm_residual:
|
||||||
output = self.last_add(output_x, mlp_logit)
|
output = self.add(output_x, mlp_logit)
|
||||||
else:
|
else:
|
||||||
output = self.last_add(x, mlp_logit)
|
output = self.add(x, mlp_logit)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
|
|
||||||
|
@ -653,10 +650,6 @@ class QueryLayer(nn.Cell):
|
||||||
self.output = Output(config, scale)
|
self.output = Output(config, scale)
|
||||||
self.post_layernorm_residual = config.post_layernorm_residual
|
self.post_layernorm_residual = config.post_layernorm_residual
|
||||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
|
|
||||||
self.last_add = P.TensorAdd().shard(
|
|
||||||
((config.dp, 1, 1), (config.dp, 1,
|
|
||||||
1))).add_prim_attr("recompute", False)
|
|
||||||
self.dtype = config.compute_dtype
|
self.dtype = config.compute_dtype
|
||||||
|
|
||||||
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
||||||
|
@ -679,9 +672,9 @@ class QueryLayer(nn.Cell):
|
||||||
output_x = F.cast(output_x, self.dtype)
|
output_x = F.cast(output_x, self.dtype)
|
||||||
mlp_logit = self.output(output_x)
|
mlp_logit = self.output(output_x)
|
||||||
if self.post_layernorm_residual:
|
if self.post_layernorm_residual:
|
||||||
output = self.last_add(output_x, mlp_logit)
|
output = self.add(output_x, mlp_logit)
|
||||||
else:
|
else:
|
||||||
output = self.last_add(x, mlp_logit)
|
output = self.add(x, mlp_logit)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
class PanguAlpha_Model(nn.Cell):
|
class PanguAlpha_Model(nn.Cell):
|
||||||
|
|
Loading…
Reference in New Issue