From 2ceea1e59df3beed948e7fb006c1bf8c24919d12 Mon Sep 17 00:00:00 2001 From: Wan Hanyang Date: Fri, 11 Sep 2020 16:30:28 +0800 Subject: [PATCH] add a self attention test case --- .../python/parallel/test_loss_and_o2_level.py | 121 ++++++++++++++ .../ut/python/parallel/test_self_attention.py | 149 ++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100755 tests/ut/python/parallel/test_loss_and_o2_level.py create mode 100644 tests/ut/python/parallel/test_self_attention.py diff --git a/tests/ut/python/parallel/test_loss_and_o2_level.py b/tests/ut/python/parallel/test_loss_and_o2_level.py new file mode 100755 index 00000000000..358b11e8d34 --- /dev/null +++ b/tests/ut/python/parallel/test_loss_and_o2_level.py @@ -0,0 +1,121 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.neg = P.Neg().shard(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x): + out = self.mul(x, self.mul_weight) + out = self.neg(out) + return out + + +_x = Tensor(np.ones([32, 128]), dtype=ms.float32) +_b = Tensor(np.ones([32]), dtype=ms.int32) +_w1 = Tensor(np.ones([512, 128]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, loss, optimizer=opt, amp_level="O2") + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_neg_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1), (16, 1)) + strategy2 = ((16, 1),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 16), (1, 16)) + strategy2 = ((1, 16),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((4, 4),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile_net(net) + + +def test_neg_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((2, 2),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_neg_repeat_calc2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 2), (4, 2)) + strategy2 = ((4, 4),) + net = Net(_w1, strategy1, strategy2) + compile_net(net) diff --git a/tests/ut/python/parallel/test_self_attention.py b/tests/ut/python/parallel/test_self_attention.py new file mode 100644 index 00000000000..a484e1fd63e --- /dev/null +++ b/tests/ut/python/parallel/test_self_attention.py @@ -0,0 +1,149 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.context import set_auto_parallel_context +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + + +def compile_net(net, x): + net.set_auto_parallel() + _executor.compile(net, x) + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5): + super().__init__() + self.query_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='query') + self.query = P.MatMul().shard(strategy1) + + self.key_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='key') + self.key = P.MatMul().shard(strategy2) + + self.value_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='value') + self.value = P.MatMul().shard(strategy3) + + self.score = P.MatMul().shard(strategy4) + self.context = P.MatMul().shard(strategy5) + self.transpose1 = P.Transpose() + self.transpose2 = P.Transpose() + self.relu = P.ReLU() + + def construct(self, x): + q = self.query(x, self.query_w) + k = self.key(x, self.key_w) + v = self.value(x, self.value_w) + + k = self.transpose1(k, (1, 0)) + s = self.score(q, k) + + v = self.transpose2(v, (1, 0)) + c = self.context(v, s) + out = self.relu(c) + + return out + + +def test_self_attention_standalone(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="stand_alone") + net = GradWrap(NetWithLoss( + Net(None, None, None, None, None))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_semi(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((2, 2), (2, 2)) + strategy3 = ((2, 2), (2, 2)) + strategy4 = ((2, 4), (4, 1)) + strategy5 = ((2, 1), (1, 4)) + + net = GradWrap(NetWithLoss( + Net(strategy1, strategy2, strategy3, strategy4, strategy5))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_dp(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((8, 1), (1, 1)) + strategy3 = ((8, 1), (1, 1)) + strategy4 = ((8, 1), (1, 1)) + strategy5 = ((8, 1), (1, 1)) + + net = GradWrap(NetWithLoss( + Net(strategy1, strategy2, strategy3, strategy4, strategy5))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_auto(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss( + Net(None, None, None, None, None))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x)