forked from mindspore-Ecosystem/mindspore
add pattern AdjustAllReduceMulAdduse the old opadd test case for bugtemp fix try
This commit is contained in:
parent
2a1aad0f55
commit
3db8cfa54f
|
@ -241,6 +241,7 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||||
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||||
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||||
|
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||||
|
|
||||||
// Debug ops
|
// Debug ops
|
||||||
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||||
|
|
|
@ -245,6 +245,7 @@ extern const PrimitivePtr kPrimInDict;
|
||||||
extern const PrimitivePtr kPrimNotInDict;
|
extern const PrimitivePtr kPrimNotInDict;
|
||||||
|
|
||||||
// Comm ops
|
// Comm ops
|
||||||
|
extern const PrimitivePtr kPrimAllReduce;
|
||||||
extern const PrimitivePtr kPrimMirror;
|
extern const PrimitivePtr kPrimMirror;
|
||||||
extern const PrimitivePtr kPrimVirtualDiv;
|
extern const PrimitivePtr kPrimVirtualDiv;
|
||||||
extern const PrimitivePtr kPrimVirtualDataset;
|
extern const PrimitivePtr kPrimVirtualDataset;
|
||||||
|
|
|
@ -53,6 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
||||||
|
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||||
|
|
||||||
// ops eliminate
|
// ops eliminate
|
||||||
item_tuple_eliminate_ =
|
item_tuple_eliminate_ =
|
||||||
|
|
|
@ -35,6 +35,7 @@ class OptimizeIRPassLib {
|
||||||
SubstitutionPtr arithmetic_simplify_;
|
SubstitutionPtr arithmetic_simplify_;
|
||||||
SubstitutionPtr special_op_eliminate_;
|
SubstitutionPtr special_op_eliminate_;
|
||||||
SubstitutionPtr zero_like_fill_zero_;
|
SubstitutionPtr zero_like_fill_zero_;
|
||||||
|
SubstitutionPtr adjust_all_reduce_mul_add_;
|
||||||
|
|
||||||
// ops eliminate
|
// ops eliminate
|
||||||
SubstitutionPtr item_tuple_eliminate_;
|
SubstitutionPtr item_tuple_eliminate_;
|
||||||
|
|
|
@ -228,6 +228,116 @@ class ConstantDuplicateMul : public AnfVisitor {
|
||||||
CNodePtr cnode_;
|
CNodePtr cnode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// grad = AllReduce(grad) / worker_number
|
||||||
|
// grad = grad + weight * decy
|
||||||
|
// ->
|
||||||
|
// grad = grad + weight * decy
|
||||||
|
// grad = AllReduce(grad) / worker_number
|
||||||
|
|
||||||
|
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||||
|
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||||
|
class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
Reset();
|
||||||
|
// {prim::kPrimAddN, Zs}
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto addn = node->cast<CNodePtr>();
|
||||||
|
if (addn->size() != 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||||
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto addn_maketuple = addn->input(1);
|
||||||
|
|
||||||
|
auto fg = all_reduce_fg_;
|
||||||
|
// addn inputs cross the graph, make the inputs same as allreduce node.
|
||||||
|
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
||||||
|
auto cnode_z = z_->cast<CNodePtr>();
|
||||||
|
z_ = NewCNode(cnode_z->inputs(), fg);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto addn_op_node = addn->input(0);
|
||||||
|
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
|
||||||
|
|
||||||
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
||||||
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
||||||
|
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
|
||||||
|
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
|
||||||
|
ProcessDependEdge(fg, addn_maketuple, all_reduce);
|
||||||
|
return mul;
|
||||||
|
}
|
||||||
|
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) {
|
||||||
|
// If has dynamic loss scale.
|
||||||
|
auto &users_map = fg->manager()->node_users();
|
||||||
|
auto it = users_map.find(mul_cnode_);
|
||||||
|
|
||||||
|
if (it != users_map.end()) {
|
||||||
|
auto users = it->second;
|
||||||
|
for (auto &user_pair : users) {
|
||||||
|
auto node = user_pair.first;
|
||||||
|
if (node != addn_maketuple) {
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||||
|
fg->manager()->SetEdge(node, user_pair.second, new_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Visit(const AnfNodePtr &node) override {
|
||||||
|
if (level_ == 0) {
|
||||||
|
level_ = 1;
|
||||||
|
is_reduce_match_ = false;
|
||||||
|
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
||||||
|
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||||
|
level_ = 0;
|
||||||
|
if (is_reduce_match_) {
|
||||||
|
mul_ = node->cast<CNodePtr>()->input(0);
|
||||||
|
mul_cnode_ = node->cast<CNodePtr>();
|
||||||
|
y_ = tmp_;
|
||||||
|
} else {
|
||||||
|
z_ = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (level_ == 1) {
|
||||||
|
// {prim::kPrimAllReduce, X}
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode->size() > 1) {
|
||||||
|
all_reduce_ = cnode->input(0);
|
||||||
|
x_ = cnode->input(1);
|
||||||
|
is_reduce_match_ = true;
|
||||||
|
all_reduce_fg_ = cnode->func_graph();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tmp_ = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
level_ = 0;
|
||||||
|
is_reduce_match_ = false;
|
||||||
|
x_ = nullptr;
|
||||||
|
y_ = nullptr;
|
||||||
|
z_ = nullptr;
|
||||||
|
tmp_ = nullptr;
|
||||||
|
all_reduce_fg_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int level_{0};
|
||||||
|
bool is_reduce_match_{false};
|
||||||
|
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||||
|
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
|
||||||
|
FuncGraphPtr all_reduce_fg_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
class ArithmeticSimplify {
|
class ArithmeticSimplify {
|
||||||
public:
|
public:
|
||||||
ArithmeticSimplify()
|
ArithmeticSimplify()
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "pipeline/parse/parse_base.h"
|
#include "pipeline/parse/parse_base.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "utils/ordered_map.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parse {
|
namespace parse {
|
||||||
|
@ -99,7 +100,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
||||||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
|
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
|
||||||
|
|
||||||
// set state nodes need to insert before function return nodes.
|
// set state nodes need to insert before function return nodes.
|
||||||
std::unordered_map<AnfNodePtr, std::string> state_assign_;
|
OrderedMap<AnfNodePtr, std::string> state_assign_;
|
||||||
|
|
||||||
// hold declared global variables in function
|
// hold declared global variables in function
|
||||||
std::set<std::string> global_vars_;
|
std::set<std::string> global_vars_;
|
||||||
|
|
|
@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
// Arithmetic simplifications
|
// Arithmetic simplifications
|
||||||
irpass.arithmetic_simplify_,
|
irpass.arithmetic_simplify_,
|
||||||
irpass.addn_zero_filter_,
|
irpass.addn_zero_filter_,
|
||||||
|
irpass.adjust_all_reduce_mul_add_,
|
||||||
|
|
||||||
// Miscellaneous
|
// Miscellaneous
|
||||||
irpass.item_tuple_eliminate_,
|
irpass.item_tuple_eliminate_,
|
||||||
|
|
|
@ -1213,7 +1213,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
||||||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
|
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||||
>>> num_segments = 4
|
>>> num_segments = 4
|
||||||
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||||
|
|
|
@ -1765,7 +1765,7 @@ class LayerNorm(Primitive):
|
||||||
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||||
|
|
||||||
|
|
|
@ -284,7 +284,8 @@ def prim_attr_register(fn):
|
||||||
|
|
||||||
def constexpr(fn=None, get_instance=True, name=None):
|
def constexpr(fn=None, get_instance=True, name=None):
|
||||||
"""
|
"""
|
||||||
Makes a PrimitiveWithInfer operator, which infer the value while compiling.
|
Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
|
||||||
|
to compute between constant variable and used in constructß.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (function): A `fn` use as the infer_value of the output operator.
|
fn (function): A `fn` use as the infer_value of the output operator.
|
||||||
|
|
|
@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
|
||||||
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
|
||||||
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
|
||||||
|
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
|
||||||
|
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
|
||||||
|
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
|
||||||
|
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
|
||||||
|
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
|
||||||
|
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
|
||||||
|
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
|
||||||
|
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
|
||||||
|
auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
|
||||||
|
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag):
|
||||||
|
|
||||||
def test_constant_duplicate_mul(tag):
|
def test_constant_duplicate_mul(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
Mul = Primitive('Mul');
|
Mul = Primitive('Mul')
|
||||||
Sqrt = Primitive('Sqrt');
|
Sqrt = Primitive('Sqrt')
|
||||||
|
|
||||||
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
|
||||||
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
|
@ -1073,3 +1073,44 @@ def test_constant_duplicate_mul(tag):
|
||||||
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
return Mul(Sqrt(x), Mul(tensor1, tensor2))
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_allreduce_mul_add(tag):
|
||||||
|
fns = FnDict()
|
||||||
|
Mul = Primitive('Mul')
|
||||||
|
AddN = Primitive('AddN')
|
||||||
|
AllReduce = Primitive('AllReduce')
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforell(x, y, z):
|
||||||
|
return AddN((z, Mul(y, AllReduce(x))))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforelr(x, y, z):
|
||||||
|
return AddN((z, Mul(AllReduce(x), y)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforerl(x, y, z):
|
||||||
|
return AddN((Mul(y, AllReduce(x)), z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def beforerr(x, y, z):
|
||||||
|
return AddN((Mul(AllReduce(x), y), z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after1(x, y, z):
|
||||||
|
return Mul(AllReduce(AddN((z, x))), y)
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before2r(x, y, z):
|
||||||
|
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before2l(x, y, z):
|
||||||
|
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after2(x, y, z):
|
||||||
|
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
|
@ -20,9 +20,14 @@ import mindspore.context as context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import amp
|
from mindspore import amp
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model, ParallelMode
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.model_zoo.resnet import resnet50
|
||||||
from ....dataset_mock import MindData
|
from ....dataset_mock import MindData
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
|
||||||
def setup_module(module):
|
def setup_module(module):
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
@ -138,3 +143,22 @@ def test_compile_model_train_O2():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# not actual run, the metrics step will fail, check if compile ok.
|
# not actual run, the metrics step will fail, check if compile ok.
|
||||||
model.eval(dataset)
|
model.eval(dataset)
|
||||||
|
|
||||||
|
def test_compile_model_train_O2_parallel():
|
||||||
|
dataset_types = (np.float32, np.float32)
|
||||||
|
dataset_shapes = ((16, 16), (16, 16))
|
||||||
|
|
||||||
|
dataset = MindDataSet(dataset_types, dataset_shapes)
|
||||||
|
|
||||||
|
net = NetNoLoss(16, 16)
|
||||||
|
loss = nn.MSELoss()
|
||||||
|
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(
|
||||||
|
global_rank=0, device_num=8,
|
||||||
|
mirror_mean=True, parameter_broadcast=True,
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||||
|
init()
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
|
||||||
|
model.train(2, dataset, dataset_sink_mode=False)
|
||||||
|
|
Loading…
Reference in New Issue