forked from mindspore-Ecosystem/mindspore
Revert "Revert "add pattern AdjustAllReduceMulAdduse the old opadd test case for bugtemp fix try""
This reverts commit 705c71a257
.
This commit is contained in:
parent
beb714d2d0
commit
b16a552d41
|
@ -253,6 +253,7 @@ const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
|||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// Debug ops
|
||||
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
|
|
|
@ -259,6 +259,7 @@ extern const PrimitivePtr kPrimMixedPrecisionCast;
|
|||
extern const PrimitivePtr kPrimIsConsant;
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimAllReduce;
|
||||
extern const PrimitivePtr kPrimMirror;
|
||||
extern const PrimitivePtr kPrimVirtualDiv;
|
||||
extern const PrimitivePtr kPrimVirtualDataset;
|
||||
|
|
|
@ -54,6 +54,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ =
|
||||
|
|
|
@ -35,6 +35,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr arithmetic_simplify_;
|
||||
SubstitutionPtr special_op_eliminate_;
|
||||
SubstitutionPtr zero_like_fill_zero_;
|
||||
SubstitutionPtr adjust_all_reduce_mul_add_;
|
||||
|
||||
// ops eliminate
|
||||
SubstitutionPtr item_tuple_eliminate_;
|
||||
|
|
|
@ -228,6 +228,115 @@ class ConstantDuplicateMul : public AnfVisitor {
|
|||
CNodePtr cnode_;
|
||||
};
|
||||
|
||||
// grad = AllReduce(grad) / worker_number
|
||||
// grad = grad + weight * decy
|
||||
// ->
|
||||
// grad = grad + weight * decy
|
||||
// grad = AllReduce(grad) / worker_number
|
||||
|
||||
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||
class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
// {prim::kPrimAddN, Zs}
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto addn = node->cast<CNodePtr>();
|
||||
if (addn->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto addn_maketuple = addn->input(1);
|
||||
|
||||
auto fg = all_reduce_fg_;
|
||||
// addn inputs cross the graph, make the inputs same as allreduce node.
|
||||
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
||||
auto cnode_z = z_->cast<CNodePtr>();
|
||||
z_ = NewCNode(cnode_z->inputs(), fg);
|
||||
}
|
||||
|
||||
auto addn_op_node = addn->input(0);
|
||||
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
|
||||
|
||||
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
||||
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
||||
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
|
||||
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
|
||||
ProcessDependEdge(fg, addn_maketuple, all_reduce);
|
||||
return mul;
|
||||
}
|
||||
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) {
|
||||
// If has dynamic loss scale.
|
||||
auto &users_map = fg->manager()->node_users();
|
||||
auto it = users_map.find(mul_cnode_);
|
||||
if (it != users_map.end()) {
|
||||
auto users = it->second;
|
||||
for (auto &user_pair : users) {
|
||||
auto node = user_pair.first;
|
||||
if (node != addn_maketuple) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
fg->manager()->SetEdge(node, user_pair.second, new_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (level_ == 0) {
|
||||
level_ = 1;
|
||||
is_reduce_match_ = false;
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
||||
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||
level_ = 0;
|
||||
if (is_reduce_match_) {
|
||||
mul_ = node->cast<CNodePtr>()->input(0);
|
||||
mul_cnode_ = node->cast<CNodePtr>();
|
||||
y_ = tmp_;
|
||||
} else {
|
||||
z_ = node;
|
||||
}
|
||||
}
|
||||
|
||||
if (level_ == 1) {
|
||||
// {prim::kPrimAllReduce, X}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() > 1) {
|
||||
all_reduce_ = cnode->input(0);
|
||||
x_ = cnode->input(1);
|
||||
is_reduce_match_ = true;
|
||||
all_reduce_fg_ = cnode->func_graph();
|
||||
}
|
||||
} else {
|
||||
tmp_ = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
level_ = 0;
|
||||
is_reduce_match_ = false;
|
||||
x_ = nullptr;
|
||||
y_ = nullptr;
|
||||
z_ = nullptr;
|
||||
tmp_ = nullptr;
|
||||
all_reduce_fg_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
int level_{0};
|
||||
bool is_reduce_match_{false};
|
||||
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
|
||||
FuncGraphPtr all_reduce_fg_{nullptr};
|
||||
};
|
||||
|
||||
class ArithmeticSimplify {
|
||||
public:
|
||||
ArithmeticSimplify()
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <utility>
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ordered_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -100,7 +101,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
|
||||
|
||||
// set state nodes need to insert before function return nodes.
|
||||
std::unordered_map<AnfNodePtr, std::string> state_assign_;
|
||||
OrderedMap<AnfNodePtr, std::string> state_assign_;
|
||||
|
||||
// hold declared global variables in function
|
||||
std::set<std::string> global_vars_;
|
||||
|
|
|
@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
// Arithmetic simplifications
|
||||
irpass.arithmetic_simplify_,
|
||||
irpass.addn_zero_filter_,
|
||||
irpass.adjust_all_reduce_mul_add_,
|
||||
|
||||
// Miscellaneous
|
||||
irpass.item_tuple_eliminate_,
|
||||
|
|
|
@ -1372,7 +1372,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||
>>> num_segments = 4
|
||||
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||
|
|
|
@ -1867,7 +1867,7 @@ class LayerNorm(Primitive):
|
|||
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
|
|
|
@ -284,7 +284,8 @@ def prim_attr_register(fn):
|
|||
|
||||
def constexpr(fn=None, get_instance=True, name=None):
|
||||
"""
|
||||
Makes a PrimitiveWithInfer operator, which infer the value while compiling.
|
||||
Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
|
||||
to compute between constant variable and used in constructß.
|
||||
|
||||
Args:
|
||||
fn (function): A `fn` use as the infer_value of the output operator.
|
||||
|
|
|
@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
|
|||
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
||||
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
|
||||
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
|
||||
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
|
||||
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
|
||||
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
|
||||
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
|
||||
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
|
||||
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
|
||||
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1046,8 +1046,8 @@ def test_print_tuple_wrapper(tag):
|
|||
# pylint: disable=unnecessary-semicolon
|
||||
def test_constant_duplicate_mul(tag):
|
||||
fns = FnDict()
|
||||
Mul = Primitive('Mul');
|
||||
Sqrt = Primitive('Sqrt');
|
||||
Mul = Primitive('Mul')
|
||||
Sqrt = Primitive('Sqrt')
|
||||
|
||||
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||
|
@ -1074,3 +1074,44 @@ def test_constant_duplicate_mul(tag):
|
|||
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_adjust_allreduce_mul_add(tag):
|
||||
fns = FnDict()
|
||||
Mul = Primitive('Mul')
|
||||
AddN = Primitive('AddN')
|
||||
AllReduce = Primitive('AllReduce')
|
||||
|
||||
@fns
|
||||
def beforell(x, y, z):
|
||||
return AddN((z, Mul(y, AllReduce(x))))
|
||||
|
||||
@fns
|
||||
def beforelr(x, y, z):
|
||||
return AddN((z, Mul(AllReduce(x), y)))
|
||||
|
||||
@fns
|
||||
def beforerl(x, y, z):
|
||||
return AddN((Mul(y, AllReduce(x)), z))
|
||||
|
||||
@fns
|
||||
def beforerr(x, y, z):
|
||||
return AddN((Mul(AllReduce(x), y), z))
|
||||
|
||||
@fns
|
||||
def after1(x, y, z):
|
||||
return Mul(AllReduce(AddN((z, x))), y)
|
||||
|
||||
@fns
|
||||
def before2r(x, y, z):
|
||||
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
|
||||
|
||||
@fns
|
||||
def before2l(x, y, z):
|
||||
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
|
||||
|
||||
@fns
|
||||
def after2(x, y, z):
|
||||
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -20,9 +20,12 @@ import mindspore.context as context
|
|||
from mindspore import Tensor
|
||||
from mindspore import amp
|
||||
from mindspore import nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.model_zoo.resnet import resnet50
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import init
|
||||
|
||||
def setup_module(module):
|
||||
_ = module
|
||||
|
@ -139,3 +142,22 @@ def test_compile_model_train_O2():
|
|||
with pytest.raises(ValueError):
|
||||
# not actual run, the metrics step will fail, check if compile ok.
|
||||
model.eval(dataset)
|
||||
|
||||
def test_compile_model_train_O2_parallel():
|
||||
dataset_types = (np.float32, np.float32)
|
||||
dataset_shapes = ((16, 16), (16, 16))
|
||||
|
||||
dataset = MindDataSet(dataset_types, dataset_shapes)
|
||||
|
||||
net = NetNoLoss(16, 16)
|
||||
loss = nn.MSELoss()
|
||||
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
|
||||
|
||||
context.set_auto_parallel_context(
|
||||
global_rank=0, device_num=8,
|
||||
mirror_mean=True, parameter_broadcast=True,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
init()
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
|
Loading…
Reference in New Issue