!22650 fixed sparse attention modify

Merge pull request !22650 from yao_yf/fixed_sparse_attention_modify
This commit is contained in:
i-robot 2021-09-01 01:29:45 +00:00 committed by Gitee
commit cc8d614b25
4 changed files with 80 additions and 19 deletions

View File

@ -308,7 +308,7 @@ class FixedSparseAttention(nn.Cell):
any local window, since there are multi-heads, each head can use a
different global representative, only supports 4 for now
size_per_head (int): An integer determining embedding size of each attention head,
only supports 64, 80, 96, 112, 128 for now
only supports 64, 128 for now
Inputs:
- **q** - Tensor uery (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
@ -317,8 +317,8 @@ class FixedSparseAttention(nn.Cell):
queries to query the context.
- **v** - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]): Sequence of
queries to query the context.
- **input_mask** - Tensor the mask of (:class:`mstype.fp32` [batch_size, seq_length]):
Sequence of 0 and 1 to pass masked information.
- **attention_mask** - Tensor the mask of (:class:`mstype.fp32` [batch_size, seq_length, seq_length]):
Lower triangular matrix to pass masked information.
Outputs:
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
@ -334,8 +334,8 @@ class FixedSparseAttention(nn.Cell):
>>> q = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> k = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> v = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> input_mask = Tensor(np.ones((2, 1024)), dtype.float16)
>>> output = model(q, k, v, input_mask)
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), dtype.float16)
>>> output = model(q, k, v, attention_mask)
>>> print(output.shape)
(2, 1024, 512)
"""
@ -404,6 +404,7 @@ class FixedSparseAttention(nn.Cell):
self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),))
self.transpose3 = P.Transpose().shard(((dp, mp, 1, 1, 1, 1),))
self.transpose4 = P.Transpose().shard(((dp, mp, 1, 1),))
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
def _transpose_inputs(self, q, k, v):
"""
@ -426,10 +427,14 @@ class FixedSparseAttention(nn.Cell):
return q, k, v
def _generate_attention_mask(self, input_mask):
def _generate_attention_mask(self, attention_mask):
"""
generate attention mask from input mask
generate global attention mask and local attention mask from origin attention mask
"""
attention_mask = self.reshape(attention_mask, (-1, self.seq_length, self.seq_length))
input_mask = self.slice1(attention_mask, (0, self.seq_length - 1, 0),
(self.batch_size, self.seq_length, self.seq_length), (1, 1, 1))
input_mask = self.reshape(input_mask, (-1, self.seq_length))
input_shape = P.Shape()(input_mask) # bs, seq_length
# bs, block_num, 1, block_size
local_shape_right = (input_shape[0], self.block_num, 1, self.block_size)
@ -457,7 +462,7 @@ class FixedSparseAttention(nn.Cell):
return local_mask, global_mask
def construct(self, q, k, v, input_mask):
def construct(self, q, k, v, attention_mask):
_check_shape_equal(F.shape(q), "q", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
@ -467,12 +472,12 @@ class FixedSparseAttention(nn.Cell):
_check_shape_equal(F.shape(v), "v", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, self.seq_length])
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32], self.cls_name)
q, k, v = self._transpose_inputs(q, k, v)
local_mask, global_mask = self._generate_attention_mask(input_mask)
local_mask, global_mask = self._generate_attention_mask(attention_mask)
q = q / F.cast(self.scale_factor, F.dtype(q))
k = k / F.cast(self.scale_factor, F.dtype(k))
local_prob, global_prob = self.matmul_dds(q, k, local_mask, global_mask)

View File

@ -21,7 +21,7 @@ def test_net():
k = k.astype(np.float16)
v = np.random.rand(bs, seq_len, heads * size_per_head)
v = v.astype(np.float16)
input_mask = np.ones((bs, seq_len), dtype=np.float32)
out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(input_mask))
attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32)
out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask))
out_np = out.asnumpy()
print("local output: ", out_np[0, 0])

View File

@ -251,5 +251,5 @@ def test_sparse_attention():
q = Tensor(np.ones((2, 1024, 512)), dtype.float16)
k = Tensor(np.ones((2, 1024, 512)), dtype.float16)
v = Tensor(np.ones((2, 1024, 512)), dtype.float16)
mask = Tensor(np.ones((2, 1024)), dtype.float32)
mask = Tensor(np.ones((2, 1024, 1024)), dtype.float32)
_cell_graph_executor.compile(model, q, k, v, mask)

View File

@ -29,6 +29,7 @@ from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from mindspore.train import Model
from mindspore.parallel import set_algo_parameters
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -355,9 +356,11 @@ def test_vocabembedding_dp_false():
model.train(1, dataset, dataset_sink_mode=False)
def _test_sparse_attention_parallel():
def test_sparse_attention_parallel_mp():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(fully_use_devices=False)
sparse_attention_config = OpParallelConfig(model_parallel=8)
net = FixedSparseAttention(batch_size=2,
net = FixedSparseAttention(batch_size=16,
seq_length=1024,
size_per_head=64,
num_heads=8,
@ -366,14 +369,67 @@ def _test_sparse_attention_parallel():
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
mask = Tensor(np.ones((2, 1024)), mstype.float32)
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
dataset = Dataset(q, k, v, mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_sparse_attention_parallel_mix():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(fully_use_devices=False)
sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4)
net = FixedSparseAttention(batch_size=16,
seq_length=1024,
size_per_head=64,
num_heads=8,
block_size=64,
parallel_config=sparse_attention_config)
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
dataset = Dataset(q, k, v, mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_sparse_attention_parallel_mix1():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(fully_use_devices=False)
sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2)
net = FixedSparseAttention(batch_size=16,
seq_length=1024,
size_per_head=64,
num_heads=8,
block_size=64,
parallel_config=sparse_attention_config)
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
dataset = Dataset(q, k, v, mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_sparse_attention_parallel_dp():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(fully_use_devices=False)
sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1)
net = FixedSparseAttention(batch_size=16,
seq_length=1024,
size_per_head=64,
num_heads=8,
block_size=64,
parallel_config=sparse_attention_config)
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
dataset = Dataset(q, k, v, mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_parallel_cross_entroy_loss_semi_auto_parallel():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
class NetWithLoss(nn.Cell):
def __init__(self, network, config_setting):