From e3ae23c9392655c3f4e0529a211d8f14cd35f770 Mon Sep 17 00:00:00 2001 From: Su Teng Date: Fri, 21 Aug 2020 10:10:53 +0800 Subject: [PATCH] add parallel attention test --- tests/ut/python/parallel/test_attention.py | 146 +++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 tests/ut/python/parallel/test_attention.py diff --git a/tests/ut/python/parallel/test_attention.py b/tests/ut/python/parallel/test_attention.py new file mode 100644 index 00000000000..e8ed4cc94fc --- /dev/null +++ b/tests/ut/python/parallel/test_attention.py @@ -0,0 +1,146 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.context import set_auto_parallel_context +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return C.grad_all(self.network)(x) + + +def compile_net(net, x): + net.set_auto_parallel() + _executor.compile(net, x) + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5): + super().__init__() + self.query_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='query') + self.query = P.MatMul().set_strategy(strategy1) + + self.key_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='key') + self.key = P.MatMul().set_strategy(strategy2) + + self.value_w = Parameter(initializer( + "normal", [8, 16], ms.float32), name='value') + self.value = P.MatMul().set_strategy(strategy3) + + self.score = P.MatMul().set_strategy(strategy4) + self.context = P.MatMul().set_strategy(strategy5) + self.transpose1 = P.Transpose() + self.transpose2 = P.Transpose() + self.relu = P.ReLU() + + def construct(self, x): + q = self.query(x, self.query_w) + k = self.key(x, self.key_w) + v = self.value(x, self.value_w) + + k = self.transpose1(k, (1, 0)) + s = self.score(q, k) + + v = self.transpose2(v, (1, 0)) + c = self.context(v, s) + out = self.relu(c) + + return out + + +def test_self_attention_standalone(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="stand_alone") + net = GradWrap(NetWithLoss( + Net(None, None, None, None, None))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_semi(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((2, 2), (2, 2)) + strategy3 = ((2, 2), (2, 2)) + strategy4 = ((2, 4), (4, 1)) + strategy5 = ((2, 1), (1, 4)) + + net = GradWrap(NetWithLoss( + Net(strategy1, strategy2, strategy3, strategy4, strategy5))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_dp(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((8, 1), (1, 1)) + strategy3 = ((8, 1), (1, 1)) + strategy4 = ((8, 1), (1, 1)) + strategy5 = ((8, 1), (1, 1)) + + net = GradWrap(NetWithLoss( + Net(strategy1, strategy2, strategy3, strategy4, strategy5))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x) + + +def test_self_attention_auto(): + set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss( + Net(None, None, None, None, None))) + + x = Tensor(np.ones([32, 8]), dtype=ms.float32) + + compile_net(net, x)