mindspore/tests/ut/python/parallel/test_dsd_matmul.py

186 lines
7.1 KiB
Python

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters
from mindspore.ops.operations._inner_ops import DSDMatmul
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
# input_w1, the shape is (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
# input_w1 cum_shape = batch_size * seq_len * embedding_size * (block_size // size_per_head)
# = batch_size * seq_len * (embedding_size // 2)
# input_w2, the shape is (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
# input_w2 cum_shape = batch_size * seq_len * embedding_size * (global_size // size_per_head)
# = batch_size * seq_len * embedding_size * 2
# input_v, the shape is (batch_size * seq_len // 16, head * v_embedding // 16, 16, 16)
# block_num = seq_len // block_size, block_size = 64, head * v_embedding = embedding_size, always.
# output shape is (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
class Net(nn.Cell):
def __init__(self, batch_size, num_heads, dp, mp, shard=True):
super(Net, self).__init__()
self.batch_size = batch_size
self.num_heads = num_heads
self.seq_len = 1024
self.block_size = 64
self.head_size = self.block_size
self.block_num = self.seq_len // self.block_size
self.global_size = 256
self.v_embedding = 128
self.embedding_size = num_heads * self.v_embedding
self.dsd_matmul = DSDMatmul()
self.reduce_sum = P.ReduceSum()
self.dense1 = nn.Dense(self.embedding_size, self.embedding_size // 2, has_bias=False)
self.dense2 = nn.Dense(self.embedding_size, self.embedding_size * 2, has_bias=False)
self.dense3 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.transpose1 = P.Transpose()
self.add = P.Add()
if shard:
self.dsd_matmul.shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1)))
self.dense1.matmul.shard(((dp, 1), (mp, 1)))
self.dense2.matmul.shard(((dp, 1), (mp, 1)))
self.dense2.matmul.shard(((dp, 1), (mp, 1)))
self.transpose.shard(((dp, 1, mp, 1),))
self.transpose1.shard(((dp, mp, 1, 1, 1, 1),))
def construct(self, x):
# x (batch_size * seq_len, embedding_size)
q = self.dense1(x)
# q (batch_size * seq_len, (embedding_size // 2))
# (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
k = self.dense2(x)
# k (batch_size * seq_len, (embedding_size * 2))
# (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
v = self.dense3(x)
# v (batch_size * seq_len, embedding_size)
q = self.reshape(q, (self.batch_size, self.num_heads, self.block_num, self.head_size // 16,
self.block_size // 16, 16, 16))
k = self.reshape(k, (self.batch_size, self.num_heads, self.block_num, self.global_size // 16,
self.head_size // 16, 16, 16))
v = self.transpose(self.reshape(v, (-1, 16, self.embedding_size // 16, 16)), (0, 2, 3, 1))
dsd = self.dsd_matmul(q, k, v)
# dsd (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
dsd = self.transpose1(dsd, (0, 1, 3, 4, 2, 5))
# dsd (batch_size, head, seq_len//16, 16, v_embedding_size//16, 16)
dsd = self.reshape(dsd, (-1, self.seq_len, self.v_embedding * self.num_heads))
result = self.reduce_sum(dsd, 2)
return result
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.network = network
self.loss = VirtualLoss()
def construct(self, x):
predict = self.network(x)
return self.loss(predict)
def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x)
def test_dsd_matmul_model_parallel_mix():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 2
mp = 8
compile_graph(batch_size, num_heads, dp, mp)
def test_dsd_matmul_model_parallel_dp():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 16
mp = 1
compile_graph(batch_size, num_heads, dp, mp)
def test_dsd_matmul_model_parallel_mp():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp)
def test_dsd_matmul_model_parallel_mix_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 2
mp = 8
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_dsd_matmul_model_parallel_dp_auto():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 16
mp = 1
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_dsd_matmul_model_parallel_mp_auto():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_dsd_matmul_model_parallel_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)