Revert "Revert "add pattern AdjustAllReduceMulAdduse the old opadd test case for bugtemp fix try""

This reverts commit 705c71a257.
This commit is contained in:
zhaoting 2020-06-09 20:31:15 +08:00
parent beb714d2d0
commit b16a552d41
13 changed files with 206 additions and 8 deletions

View File

@ -253,6 +253,7 @@ const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// Debug ops
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");

View File

@ -259,6 +259,7 @@ extern const PrimitivePtr kPrimMixedPrecisionCast;
extern const PrimitivePtr kPrimIsConsant;
// Comm ops
extern const PrimitivePtr kPrimAllReduce;
extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;

View File

@ -54,6 +54,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate
item_tuple_eliminate_ =

View File

@ -35,6 +35,7 @@ class OptimizeIRPassLib {
SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_;
// ops eliminate
SubstitutionPtr item_tuple_eliminate_;

View File

@ -228,6 +228,115 @@ class ConstantDuplicateMul : public AnfVisitor {
CNodePtr cnode_;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
// grad = grad + weight * decy
// grad = AllReduce(grad) / worker_number
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class AdjustAllReduceMulAdd : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimAddN, Zs}
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
return nullptr;
}
auto addn_maketuple = addn->input(1);
auto fg = all_reduce_fg_;
// addn inputs cross the graph, make the inputs same as allreduce node.
if (z_->isa<CNode>() && fg != z_->func_graph()) {
auto cnode_z = z_->cast<CNodePtr>();
z_ = NewCNode(cnode_z->inputs(), fg);
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
ProcessDependEdge(fg, addn_maketuple, all_reduce);
return mul;
}
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) {
// If has dynamic loss scale.
auto &users_map = fg->manager()->node_users();
auto it = users_map.find(mul_cnode_);
if (it != users_map.end()) {
auto users = it->second;
for (auto &user_pair : users) {
auto node = user_pair.first;
if (node != addn_maketuple) {
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
fg->manager()->SetEdge(node, user_pair.second, new_node);
}
}
}
}
}
void Visit(const AnfNodePtr &node) override {
if (level_ == 0) {
level_ = 1;
is_reduce_match_ = false;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
mul_cnode_ = node->cast<CNodePtr>();
y_ = tmp_;
} else {
z_ = node;
}
}
if (level_ == 1) {
// {prim::kPrimAllReduce, X}
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
all_reduce_fg_ = cnode->func_graph();
}
} else {
tmp_ = node;
}
}
}
void Reset() {
level_ = 0;
is_reduce_match_ = false;
x_ = nullptr;
y_ = nullptr;
z_ = nullptr;
tmp_ = nullptr;
all_reduce_fg_ = nullptr;
}
private:
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
FuncGraphPtr all_reduce_fg_{nullptr};
};
class ArithmeticSimplify {
public:
ArithmeticSimplify()

View File

@ -28,6 +28,7 @@
#include <utility>
#include "pipeline/parse/parse_base.h"
#include "utils/log_adapter.h"
#include "utils/ordered_map.h"
namespace mindspore {
namespace parse {
@ -100,7 +101,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// set state nodes need to insert before function return nodes.
std::unordered_map<AnfNodePtr, std::string> state_assign_;
OrderedMap<AnfNodePtr, std::string> state_assign_;
// hold declared global variables in function
std::set<std::string> global_vars_;

View File

@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Arithmetic simplifications
irpass.arithmetic_simplify_,
irpass.addn_zero_filter_,
irpass.adjust_all_reduce_mul_add_,
// Miscellaneous
irpass.item_tuple_eliminate_,

View File

@ -1372,7 +1372,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)

View File

@ -1867,7 +1867,7 @@ class LayerNorm(Primitive):
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
.. math::
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.

View File

@ -284,7 +284,8 @@ def prim_attr_register(fn):
def constexpr(fn=None, get_instance=True, name=None):
"""
Makes a PrimitiveWithInfer operator, which infer the value while compiling.
Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
to compute between constant variable and used in constructß.
Args:
fn (function): A `fn` use as the infer_value of the output operator.

View File

@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
}
} // namespace opt
} // namespace mindspore

View File

@ -1046,8 +1046,8 @@ def test_print_tuple_wrapper(tag):
# pylint: disable=unnecessary-semicolon
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
Mul = Primitive('Mul')
Sqrt = Primitive('Sqrt')
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
@ -1074,3 +1074,44 @@ def test_constant_duplicate_mul(tag):
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]
def test_adjust_allreduce_mul_add(tag):
fns = FnDict()
Mul = Primitive('Mul')
AddN = Primitive('AddN')
AllReduce = Primitive('AllReduce')
@fns
def beforell(x, y, z):
return AddN((z, Mul(y, AllReduce(x))))
@fns
def beforelr(x, y, z):
return AddN((z, Mul(AllReduce(x), y)))
@fns
def beforerl(x, y, z):
return AddN((Mul(y, AllReduce(x)), z))
@fns
def beforerr(x, y, z):
return AddN((Mul(AllReduce(x), y), z))
@fns
def after1(x, y, z):
return Mul(AllReduce(AddN((z, x))), y)
@fns
def before2r(x, y, z):
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
@fns
def before2l(x, y, z):
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
@fns
def after2(x, y, z):
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
return fns[tag]

View File

@ -20,9 +20,12 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore import amp
from mindspore import nn
from mindspore.train import Model
from mindspore.train import Model, ParallelMode
from mindspore.common import dtype as mstype
from mindspore.model_zoo.resnet import resnet50
from ....dataset_mock import MindData
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import init
def setup_module(module):
_ = module
@ -139,3 +142,22 @@ def test_compile_model_train_O2():
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)
def test_compile_model_train_O2_parallel():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
context.set_auto_parallel_context(
global_rank=0, device_num=8,
mirror_mean=True, parameter_broadcast=True,
parallel_mode=ParallelMode.DATA_PARALLEL)
init()
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)