forked from mindspore-Ecosystem/mindspore
!22650 fixed sparse attention modify
Merge pull request !22650 from yao_yf/fixed_sparse_attention_modify
This commit is contained in:
commit
cc8d614b25
|
@ -308,7 +308,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
any local window, since there are multi-heads, each head can use a
|
||||
different global representative, only supports 4 for now
|
||||
size_per_head (int): An integer determining embedding size of each attention head,
|
||||
only supports 64, 80, 96, 112, 128 for now
|
||||
only supports 64, 128 for now
|
||||
|
||||
Inputs:
|
||||
- **q** - Tensor uery (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
|
||||
|
@ -317,8 +317,8 @@ class FixedSparseAttention(nn.Cell):
|
|||
queries to query the context.
|
||||
- **v** - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]): Sequence of
|
||||
queries to query the context.
|
||||
- **input_mask** - Tensor the mask of (:class:`mstype.fp32` [batch_size, seq_length]):
|
||||
Sequence of 0 and 1 to pass masked information.
|
||||
- **attention_mask** - Tensor the mask of (:class:`mstype.fp32` [batch_size, seq_length, seq_length]):
|
||||
Lower triangular matrix to pass masked information.
|
||||
|
||||
Outputs:
|
||||
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
|
||||
|
@ -334,8 +334,8 @@ class FixedSparseAttention(nn.Cell):
|
|||
>>> q = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
|
||||
>>> k = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
|
||||
>>> v = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
|
||||
>>> input_mask = Tensor(np.ones((2, 1024)), dtype.float16)
|
||||
>>> output = model(q, k, v, input_mask)
|
||||
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), dtype.float16)
|
||||
>>> output = model(q, k, v, attention_mask)
|
||||
>>> print(output.shape)
|
||||
(2, 1024, 512)
|
||||
"""
|
||||
|
@ -404,6 +404,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),))
|
||||
self.transpose3 = P.Transpose().shard(((dp, mp, 1, 1, 1, 1),))
|
||||
self.transpose4 = P.Transpose().shard(((dp, mp, 1, 1),))
|
||||
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
|
||||
|
||||
def _transpose_inputs(self, q, k, v):
|
||||
"""
|
||||
|
@ -426,10 +427,14 @@ class FixedSparseAttention(nn.Cell):
|
|||
|
||||
return q, k, v
|
||||
|
||||
def _generate_attention_mask(self, input_mask):
|
||||
def _generate_attention_mask(self, attention_mask):
|
||||
"""
|
||||
generate attention mask from input mask
|
||||
generate global attention mask and local attention mask from origin attention mask
|
||||
"""
|
||||
attention_mask = self.reshape(attention_mask, (-1, self.seq_length, self.seq_length))
|
||||
input_mask = self.slice1(attention_mask, (0, self.seq_length - 1, 0),
|
||||
(self.batch_size, self.seq_length, self.seq_length), (1, 1, 1))
|
||||
input_mask = self.reshape(input_mask, (-1, self.seq_length))
|
||||
input_shape = P.Shape()(input_mask) # bs, seq_length
|
||||
# bs, block_num, 1, block_size
|
||||
local_shape_right = (input_shape[0], self.block_num, 1, self.block_size)
|
||||
|
@ -457,7 +462,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
|
||||
return local_mask, global_mask
|
||||
|
||||
def construct(self, q, k, v, input_mask):
|
||||
def construct(self, q, k, v, attention_mask):
|
||||
_check_shape_equal(F.shape(q), "q", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
|
||||
|
@ -467,12 +472,12 @@ class FixedSparseAttention(nn.Cell):
|
|||
_check_shape_equal(F.shape(v), "v", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length])
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32], self.cls_name)
|
||||
|
||||
q, k, v = self._transpose_inputs(q, k, v)
|
||||
local_mask, global_mask = self._generate_attention_mask(input_mask)
|
||||
local_mask, global_mask = self._generate_attention_mask(attention_mask)
|
||||
q = q / F.cast(self.scale_factor, F.dtype(q))
|
||||
k = k / F.cast(self.scale_factor, F.dtype(k))
|
||||
local_prob, global_prob = self.matmul_dds(q, k, local_mask, global_mask)
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_net():
|
|||
k = k.astype(np.float16)
|
||||
v = np.random.rand(bs, seq_len, heads * size_per_head)
|
||||
v = v.astype(np.float16)
|
||||
input_mask = np.ones((bs, seq_len), dtype=np.float32)
|
||||
out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(input_mask))
|
||||
attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32)
|
||||
out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask))
|
||||
out_np = out.asnumpy()
|
||||
print("local output: ", out_np[0, 0])
|
||||
|
|
|
@ -251,5 +251,5 @@ def test_sparse_attention():
|
|||
q = Tensor(np.ones((2, 1024, 512)), dtype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), dtype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), dtype.float16)
|
||||
mask = Tensor(np.ones((2, 1024)), dtype.float32)
|
||||
mask = Tensor(np.ones((2, 1024, 1024)), dtype.float32)
|
||||
_cell_graph_executor.compile(model, q, k, v, mask)
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.nn.optim import AdamWeightDecay
|
|||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
|
||||
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
||||
from mindspore.train import Model
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from tests.dataset_mock import MindData
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
@ -355,9 +356,11 @@ def test_vocabembedding_dp_false():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def _test_sparse_attention_parallel():
|
||||
def test_sparse_attention_parallel_mp():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
sparse_attention_config = OpParallelConfig(model_parallel=8)
|
||||
net = FixedSparseAttention(batch_size=2,
|
||||
net = FixedSparseAttention(batch_size=16,
|
||||
seq_length=1024,
|
||||
size_per_head=64,
|
||||
num_heads=8,
|
||||
|
@ -366,14 +369,67 @@ def _test_sparse_attention_parallel():
|
|||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
mask = Tensor(np.ones((2, 1024)), mstype.float32)
|
||||
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
|
||||
dataset = Dataset(q, k, v, mask)
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_sparse_attention_parallel_mix():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4)
|
||||
net = FixedSparseAttention(batch_size=16,
|
||||
seq_length=1024,
|
||||
size_per_head=64,
|
||||
num_heads=8,
|
||||
block_size=64,
|
||||
parallel_config=sparse_attention_config)
|
||||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
|
||||
dataset = Dataset(q, k, v, mask)
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_sparse_attention_parallel_mix1():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2)
|
||||
net = FixedSparseAttention(batch_size=16,
|
||||
seq_length=1024,
|
||||
size_per_head=64,
|
||||
num_heads=8,
|
||||
block_size=64,
|
||||
parallel_config=sparse_attention_config)
|
||||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
|
||||
dataset = Dataset(q, k, v, mask)
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_sparse_attention_parallel_dp():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1)
|
||||
net = FixedSparseAttention(batch_size=16,
|
||||
seq_length=1024,
|
||||
size_per_head=64,
|
||||
num_heads=8,
|
||||
block_size=64,
|
||||
parallel_config=sparse_attention_config)
|
||||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
|
||||
dataset = Dataset(q, k, v, mask)
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network, config_setting):
|
||||
|
|
Loading…
Reference in New Issue