forked from mindspore-Ecosystem/mindspore
!16550 add sparse attention related ops
From: @stsuteng Reviewed-by: @kisnwang,@yangzhenzhang Signed-off-by:
This commit is contained in:
commit
bac16639c4
|
@ -0,0 +1,432 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dsd back impl"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
||||
from te import tik
|
||||
from topi.cce import util
|
||||
|
||||
dsd_grad_info = TBERegOp('DSDGrad') \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("dsdbrop.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("dsdbpropimpl") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "w1_gm", False, "required", "all") \
|
||||
.input(1, "w2_gm", False, "required", "all") \
|
||||
.input(2, "v_gm", False, "required", "all") \
|
||||
.input(3, "a_gm", False, "required", "all") \
|
||||
.input(4, "d_a_gm", False, "required", "all") \
|
||||
.output(0, "d_w1_gm", False, "required", "all") \
|
||||
.output(1, "d_w2_gm", False, "required", "all") \
|
||||
.output(2, "d_v_gm", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dsd_grad_info)
|
||||
def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_gm={}, kernel_name='dsdbpropimpl'):
|
||||
"""dsd back impl"""
|
||||
if util.get_product_version() == util.VERSION_MINI:
|
||||
tik_inst = tik.Tik(tik.Dprofile("v100", "mini"))
|
||||
else:
|
||||
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
||||
|
||||
# (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
||||
input_w1_shape = w1_gm.get('shape')
|
||||
# (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
||||
input_w2_shape = w2_gm.get('shape')
|
||||
# (batch_size, seq_len//16, 16, head, v_embedding//16, 16)
|
||||
input_v_shape = v_gm.get('shape')
|
||||
|
||||
batch_size = input_w1_shape[0]
|
||||
head = input_w1_shape[1]
|
||||
block_num = input_w1_shape[2]
|
||||
block_size = input_w1_shape[4] * 16
|
||||
head_size = input_w1_shape[3] * 16
|
||||
global_size = input_w2_shape[3] * 16
|
||||
v_embedding = input_v_shape[1] * 16 // head
|
||||
seq_len = input_v_shape[0] * 16 // batch_size
|
||||
|
||||
# batch_size = 1
|
||||
# head = 1
|
||||
# block_num = 1024 // 64
|
||||
# block_size = 64
|
||||
# head_size = 64
|
||||
# global_size = 256
|
||||
# v_embedding = 128
|
||||
# seq_len = 1024
|
||||
|
||||
block_bite_size = 32
|
||||
|
||||
# 4, 16, 1024//64, 64//16, 64//16, 16*16
|
||||
w1_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, block_num, head_size //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='w1_gm',
|
||||
scope=tik.scope_gm)
|
||||
w2_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, block_num, global_size //
|
||||
16, head_size // 16, 16, 16),
|
||||
name='w2_gm',
|
||||
scope=tik.scope_gm)
|
||||
|
||||
v_gm = tik_inst.Tensor('float16',
|
||||
(batch_size*seq_len//16, head*v_embedding//16, 16, 16),
|
||||
name='v_gm',
|
||||
scope=tik.scope_gm)
|
||||
|
||||
# zN
|
||||
a_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, v_embedding //
|
||||
16, seq_len // 16, 16, 16),
|
||||
name='a_gm',
|
||||
scope=tik.scope_gm)
|
||||
local_gm = a_gm
|
||||
global_gm = a_gm
|
||||
# zN
|
||||
d_a_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, v_embedding //
|
||||
16, seq_len // 16, 16, 16),
|
||||
name='d_a_gm',
|
||||
scope=tik.scope_gm)
|
||||
d_local_gm = d_a_gm
|
||||
d_global_gm = d_a_gm
|
||||
|
||||
# output
|
||||
# w-zN
|
||||
# 4, 16, 1024//64, 64//16, 64//16, 16*16
|
||||
d_w1_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, block_num, head_size //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='d_w1_gm',
|
||||
scope=tik.scope_gm)
|
||||
d_w2_gm = tik_inst.Tensor('float16',
|
||||
(batch_size, head, block_num, global_size //
|
||||
16, head_size // 16, 16, 16),
|
||||
name='d_w2_gm',
|
||||
scope=tik.scope_gm)
|
||||
|
||||
# v-nZ
|
||||
# d_v_gm = tik_inst.Tensor('float16',
|
||||
# (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16),
|
||||
# name='d_v_gm',
|
||||
# scope=tik.scope_gm)
|
||||
d_v_gm = tik_inst.Tensor('float16',
|
||||
(batch_size*seq_len//16, head*v_embedding//16, 16, 16),
|
||||
name='d_v_gm',
|
||||
scope=tik.scope_gm)
|
||||
|
||||
channel_num = batch_size * head
|
||||
with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx:
|
||||
head_idx = channel_idx // batch_size
|
||||
bs_idx = channel_idx % batch_size
|
||||
global_idx = 3 - head_idx % 4
|
||||
# tensor size // (byte * l0b size * thread)
|
||||
cpt_time = 1 if global_size * v_embedding * \
|
||||
4//(1024 * 64) <= 1 else global_size * v_embedding * 4//(1024 * 64)
|
||||
ub_time = 1 if global_size == 256 else 2
|
||||
|
||||
# d_a_gm = (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
|
||||
d_a_l1 = tik_inst.Tensor('float16', (seq_len // 16, v_embedding // 16, 16, 16),
|
||||
name='d_a_l1', scope=tik.scope_cbuf)
|
||||
|
||||
with tik_inst.for_range(0, v_embedding//16) as brick_i:
|
||||
tik_inst.data_move(d_a_l1[0, brick_i, 0, 0], d_a_gm[bs_idx, head_idx, brick_i, 0, 0, 0], 0,
|
||||
seq_len//16, 16*16*2//block_bite_size,
|
||||
0, (v_embedding//16-1)*16*16*2//block_bite_size)
|
||||
|
||||
# dv
|
||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||
|
||||
d_v_l0c = tik_inst.Tensor('float32', (v_embedding // 16, head_size // 16, 16, 16),
|
||||
name='d_v_local_l0c', scope=tik.scope_cc)
|
||||
d_v_ub = tik_inst.Tensor('float16', (v_embedding // 16, head_size // 16, 16, 16),
|
||||
name='d_v_ub', scope=tik.scope_ubuf)
|
||||
d_v_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, head_size // 16, 16, 16),
|
||||
name='d_v_ub_32', scope=tik.scope_ubuf)
|
||||
|
||||
d_v_global_32_l0c = tik_inst.Tensor('float32', (v_embedding // 16, 1, 16, 16),
|
||||
name='d_v_global_32_l0c', scope=tik.scope_cc)
|
||||
|
||||
d_v_global_32_ub = tik_inst.Tensor('float32', (v_embedding // 16, 1, 16, 16),
|
||||
name='d_v_global_32_ub', scope=tik.scope_ubuf)
|
||||
|
||||
# d_v_local
|
||||
with tik_inst.new_stmt_scope():
|
||||
w_local_l1 = tik_inst.Tensor('float16', (head_size//16, block_size//16, 16, 16),
|
||||
name='w_local_l1', scope=tik.scope_cbuf)
|
||||
w_local_l0a = tik_inst.Tensor('float16', (head_size//16, block_size//16, 16, 16),
|
||||
name='w_local_l0a', scope=tik.scope_ca)
|
||||
|
||||
d_a_l0b = tik_inst.Tensor('float16', (block_size//16, v_embedding//16, 16, 16),
|
||||
name='d_a_l0b', scope=tik.scope_cb)
|
||||
|
||||
# (batch_size, head, block_num, head_size // 16, block_size // 16, 16, 16)
|
||||
tik_inst.data_move(w_local_l1[0, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], 0,
|
||||
1, (block_size*head_size*2)//block_bite_size,
|
||||
0, 0)
|
||||
|
||||
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx * block_size//16, 0, 0, 0], 0,
|
||||
(block_size*v_embedding)//(16*16), 1, 0, True)
|
||||
|
||||
tik_inst.load2dv1(w_local_l0a[0, 0, 0, 0], w_local_l1[0, 0, 0, 0],
|
||||
0, (head_size*block_size)//(16*16),
|
||||
1, 0, True)
|
||||
|
||||
tik_inst.mmad(d_v_l0c, w_local_l0a, d_a_l0b,
|
||||
head_size, block_size, v_embedding, 0)
|
||||
|
||||
tik_inst.data_move(d_v_ub_32[0, 0, 0, 0], d_v_l0c[0, 0, 0, 0], 0,
|
||||
1, (v_embedding * head_size)*4//1024, 0, 0)
|
||||
|
||||
# d_v_global
|
||||
with tik_inst.new_stmt_scope():
|
||||
w_global_l1 = tik_inst.Tensor('float16', (1, head_size // 16, 16, 16),
|
||||
name='w_global_l1', scope=tik.scope_cbuf)
|
||||
w_global_l0a = tik_inst.Tensor('float16', (1, head_size // 16, 16, 16),
|
||||
name='w_global_l0a', scope=tik.scope_ca)
|
||||
|
||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
||||
d_a_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16),
|
||||
name='d_a_l0b', scope=tik.scope_cb)
|
||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx_1:
|
||||
# w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
||||
# (1, head_size // 16, 16, 16)
|
||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
||||
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx_1*(block_size//16), 0, 0, 0], 0,
|
||||
(head_size * v_embedding)//(16*16), 1, 0, True)
|
||||
|
||||
tik_inst.data_move(w_global_l1[0, 0, 0, 0], w2_gm[bs_idx, head_idx, w_idx_1, w_idx, 0, 0, 0], 0,
|
||||
head_size // 16, 16*16*2//block_bite_size,
|
||||
0, 0)
|
||||
tik_inst.load2dv1(w_global_l0a[0, 0, 0, 0], w_global_l1[0, 0, 0, 0], 0,
|
||||
16 * head_size // (16 * 16),
|
||||
1, 0, True)
|
||||
|
||||
# d_v_l0c = (v_embedding // 16, head_size // 16, 16, 16)
|
||||
with tik_inst.if_scope(w_idx_1 == 0):
|
||||
tik_inst.mmad(d_v_global_32_l0c, w_global_l0a, d_a_l0b,
|
||||
16, head_size, v_embedding, 0)
|
||||
with tik_inst.else_scope():
|
||||
tik_inst.mmad(d_v_global_32_l0c, w_global_l0a, d_a_l0b,
|
||||
16, head_size, v_embedding, 1)
|
||||
|
||||
tik_inst.data_move(d_v_global_32_ub[0, 0, 0, 0], d_v_global_32_l0c[0, 0, 0, 0], 0,
|
||||
1, v_embedding*16*4//1024, 0, 0)
|
||||
|
||||
with tik_inst.for_range(0, 4) as cpt_i:
|
||||
tik_inst.vadd(64, d_v_ub_32[0, global_idx, cpt_i*4, 0], d_v_ub_32[0, global_idx, cpt_i*4, 0],
|
||||
d_v_global_32_ub[0, 0,
|
||||
cpt_i*4, 0], v_embedding//16,
|
||||
1, 1, 1,
|
||||
head_size*16*4//block_bite_size, head_size*16*4//block_bite_size,
|
||||
16*16*4//block_bite_size)
|
||||
|
||||
tik_inst.vconv(64, '', d_v_ub[0, 0, 0, 0], d_v_ub_32[0, 0, 0, 0],
|
||||
v_embedding * head_size//64, 1, 1, 4, 8)
|
||||
|
||||
with tik_inst.for_range(0, head_size // 16) as h_idx:
|
||||
with tik_inst.for_range(0, v_embedding//16) as v_idx:
|
||||
tik_inst.vtranspose(
|
||||
d_v_ub[v_idx, h_idx, 0, 0], d_v_ub[v_idx, h_idx, 0, 0])
|
||||
tik_inst.data_move(d_v_gm[bs_idx*seq_len//16+w_idx * (block_size // 16) + h_idx,
|
||||
head_idx*v_embedding//16, 0, 0],
|
||||
d_v_ub[0, h_idx, 0, 0], 0,
|
||||
v_embedding // 16, 16 * 16 * 2 // block_bite_size,
|
||||
(head_size // 16 - 1) * 16 * 16 * 2 // 32, 0)
|
||||
|
||||
with tik_inst.new_stmt_scope():
|
||||
# dw = da * v^t
|
||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||
# d_local_l1 = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
||||
# name='d_local_l1', scope=tik.scope_cbuf)
|
||||
d_local_l0a = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
||||
name='d_local_l0a', scope=tik.scope_ca)
|
||||
|
||||
v_local_l1 = tik_inst.Tensor('float16', (v_embedding // 16, head_size // 16, 16, 16),
|
||||
name='v_local_l1', scope=tik.scope_cbuf)
|
||||
v_local_l0b = tik_inst.Tensor('float16', (v_embedding // 16, head_size // 16, 16, 16),
|
||||
name='v_local_l0b', scope=tik.scope_cb)
|
||||
|
||||
# d_w_local
|
||||
d_w_local_l0c = tik_inst.Tensor('float32', (head_size // 16, block_size // 16, 16, 16),
|
||||
name='d_w_local_l0c', scope=tik.scope_cc)
|
||||
|
||||
d_w_local_ub_32 = tik_inst.Tensor('float32', (head_size // 16, block_size // 16, 16, 16),
|
||||
name='d_w_local_ub', scope=tik.scope_ubuf)
|
||||
|
||||
d_w_local_ub = tik_inst.Tensor('float16', (head_size // 16, block_size // 16, 16, 16),
|
||||
name='d_w_local_ub', scope=tik.scope_ubuf)
|
||||
|
||||
tik_inst.load2dv1(d_local_l0a[0, 0, 0, 0], d_a_l1[w_idx*(block_size//16), 0, 0, 0],
|
||||
0, (block_size*v_embedding)//(16*16), 1, 0, False)
|
||||
|
||||
# v_gm = (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16)
|
||||
# v_local_l1 = (v_embedding//16, head_size//16, 16, 16)
|
||||
with tik_inst.for_range(0, head_size//16) as brick_i:
|
||||
tik_inst.data_move(v_local_l1[0, brick_i, 0, 0],
|
||||
v_gm[bs_idx*seq_len//16+w_idx *
|
||||
(head_size//16)+brick_i, head_idx*v_embedding//16, 0, 0],
|
||||
0, v_embedding//16, 16*16*2//block_bite_size,
|
||||
0, (head_size//16-1)*16*16*2//block_bite_size)
|
||||
|
||||
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0],
|
||||
0, v_embedding*head_size//(16*16), 1, 0, True)
|
||||
|
||||
# dw
|
||||
tik_inst.mmad(d_w_local_l0c, d_local_l0a, v_local_l0b,
|
||||
block_size, v_embedding, head_size, 0)
|
||||
|
||||
tik_inst.data_move(d_w_local_ub_32[0, 0, 0, 0], d_w_local_l0c[0, 0, 0, 0], 0,
|
||||
1, head_size*block_size*4//1024,
|
||||
0, 0)
|
||||
|
||||
tik_inst.vconv(64, '', d_w_local_ub[0, 0, 0, 0], d_w_local_ub_32[0, 0, 0, 0],
|
||||
head_size * block_size//64, 1, 1, 4, 8)
|
||||
|
||||
# d_w1_gm = (batch_size, head, block_num, head_size // 16, block_size // 16, 16, 16)
|
||||
tik_inst.data_move(d_w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_local_ub[0, 0, 0, 0], 0,
|
||||
1, head_size*block_size*2//block_bite_size,
|
||||
0, 0)
|
||||
|
||||
# calculate d_w_global
|
||||
with tik_inst.new_stmt_scope():
|
||||
# load2d permute
|
||||
v_global_l1 = tik_inst.Tensor('float16', (v_embedding//16, global_size//16, 16, 16),
|
||||
name='v_global_l1', scope=tik.scope_cbuf)
|
||||
|
||||
# v_gm = (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16)
|
||||
# tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx, global_idx, head_idx, 0, 0, 0], 0,
|
||||
# seq_len//(4*16), (16*v_embedding*2)//block_bite_size,
|
||||
# ((4*head*v_embedding*16-16*v_embedding)*2)//block_bite_size, 0)
|
||||
with tik_inst.for_range(0, block_num) as w_idx:
|
||||
tik_inst.data_move(v_global_l1[0, w_idx, 0, 0],
|
||||
v_gm[bs_idx*seq_len//16 + (
|
||||
w_idx * (block_size//16) + global_idx), head_idx * v_embedding//16, 0, 0],
|
||||
0, v_embedding//16, 16*16*2//block_bite_size,
|
||||
0, (global_size // 16 - 1)*16*16*2//block_bite_size)
|
||||
|
||||
with tik_inst.for_range(0, block_num * ub_time, thread_num=2) as w_idx:
|
||||
# tik_inst.tikdb.debug_print("'v_embedding//(16*cpt_time): '+str(v_embedding//(16*cpt_time))")
|
||||
# d_global_l1 = tik_inst.Tensor('float16', (head_size//16, v_embedding//(16*cpt_time), 16, 16),
|
||||
# name='d_global_l1', scope=tik.scope_cbuf)
|
||||
d_global_l0a = tik_inst.Tensor('float16', (head_size // (16*ub_time),
|
||||
v_embedding // (16*cpt_time), 16, 16),
|
||||
name='d_global_l0a', scope=tik.scope_ca)
|
||||
|
||||
v_global_l0b = tik_inst.Tensor('float16', (v_embedding // (16*cpt_time),
|
||||
global_size // 16, 16, 16),
|
||||
name='v_global_l0b', scope=tik.scope_cb)
|
||||
|
||||
# d_w_global,小z大n
|
||||
d_w_global_l0c = tik_inst.Tensor('float32', (global_size//16, head_size//(16*ub_time), 16, 16),
|
||||
name='d_w_global_l0c', scope=tik.scope_cc)
|
||||
d_w_global_ub = tik_inst.Tensor('float16', (global_size // 16,
|
||||
head_size // (16*ub_time), 16, 16),
|
||||
name='d_w_global_ub', scope=tik.scope_ubuf)
|
||||
d_w_global_ub_32 = tik_inst.Tensor('float32', (global_size // 16,
|
||||
head_size // (16*ub_time), 16, 16),
|
||||
name='d_w_global_ub_32', scope=tik.scope_ubuf)
|
||||
|
||||
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
||||
tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0],
|
||||
v_global_l1[cpt_idx * v_embedding //
|
||||
(16 * cpt_time), 0, 0, 0], 0,
|
||||
global_size * v_embedding // (16 * 16 * cpt_time), 1, 0, True)
|
||||
# d_global_gm = (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
|
||||
# d_global_l1 = (head_size//16, v_embedding//(16*cpt_time), 16, 16)
|
||||
# with tik_inst.for_range(0, v_embedding//(16*cpt_time)) as brick_i:
|
||||
# tik_inst.data_move(d_global_l1[0, brick_i, 0, 0],
|
||||
# d_global_gm[bs_idx,
|
||||
# head_idx, brick_i + cpt_idx * (v_embedding//(16*cpt_time)),
|
||||
# w_idx*(head_size//16), 0, 0], 0,
|
||||
# head_size//16, (16*16*2)//block_bite_size,
|
||||
# 0, (v_embedding//(16*cpt_time)-1)*16*16*2//block_bite_size)
|
||||
#
|
||||
# tik_inst.load2dv1(d_global_l0a[0, 0, 0, 0], d_global_l1[0, 0, 0, 0],
|
||||
# 0, (head_size*v_embedding)//(16*16*cpt_time), 1, 0, False)
|
||||
|
||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
||||
# d_global_l0a = (head_size // 16, v_embedding // (16*cpt_time), 16, 16)
|
||||
with tik_inst.for_range(0, head_size//(16*ub_time)) as brick_i:
|
||||
tik_inst.load2dv1(d_global_l0a[brick_i, 0, 0, 0],
|
||||
d_a_l1[w_idx*(block_size//(16*ub_time)) + brick_i,
|
||||
cpt_idx*v_embedding//(16*cpt_time), 0, 0],
|
||||
0, (16*v_embedding)//(16*16*cpt_time), 1, 0, False)
|
||||
|
||||
# (head_size, global_size) = (head_size, v_embedding//cpttime) *
|
||||
# (v_embedding//cpttime, global_size)
|
||||
with tik_inst.if_scope(cpt_idx == 0):
|
||||
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
||||
head_size//ub_time, v_embedding//cpt_time, global_size, 0)
|
||||
with tik_inst.else_scope():
|
||||
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
||||
head_size//ub_time, v_embedding//cpt_time, global_size, 1)
|
||||
|
||||
tik_inst.data_move(d_w_global_ub_32[0, 0, 0, 0], d_w_global_l0c[0, 0, 0, 0], 0,
|
||||
1, head_size*global_size*4//(1024*ub_time),
|
||||
0, 0)
|
||||
|
||||
# tik_inst.tikdb.debug_print("'d_w_global_ub_32: '+str(d_global_l1)")
|
||||
|
||||
# (global_size // 16, head_size // 16, 16, 16)
|
||||
rpt_time = global_size//(16*8)
|
||||
with tik_inst.for_range(0, rpt_time) as conv_i:
|
||||
tik_inst.vconv(64, '',
|
||||
d_w_global_ub[conv_i*global_size //
|
||||
(16*rpt_time), 0, 0, 0],
|
||||
d_w_global_ub_32[conv_i*global_size //
|
||||
(16*rpt_time), 0, 0, 0],
|
||||
global_size * head_size//(64*rpt_time*ub_time), 1, 1, 4, 8)
|
||||
|
||||
# tik_inst.vconv(64, '', d_w_global_ub[0, 0, 0, 0], d_w_global_ub_32[0, 0, 0, 0],
|
||||
# global_size * head_size//64, 1, 1, 4, 8)
|
||||
|
||||
# d_w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
||||
# d_w_global_ub = (global_size // 16, head_size // (16*ub_time), 16, 16)
|
||||
with tik_inst.if_scope(ub_time == 1):
|
||||
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_global_ub[0, 0, 0, 0], 0,
|
||||
1, head_size*global_size *
|
||||
2//(block_bite_size),
|
||||
0, 0)
|
||||
with tik_inst.else_scope():
|
||||
w_idx_i = w_idx // 2
|
||||
h_idx = (w_idx % 2) * 2 # 0/2
|
||||
|
||||
# tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx_i, 0,h_idx,0,0], d_w_global_ub[0,0,0,0], 0,
|
||||
# 1, head_size*global_size*2//(block_bite_size*ub_time),
|
||||
# 0,0)
|
||||
|
||||
with tik_inst.for_range(0, head_size//(16*ub_time)) as m_idx:
|
||||
# d_w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
||||
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx_i, 0, h_idx + m_idx, 0, 0],
|
||||
d_w_global_ub[0, m_idx, 0, 0], 0,
|
||||
global_size//16, 16*16*2//block_bite_size,
|
||||
(head_size//(16*ub_time) - 1) *
|
||||
16*16*2//block_bite_size,
|
||||
(head_size//16 - 1)*16*16*2//block_bite_size)
|
||||
|
||||
# tik_inst.tikdb.debug_print("'d_w_global_ub_32: '+str(d_global_l1)")
|
||||
# tik_inst.tikdb.debug_print("'d_w2_gm: '+str(d_w2_gm[bs_idx, head_idx, w_idx_i, 0,h_idx,:,:])")
|
||||
|
||||
tik_inst.BuildCCE(kernel_name=kernel_name,
|
||||
inputs=[w1_gm, w2_gm, v_gm, a_gm, d_a_gm],
|
||||
outputs=[d_w1_gm, d_w2_gm, d_v_gm])
|
||||
return tik_inst
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" dense sparse to densne matmul"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
||||
from te import tik
|
||||
from topi.cce import util
|
||||
|
||||
dsd_matmul_info = TBERegOp('DSDMatmul') \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("dsdmatmul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("DSDMatmulimpl") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_w1", False, "required", "all") \
|
||||
.input(1, "input_w2", False, "required", "all") \
|
||||
.input(2, "input_v", False, "required", "all") \
|
||||
.output(0, "output_y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dsd_matmul_info)
|
||||
def DSDMatmulimpl(input_w1, input_w2, input_v, output_y={}, kernel_name='DSDMatmulimpl'):
|
||||
""" dense sparse to densne matmul"""
|
||||
# shape_w1 = input_w1.get('shape')
|
||||
# shape_w2 = input_w2.get('shape')
|
||||
# shape_v = input_v.get('shape')
|
||||
if util.get_product_version() == util.VERSION_MINI:
|
||||
tik_inst = tik.Tik(tik.Dprofile("v100", "mini"))
|
||||
else:
|
||||
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
||||
|
||||
# (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
||||
input_w1_shape = input_w1.get('shape')
|
||||
# (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
||||
input_w2_shape = input_w2.get('shape')
|
||||
input_v_shape = input_v.get('shape')
|
||||
|
||||
batch_size = input_w1_shape[0]
|
||||
head = input_w1_shape[1]
|
||||
block_num = input_w1_shape[2]
|
||||
block_size = input_w1_shape[4] * 16
|
||||
head_size = input_w1_shape[3] * 16
|
||||
global_size = input_w2_shape[3] * 16
|
||||
v_embedding = input_v_shape[1] * 16 // head
|
||||
seq_len = input_v_shape[0] * 16 // batch_size
|
||||
|
||||
block_bite_size = 32
|
||||
cpt_time = seq_len//512
|
||||
if v_embedding == 128:
|
||||
thread = 1
|
||||
else:
|
||||
thread = 2
|
||||
|
||||
# # w:zN
|
||||
# # 4, 16, 1024//64, 64//16, 64//16, 16*16
|
||||
# w1_gm = tik_inst.Tensor('float16', (1, 2, 16, 4, 4, 16, 16), name='w1_gm', scope=tik.scope_gm)
|
||||
# w2_gm = tik_inst.Tensor('float16', (1, 2, 16, 16, 4, 16, 16), name='w2_gm', scope=tik.scope_gm)
|
||||
#
|
||||
# # v:nZ
|
||||
# v_gm = tik_inst.Tensor('float16', (64, 16, 16, 16), name='v_gm', scope=tik.scope_gm)
|
||||
|
||||
# w:zN
|
||||
w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size //
|
||||
16, block_size//16, 16, 16), name='w1_gm', scope=tik.scope_gm)
|
||||
w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size //
|
||||
16, head_size//16, 16, 16), name='w2_gm', scope=tik.scope_gm)
|
||||
#
|
||||
# # v:nZ
|
||||
v_gm = tik_inst.Tensor('float16', (batch_size*seq_len//16,
|
||||
head*v_embedding//16, 16, 16), name='v_gm', scope=tik.scope_gm)
|
||||
|
||||
# zN
|
||||
output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len//16, 16, 16), name='output_gm',
|
||||
scope=tik.scope_gm)
|
||||
|
||||
channel_num = batch_size*head
|
||||
with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx:
|
||||
head_idx = channel_idx // batch_size
|
||||
bs_idx = channel_idx % batch_size
|
||||
|
||||
output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16),
|
||||
name='output_l0c',
|
||||
scope=tik.scope_cc)
|
||||
|
||||
output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16),
|
||||
name='output_ub_32',
|
||||
scope=tik.scope_ubuf)
|
||||
|
||||
output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16),
|
||||
name='output_ub',
|
||||
scope=tik.scope_ubuf)
|
||||
# zZ
|
||||
w1_l1 = tik_inst.Tensor(
|
||||
'float16', (block_size//16, head_size//16, 16, 16), name='w1_l1', scope=tik.scope_cbuf)
|
||||
# nZ
|
||||
v_local_l1 = tik_inst.Tensor(
|
||||
'float16', (head_size//16, v_embedding//16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf)
|
||||
|
||||
# zZ
|
||||
w2_l1 = tik_inst.Tensor('float16', (head_size//16, global_size//(16*cpt_time), 16, 16),
|
||||
name='w2_l1',
|
||||
scope=tik.scope_cbuf)
|
||||
|
||||
# nZ
|
||||
# use same v_global
|
||||
v_global_l1 = tik_inst.Tensor('float16', (global_size//16, v_embedding//16, 16, 16),
|
||||
name='v_global_l1',
|
||||
scope=tik.scope_cbuf)
|
||||
# global v
|
||||
global_idx = 3 - head_idx % 4
|
||||
tik_inst.data_move(v_global_l1[0, 0, 0, 0],
|
||||
v_gm[bs_idx*seq_len//16+global_idx,
|
||||
head_idx*v_embedding//16, 0, 0], 0,
|
||||
seq_len//(4*16), 16*v_embedding*2//block_bite_size,
|
||||
(4*head*v_embedding*16-16*v_embedding)*2//block_bite_size, 0)
|
||||
|
||||
# every block size is 64, local和global输出均为(1024,128)的小z大n矩阵
|
||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||
# global
|
||||
with tik_inst.new_stmt_scope():
|
||||
w2_l0a = tik_inst.Tensor('float16', (head_size//16, global_size//(cpt_time*16), 16, 16),
|
||||
name='w2_l0a', scope=tik.scope_ca)
|
||||
|
||||
v_global_l0b = tik_inst.Tensor('float16', (global_size//(cpt_time*16), v_embedding//16, 16, 16),
|
||||
name='v_global_l0b', scope=tik.scope_cb)
|
||||
|
||||
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
||||
with tik_inst.for_range(0, head_size//16) as brick_i:
|
||||
# tik_inst.tikdb.debug_print("'w_idx: '+str(w_idx)")
|
||||
# tik_inst.tikdb.debug_print("'w2_gm: '+str(w2_gm.shape)")
|
||||
# tik_inst.tikdb.debug_print("'brick_i: '+str(brick_i)")
|
||||
# tik_inst.tikdb.debug_print("'(block_size//16-1)*16*16*2//block_bite_size:
|
||||
# '+str((block_size//16-1)*16*16*2//block_bite_size)")
|
||||
tik_inst.data_move(w2_l1[brick_i, 0, 0, 0],
|
||||
w2_gm[bs_idx, head_idx, w_idx, cpt_idx *
|
||||
global_size//(16*cpt_time), brick_i, 0, 0], 0,
|
||||
global_size//(16*cpt_time), 16 *
|
||||
16*2//block_bite_size,
|
||||
(block_size//16-1)*16*16*2//block_bite_size, 0)
|
||||
|
||||
tik_inst.load2dv1(
|
||||
w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size*global_size//(cpt_time*16*16), 1, 0)
|
||||
|
||||
tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx*global_size//(
|
||||
16*cpt_time), 0, 0, 0], 0, global_size*v_embedding//(16*16*cpt_time), 1, 0)
|
||||
|
||||
with tik_inst.if_scope(cpt_idx == 0):
|
||||
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
||||
block_size, global_size//cpt_time, v_embedding, 0)
|
||||
with tik_inst.else_scope():
|
||||
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
||||
block_size, global_size//cpt_time, v_embedding, 1)
|
||||
|
||||
# local
|
||||
with tik_inst.new_stmt_scope():
|
||||
w1_l0a = tik_inst.Tensor('float16', (block_size//16, head_size//16, 16, 16),
|
||||
name='w1_l0a', scope=tik.scope_ca)
|
||||
v_local_l0b = tik_inst.Tensor('float16', (head_size//16, v_embedding//16, 16, 16),
|
||||
name='v_local_l0b', scope=tik.scope_cb)
|
||||
|
||||
# v
|
||||
tik_inst.data_move(v_local_l1[0, 0, 0, 0],
|
||||
v_gm[bs_idx * seq_len//16 + w_idx * 4, head_idx *
|
||||
v_embedding//16, 0, 0], 0, block_size//16,
|
||||
16 * v_embedding * 2 // block_bite_size,
|
||||
16*(head-1)*v_embedding*2//block_bite_size, 0)
|
||||
|
||||
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0,
|
||||
head_size*v_embedding//(16*16), 1, 0)
|
||||
|
||||
# w
|
||||
with tik_inst.for_range(0, block_size // 16) as brick_i:
|
||||
tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0,
|
||||
head_size // 16, (16 *
|
||||
16*2)//block_bite_size,
|
||||
(block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0)
|
||||
tik_inst.load2dv1(
|
||||
w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size*head_size//(16*16), 1, 0)
|
||||
|
||||
tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b,
|
||||
block_size, head_size, v_embedding, 1)
|
||||
|
||||
tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0,
|
||||
1, block_size * v_embedding * 4 // 1024, 0, 0)
|
||||
|
||||
tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0],
|
||||
v_embedding * block_size//64, 1, 1, 4, 8)
|
||||
|
||||
tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx*(block_size//16), 0, 0], output_ub[0, 0, 0, 0],
|
||||
0, v_embedding//16, 16*block_size*2//block_bite_size, 0,
|
||||
(seq_len - block_size)*16*2//block_bite_size)
|
||||
|
||||
tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm],
|
||||
outputs=[output_gm])
|
||||
return tik_inst
|
|
@ -0,0 +1,652 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""matmul dds impl"""
|
||||
from te import tik
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
matmul_dds_grad_op_info = TBERegOp("CusMatmulDDSGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("matmul_dds_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("matmul_dds_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "q", False, "required", "all") \
|
||||
.input(1, "k", False, "required", "all") \
|
||||
.input(2, "local_prob", False, "required", "all") \
|
||||
.input(3, "global_prob", False, "required", "all") \
|
||||
.input(4, "local_prob_grad", False, "required", "all") \
|
||||
.input(5, "global_prob_grad", False, "required", "all") \
|
||||
.output(0, "dq", False, "required", "all") \
|
||||
.output(1, "dk", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(matmul_dds_grad_op_info)
|
||||
def matmul_dds_grad(q,
|
||||
k,
|
||||
local_prob,
|
||||
global_prob,
|
||||
local_prob_grad,
|
||||
global_prob_grad,
|
||||
dq,
|
||||
dk,
|
||||
kernel_name="matmul_dds_grad"):
|
||||
"""
|
||||
:param q: the dict of input q (bs*seq_len, embedding_size) zN
|
||||
:param k: the dict of input k (bs*seq_len, embedding_size) nZ
|
||||
:param local_mask: the dict of input mask local (bs*16*64, 64) zN
|
||||
:param global_mask: the dict of input mask global (heads*1024, 256) zN
|
||||
:param local_prob: local output (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN
|
||||
:param global_prob: global output (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN
|
||||
:param local_prob_grad: local output grad (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN
|
||||
:param global_prob_grad: global output grad (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN
|
||||
"""
|
||||
# seq_len = 1024
|
||||
# size_per_head = 128
|
||||
# block_size = 64
|
||||
# block_num = 16
|
||||
# global_size = 256
|
||||
# bs = 2
|
||||
# heads = 2
|
||||
shape_q = q.get(
|
||||
'shape')
|
||||
shape_lc = local_prob.get(
|
||||
'shape')
|
||||
shape_gc = global_prob.get(
|
||||
'shape')
|
||||
bs = shape_lc[0]
|
||||
heads = shape_gc[1]
|
||||
global_size = shape_gc[3] * shape_gc[-1]
|
||||
block_size = shape_lc[4] * shape_lc[5]
|
||||
seq_len = shape_q[1] * shape_q[2] // bs
|
||||
block_num = seq_len // block_size
|
||||
size_per_head = shape_q[0] * shape_q[-1] // heads
|
||||
|
||||
tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud'))
|
||||
mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16),
|
||||
name="mat_q",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16),
|
||||
name="mat_k",
|
||||
scope=tik.scope_gm) # nZ
|
||||
mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_lc",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_gc",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_lc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_lc_grad",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_gc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_gc_grad",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_dq = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16),
|
||||
name="mat_dq",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_dk = tik_inst.Tensor("float16", (bs * seq_len // 16, size_per_head * heads // 16, 16, 16),
|
||||
name="mat_dk",
|
||||
scope=tik.scope_gm) # zN
|
||||
|
||||
channel_num = bs * heads
|
||||
with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index:
|
||||
# apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ
|
||||
mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16),
|
||||
name='mat_l1_ones',
|
||||
scope=tik.scope_cbuf)
|
||||
with tik_inst.new_stmt_scope():
|
||||
mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16),
|
||||
name='mat_ub_ones',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vec_dup(128, mat_ub_ones, 1.0,
|
||||
(global_size + block_size) * 16 // 128, 8)
|
||||
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
||||
0, (global_size + block_size) // 16, 16, 0, 0)
|
||||
# all_head = 32 * rx + block_index
|
||||
b = tik_inst.Scalar(dtype="int32")
|
||||
b.set_as(block_index // heads)
|
||||
|
||||
# head = block_index - b * heads
|
||||
head = tik_inst.Scalar(dtype="int32")
|
||||
head.set_as(block_index - b * heads)
|
||||
# s = head // 4
|
||||
s = tik_inst.Scalar(dtype="int32")
|
||||
s.set_as(head // 4)
|
||||
# global_idx = 3 - (head - 4 * s) # global idx for global key extraction
|
||||
global_idx = tik_inst.Scalar(dtype="int32")
|
||||
global_idx.set_as(3 - (head - 4 * s))
|
||||
# apply tensor in l1 for global k (256, 128) nZ
|
||||
mat_l1_gk = tik_inst.Tensor("float16",
|
||||
(global_size // 16, size_per_head // 16, 16, 16),
|
||||
name="mat_l1_gk",
|
||||
scope=tik.scope_cbuf)
|
||||
# apply for tensor in L0C for global dk (128, 256) zN
|
||||
mat_l0c_dkg = tik_inst.Tensor("float32",
|
||||
(global_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_l0c_dkg",
|
||||
scope=tik.scope_cc)
|
||||
with tik_inst.for_range(0, global_size // 16) as gb:
|
||||
# move global key from gm to L1 nZ
|
||||
# the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16)
|
||||
tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0],
|
||||
mat_k[
|
||||
head * size_per_head // 16, b * seq_len // 16 +
|
||||
global_idx + gb * block_size // 16, 0, 0],
|
||||
0, size_per_head // 16, 16, bs * seq_len - 16, 0)
|
||||
with tik_inst.for_range(0, block_num) as block:
|
||||
# do backward softmax
|
||||
# grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
|
||||
# apply for tensor in ub for grad_x out (64, 320) zN
|
||||
mat_ub_lg_d = tik_inst.Tensor("float16",
|
||||
((global_size + block_size) //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='mat_ub_lg_d',
|
||||
scope=tik.scope_ubuf)
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply for tensor in ub for softmax out (64, 320) zN
|
||||
mat_ub_lg = tik_inst.Tensor("float16", ((global_size + block_size) // 16, block_size // 16, 16, 16),
|
||||
name='mat_ub_lg',
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in ub for softmax out grad (64, 320) zN
|
||||
mat_ub_lg_grad = tik_inst.Tensor("float16",
|
||||
((global_size + block_size) //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='mat_ub_lg_grad',
|
||||
scope=tik.scope_ubuf)
|
||||
# move local out from gm to ub zN
|
||||
# the shape of local out in gm is zN
|
||||
# the shape of local out in UB is zN
|
||||
# the stride between each (64, 16) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_lc[b, head, block, 0, 0, 0, 0], 0,
|
||||
block_size // 16, block_size,
|
||||
0, 0)
|
||||
# move global out from gm to ub zN
|
||||
# the shape of global out in gm is zN
|
||||
# the shape of global out in UB is zN
|
||||
# the stride between each (64, 16) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.data_move(mat_ub_lg[block_size // 16, 0, 0, 0], mat_gc[b, head, block, 0, 0, 0, 0], 0,
|
||||
global_size // 16, block_size,
|
||||
0, 0)
|
||||
# move local out grad from gm to ub zN
|
||||
# the shape of local out grad in gm is zN
|
||||
# the shape of local out grad in UB is zN
|
||||
# the stride between each (64, 16) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_ub_lg_grad[0, 0, 0, 0], mat_lc_grad[b, head, block, 0, 0, 0, 0], 0,
|
||||
block_size // 16, block_size,
|
||||
0, 0)
|
||||
# move global out grad from gm to ub zN
|
||||
# the shape of global out grad in gm is zN
|
||||
# the shape of global out grad in UB is zN
|
||||
# the stride between each (64, 16) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.data_move(mat_ub_lg_grad[block_size // 16, 0, 0, 0],
|
||||
mat_gc_grad[b, head, block, 0, 0, 0, 0], 0,
|
||||
global_size // 16, block_size,
|
||||
0, 0)
|
||||
# apply for tensor in ub for softmax multiply out grad (64, 320) zN
|
||||
mat_ub_ssg = tik_inst.Tensor("float16",
|
||||
((global_size + block_size) //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='mat_ub_ssg',
|
||||
scope=tik.scope_ubuf)
|
||||
# calculate softmax * softmax_grad
|
||||
tik_inst.vmul(128, mat_ub_ssg[0, 0, 0, 0], mat_ub_lg_grad[0, 0, 0, 0], mat_ub_lg[0, 0, 0, 0],
|
||||
(global_size + block_size) * block_size // 128,
|
||||
1, 1, 1, 8, 8, 8)
|
||||
|
||||
# apply for tensor in L1 for dsoftmax*softmax result (320, 64) nZ
|
||||
mat_l1_ssg_nZ = tik_inst.Tensor("float16", ((global_size + block_size) // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name='mat_l1_ssg_nZ',
|
||||
scope=tik.scope_cbuf)
|
||||
# move ones from ub to L1 for CUBE mmad
|
||||
# the shape of ones in ub is nZ
|
||||
# the shape of ones in L0A is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.data_move(mat_l1_ssg_nZ[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], 0,
|
||||
(global_size + block_size) // 16, block_size, 0, 0)
|
||||
# apply tensor in l0c for exp sum (16, 64) zN
|
||||
mat_l0c_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
||||
name='mat_l0c_ssg_sum',
|
||||
scope=tik.scope_cc)
|
||||
# apply tensor in ub for exp sum (16, 64) zN
|
||||
mat_ub_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
||||
name='mat_ub_ssg_sum',
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in L0A for q (16, 320) zZ
|
||||
mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 16, 16, 16),
|
||||
name='mat_l0a_ones', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for exp (320, 64) nZ
|
||||
mat_l0b_ssg = tik_inst.Tensor('float16', ((global_size + block_size) // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0b_exp', scope=tik.scope_cb)
|
||||
# move ones from l1 to L0A for CUBE mmad
|
||||
# the shape of ones in l1 is zZ
|
||||
# the shape of ones in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0,
|
||||
(global_size + block_size) * 16 // (16 * 16), 1, 0, False)
|
||||
# move ssg from l1 to L0B for CUBE mmad
|
||||
# the shape of ssg in l1 is nZ
|
||||
# the shape of ssg in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 128 times
|
||||
tik_inst.load2dv1(mat_l0b_ssg[0, 0, 0, 0], mat_l1_ssg_nZ[0, 0, 0, 0], 0,
|
||||
(global_size + block_size) * block_size // (16 * 16), 1, 0, False)
|
||||
tik_inst.mmad(mat_l0c_ssg_sum, mat_l0a_ones, mat_l0b_ssg,
|
||||
16, (global_size + block_size), block_size, 0)
|
||||
tik_inst.data_move(mat_ub_ssg_sum[0, 0, 0, 0], mat_l0c_ssg_sum[0, 0, 0, 0], 0,
|
||||
block_size // 16, 1, 0, 0)
|
||||
# apply for tensor in UB for global prob sum (64,)
|
||||
mat_ub_ssg_sums = tik_inst.Tensor("float32", (block_size,),
|
||||
name='mat_ub_ssg_sums',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.data_move(mat_ub_ssg_sums[0], mat_ub_ssg_sum[0, 0, 0, 0],
|
||||
0, block_size // 16, 1*2, 15*2, 0)
|
||||
# apply for tensor in UB for global prob sum (64,)
|
||||
mat_ub_ssg_sums_16 = tik_inst.Tensor("float16", (block_size,),
|
||||
name='mat_ub_ssg_sums_16',
|
||||
scope=tik.scope_ubuf)
|
||||
# convert fp32 to fp16
|
||||
tik_inst.vec_conv(
|
||||
64, "", mat_ub_ssg_sums_16[0], mat_ub_ssg_sums[0], 1, 4, 8)
|
||||
|
||||
mat_ub_ssgs = tik_inst.Tensor("float16",
|
||||
((global_size + block_size) //
|
||||
16, block_size // 16, 16, 16),
|
||||
name='mat_ub_ssgs',
|
||||
scope=tik.scope_ubuf)
|
||||
|
||||
with tik_inst.for_range(0, block_size) as bbs:
|
||||
# apply for scalar in UB for prob sum rec
|
||||
sum_ssg = tik_inst.Scalar("float16",
|
||||
name='sum_ssg',
|
||||
init_value=0)
|
||||
# set value for scalar prob sum rec
|
||||
sum_ssg.set_as(mat_ub_ssg_sums_16[bbs])
|
||||
tik_inst.vec_muls(16, mat_ub_ssgs[0, bbs // 16, bbs % 16, 0],
|
||||
mat_ub_lg[0, bbs // 16, bbs %
|
||||
16, 0], sum_ssg,
|
||||
(global_size + block_size) // 16,
|
||||
block_size, block_size)
|
||||
|
||||
tik_inst.vsub(128, mat_ub_lg_d[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], mat_ub_ssgs[0, 0, 0, 0],
|
||||
(global_size + block_size) * block_size // 128,
|
||||
1, 1, 1, 8, 8, 8)
|
||||
|
||||
# local dq calculation
|
||||
# dw X K.T
|
||||
# apply tensor in l1 for local k (64, 128) nZ
|
||||
mat_l1_lk = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_l1_lk",
|
||||
scope=tik.scope_cbuf)
|
||||
# move k from gm to l1
|
||||
# the shape of local k in gm is nZ
|
||||
# the shape of local k in l1 is zZ
|
||||
# the stride between each (16, 16) is 1024*bs-64
|
||||
# repeat 8 times
|
||||
# LOOP 4 times
|
||||
with tik_inst.for_range(0, block_size // 16) as lb:
|
||||
tik_inst.data_move(mat_l1_lk[lb, 0, 0, 0],
|
||||
mat_k[head * size_per_head // 16, b * seq_len // 16 + (
|
||||
block * block_size) // 16 + lb, 0, 0],
|
||||
0, size_per_head // 16, 16, bs * seq_len - 16, 0)
|
||||
|
||||
# apply tensor in l1 for local dw (64, 128) zZ
|
||||
mat_l1_ldw = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_l1_ldw",
|
||||
scope=tik.scope_cbuf)
|
||||
# move local d-softmax from ub to l1
|
||||
# the shape of d-softmax in ub is zN
|
||||
# the shape of d-softmax in l1 is zZ
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 16 times
|
||||
with tik_inst.for_range(0, block_size // 16) as lb:
|
||||
tik_inst.data_move(mat_l1_ldw[lb, 0, 0, 0],
|
||||
mat_ub_lg_d[0, lb, 0, 0],
|
||||
0, block_size // 16, 16, block_size - 16, 0)
|
||||
# apply for tensor in L0C for local d-q (64, 128) zN
|
||||
mat_l0c_dq = tik_inst.Tensor("float32",
|
||||
(size_per_head // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_l0c_dq",
|
||||
scope=tik.scope_cc)
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply for tensor in L0A for q (64, 64) zZ
|
||||
mat_l0a_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0a_ldw', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for global k (128, 256) nZ
|
||||
mat_l0b_lk = tik_inst.Tensor('float16', (block_size // 16, size_per_head // 16, 16, 16),
|
||||
name='mat_l0b_lk', scope=tik.scope_cb)
|
||||
# move q from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is zZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.load2dv1(mat_l0a_ldw[0, 0, 0, 0], mat_l1_ldw[0, 0, 0, 0], 0,
|
||||
block_size * block_size // (16 * 16), 1, 0, False)
|
||||
# move local k from l1 to L0B for CUBE mmad
|
||||
# the shape of local k in l1 is zZ
|
||||
# the shape of local k in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0b_lk[0, 0, 0, 0], mat_l1_lk[0, 0, 0, 0], 0,
|
||||
block_size * size_per_head // (16 * 16), 1, 0, True)
|
||||
# matmul q and local dw
|
||||
# the shape of global scores in L0C is zN
|
||||
tik_inst.mmad(mat_l0c_dq, mat_l0a_ldw, mat_l0b_lk,
|
||||
block_size, block_size, size_per_head, 0)
|
||||
|
||||
# global dq calculation
|
||||
# apply tensor in l1 for global dw (64, 256) zZ
|
||||
mat_l1_gdw = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
global_size // 16, 16, 16),
|
||||
name="mat_l1_gdw",
|
||||
scope=tik.scope_cbuf)
|
||||
# move global dw from ub to l1
|
||||
# the shape of global dw in gm is zN
|
||||
# the shape of global dw in l1 is zZ
|
||||
# the stride between each (16, 16) is 1024*bs-64
|
||||
# repeat 8 times
|
||||
# LOOP 4 times
|
||||
with tik_inst.for_range(0, block_size // 16) as lb:
|
||||
tik_inst.data_move(mat_l1_gdw[lb, 0, 0, 0],
|
||||
mat_ub_lg_d[block_size // 16, lb, 0, 0],
|
||||
0, global_size // 16, 16, block_size - 16, 0)
|
||||
# apply for tensor in ub for dq (64, 128) zN
|
||||
mat_ub_dq = tik_inst.Tensor("float32",
|
||||
(size_per_head // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_ub_dq",
|
||||
scope=tik.scope_ubuf)
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply for tensor in L0A for global dw (64, 256) zZ
|
||||
mat_l0a_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16),
|
||||
name='mat_l0a_gdw', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for global k (256, 128) nZ
|
||||
mat_l0b_gk = tik_inst.Tensor('float16', (global_size // 16, size_per_head // 16, 16, 16),
|
||||
name='mat_l0b_gk', scope=tik.scope_cb)
|
||||
# move dw global from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is zZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.load2dv1(mat_l0a_gdw[0, 0, 0, 0], mat_l1_gdw[0, 0, 0, 0], 0,
|
||||
block_size * global_size // (16 * 16), 1, 0, False)
|
||||
# move local k from l1 to L0B for CUBE mmad
|
||||
# the shape of local k in l1 is zZ
|
||||
# the shape of local k in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0b_gk[0, 0, 0, 0], mat_l1_gk[0, 0, 0, 0], 0,
|
||||
global_size * size_per_head // (16 * 16), 1, 0, True)
|
||||
# matmul k and local dw
|
||||
# the shape of global scores in L0C is zN
|
||||
tik_inst.mmad(mat_l0c_dq, mat_l0a_gdw, mat_l0b_gk,
|
||||
block_size, global_size, size_per_head, 1)
|
||||
# move dq from l0c to UB
|
||||
# the shape of dq in l9c is zN
|
||||
# the shape of dq in ub is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 8 times
|
||||
tik_inst.data_move(mat_ub_dq[0, 0, 0, 0], mat_l0c_dq[0, 0, 0, 0], 0, size_per_head // 16,
|
||||
block_size // 16, 0, 0)
|
||||
|
||||
# local dk calculation
|
||||
# dk calculation q.T X dw
|
||||
# apply for tensor in ub for dw (320, 64) nZ
|
||||
mat_ub_lg_d_nZ = tik_inst.Tensor("float16",
|
||||
(block_size // 16, (global_size +
|
||||
block_size) // 16, 16, 16),
|
||||
name='mat_ub_lg_d_nZ',
|
||||
scope=tik.scope_ubuf)
|
||||
# transpose dw from zN to nZ
|
||||
with tik_inst.for_range(0, (global_size + block_size) // 16) as lb:
|
||||
with tik_inst.for_range(0, block_size // 16) as gb:
|
||||
tik_inst.vtranspose(
|
||||
mat_ub_lg_d_nZ[gb, lb, 0, 0], mat_ub_lg_d[lb, gb, 0, 0])
|
||||
|
||||
# apply tensor in l1 for local dw (64, 64) nZ
|
||||
mat_l1_ldw_nZ = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_l1_ldw_nZ",
|
||||
scope=tik.scope_cbuf)
|
||||
# move local dw from ub to l1
|
||||
# the shape of local dw in ub is nZ
|
||||
# the shape of local dw in l1 is nZ
|
||||
# the stride between each (16, 64) is 256
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_l1_ldw_nZ[0, 0, 0, 0],
|
||||
mat_ub_lg_d_nZ[0, 0, 0, 0],
|
||||
0, block_size // 16, block_size, global_size, 0)
|
||||
# apply for tensor in L1 for q (128, 64) nZ
|
||||
mat_l1_q_b = tik_inst.Tensor("float16",
|
||||
(size_per_head // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_l1_q_b",
|
||||
scope=tik.scope_cbuf)
|
||||
# move local q from gm to l1
|
||||
# the shape of local q in gm is zN
|
||||
# the shape of local dw in l1 is zZ
|
||||
# the stride between each (16, 16) is 48
|
||||
# repeat 4 times
|
||||
# LOOP 8 times
|
||||
with tik_inst.for_range(0, size_per_head // 16) as lb:
|
||||
tik_inst.load2dv1(mat_l1_q_b[lb, 0, 0, 0],
|
||||
mat_q[head * size_per_head // 16 + lb,
|
||||
b * seq_len // 16 + (block * block_size) // 16, 0, 0],
|
||||
0, block_size // 16, 1, 0, False)
|
||||
# apply for tensor in L0C for local dk (128, 64) zN
|
||||
mat_l0c_dkl = tik_inst.Tensor("float32",
|
||||
(block_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_l0c_dkl",
|
||||
scope=tik.scope_cc)
|
||||
# apply for tensor in ub for local dk (128, 64) zN
|
||||
mat_ub_ldk = tik_inst.Tensor("float32",
|
||||
(block_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_ub_ldk",
|
||||
scope=tik.scope_ubuf)
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply for tensor in L0A for q (128, 64) zZ
|
||||
mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0a_q', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for local dw (64, 64) nZ
|
||||
mat_l0b_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0b_ldw', scope=tik.scope_cb)
|
||||
# move q from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is nZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 4 times
|
||||
# LOOP 8 times
|
||||
tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0],
|
||||
mat_l1_q_b[0, 0, 0, 0],
|
||||
0, block_size * size_per_head // 256, 1, 0, True)
|
||||
# move local dw from l1 to L0B for CUBE mmad
|
||||
# the shape of local dw in l1 is nZ
|
||||
# the shape of local dw in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0b_ldw[0, 0, 0, 0], mat_l1_ldw_nZ[0, 0, 0, 0], 0,
|
||||
block_size * block_size // (16 * 16), 1, 0, False)
|
||||
# matmul q and local dw
|
||||
# the shape of local k in L0C is zN
|
||||
tik_inst.mmad(mat_l0c_dkl, mat_l0a_q, mat_l0b_ldw,
|
||||
size_per_head, block_size, block_size, 0)
|
||||
# move local dk from l0c to UB
|
||||
# the shape of local dk in l0C is zN
|
||||
# the shape of local dk in UB is zN
|
||||
# the stride between each (16, 128) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_ub_ldk[0, 0, 0, 0], mat_l0c_dkl[0, 0, 0, 0], 0, block_size // 16,
|
||||
size_per_head // 16, 0, 0)
|
||||
|
||||
# move global dw from UB to l1
|
||||
# apply for tensor in L1 for global dw (64, 256) nZ
|
||||
mat_l1_dwg_b = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
global_size // 16, 16, 16),
|
||||
name="mat_l1_dwg_b",
|
||||
scope=tik.scope_cbuf)
|
||||
# move global dw from UB to L1
|
||||
# the shape of global dw in gm is nZ
|
||||
# the shape of global dw in gm is nZ
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 8 times
|
||||
tik_inst.data_move(mat_l1_dwg_b[0, 0, 0, 0],
|
||||
mat_ub_lg_d_nZ[0, block_size // 16, 0, 0],
|
||||
0, block_size // 16, global_size, block_size, 0)
|
||||
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply for tensor in L0A for q (128, 64) zZ
|
||||
mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0a_q', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for local dw (64, 64) nZ
|
||||
mat_l0b_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16),
|
||||
name='mat_l0b_ldw', scope=tik.scope_cb)
|
||||
# move q from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is nZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 4 times
|
||||
# LOOP 8 times
|
||||
tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0],
|
||||
mat_l1_q_b[0, 0, 0, 0],
|
||||
0, block_size * size_per_head // 256, 1, 0, True)
|
||||
# move local dw from l1 to L0B for CUBE mmad
|
||||
# the shape of local dw in l1 is nZ
|
||||
# the shape of local dw in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0b_gdw[0, 0, 0, 0], mat_l1_dwg_b[0, 0, 0, 0], 0,
|
||||
block_size * global_size // (16 * 16), 1, 0, False)
|
||||
# matmul q and local dw
|
||||
# the shape of local k in L0C is zN
|
||||
with tik_inst.if_scope(block == 0):
|
||||
tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw,
|
||||
size_per_head, block_size, global_size, 0)
|
||||
with tik_inst.else_scope():
|
||||
tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw,
|
||||
size_per_head, block_size, global_size, 1)
|
||||
|
||||
# cast dq from 32 to 16
|
||||
# apply for tensor in ub for dq (64, 128) zN
|
||||
mat_ub_dq_16 = tik_inst.Tensor("float16",
|
||||
(size_per_head // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_ub_dq_16",
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in ub for local dk (128, 64) zN
|
||||
mat_ub_ldk_16 = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_ub_ldk_16",
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vec_conv(
|
||||
64, "", mat_ub_ldk_16[0, 0, 0, 0], mat_ub_ldk[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8)
|
||||
tik_inst.vec_conv(
|
||||
64, "", mat_ub_dq_16[0, 0, 0, 0], mat_ub_dq[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8)
|
||||
|
||||
# move dq from UB to gm
|
||||
# the shape of dq in UB is zN
|
||||
# the shape of dq in gm is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 8 times
|
||||
tik_inst.data_move(mat_dq[head * size_per_head // 16,
|
||||
b * seq_len // 16 + (block * block_size) // 16, 0, 0],
|
||||
mat_ub_dq_16[0, 0, 0,
|
||||
0], 0, size_per_head // 16, block_size, 0,
|
||||
bs * seq_len - block_size)
|
||||
# move local dk from UB to gm
|
||||
# the shape of local dk in UB is zN
|
||||
# the shape of local dk in gm is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 8 times
|
||||
tik_inst.data_move(mat_dk[b * seq_len // 16 + (block * block_size) // 16,
|
||||
head * size_per_head // 16, 0, 0],
|
||||
mat_ub_ldk_16[0, 0, 0,
|
||||
0], 0, block_size // 16, size_per_head, 0,
|
||||
heads * size_per_head - size_per_head)
|
||||
with tik_inst.for_range(0, global_size // 16) as lb:
|
||||
# apply for tensor in ub for global dk (128, 16) zN
|
||||
mat_ub_gdk_32 = tik_inst.Tensor("float32",
|
||||
(1, size_per_head // 16, 16, 16),
|
||||
name="mat_ub_gdk",
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in ub for global dk (128, 16) zN
|
||||
mat_ub_gdk = tik_inst.Tensor("float16",
|
||||
(1, size_per_head // 16, 16, 16),
|
||||
name="mat_ub_gdk",
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in ub for global dk (128, 16) zN
|
||||
mat_ub_ldk2 = tik_inst.Tensor("float16",
|
||||
(1, size_per_head // 16, 16, 16),
|
||||
name="mat_ub_ldk2",
|
||||
scope=tik.scope_ubuf)
|
||||
# move global dk from l0c to UB
|
||||
# the shape of global dk in l0C is zN
|
||||
# the shape of global dk in UB is zN
|
||||
# the stride between each (16, 128) is 0
|
||||
# repeat 1 times
|
||||
tik_inst.data_move(mat_ub_gdk_32[0, 0, 0, 0], mat_l0c_dkg[lb, 0, 0, 0], 0, 1,
|
||||
size_per_head // 16, 0, 0)
|
||||
tik_inst.vec_conv(
|
||||
64, "", mat_ub_gdk[0, 0, 0, 0], mat_ub_gdk_32[0, 0, 0, 0], size_per_head * 16 // 64, 4, 8)
|
||||
# move local dk from gm to UB
|
||||
# the shape of local dk in gm is zN
|
||||
# the shape of local dk in UB is zN
|
||||
# the stride between each (16, 128) is 0
|
||||
# repeat 1 times
|
||||
tik_inst.data_move(mat_ub_ldk2[0, 0, 0, 0], mat_dk[b * seq_len // 16 + 4 * lb + global_idx,
|
||||
head * size_per_head // 16, 0, 0], 0, 1,
|
||||
size_per_head, 0, 0)
|
||||
# add local dk and global dk
|
||||
mat_ub_dk = tik_inst.Tensor("float16",
|
||||
(1, size_per_head // 16, 16, 16),
|
||||
name="mat_ub_dk",
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vec_add(128, mat_ub_dk, mat_ub_ldk2, mat_ub_gdk,
|
||||
size_per_head * 16 // 128, 8, 8, 8)
|
||||
# move dk from UB to gm
|
||||
# the shape of dk in UB is zN
|
||||
# the shape of dk in gm is zN
|
||||
# the stride between each (16, 128) is 0
|
||||
# repeat 1 times
|
||||
tik_inst.data_move(
|
||||
mat_dk[b * seq_len // 16 + 4 * lb + global_idx,
|
||||
head * size_per_head // 16, 0, 0],
|
||||
mat_ub_dk[0, 0, 0, 0], 0, 1, size_per_head, 0, 0)
|
||||
tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lc, mat_gc, mat_lc_grad, mat_gc_grad],
|
||||
outputs=[mat_dq, mat_dk])
|
||||
return tik_inst
|
|
@ -0,0 +1,614 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""matmul dds impl"""
|
||||
from te import tik
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
matmul_dds_op_info = TBERegOp("CusMatmulDDS") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("matmul_dds.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("CusMatmulDDSImpl") \
|
||||
.partial_flag(True) \
|
||||
.attr("bs", "required", "int", "all") \
|
||||
.attr("heads", "required", "int", "all") \
|
||||
.input(0, "q", False, "required", "all") \
|
||||
.input(1, "k", False, "required", "all") \
|
||||
.input(2, "local_mask", False, "required", "all") \
|
||||
.input(3, "global_mask", False, "required", "all") \
|
||||
.output(0, "local_prob", False, "required", "all") \
|
||||
.output(1, "global_prob", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(matmul_dds_op_info)
|
||||
def CusMatmulDDSImpl(q,
|
||||
k,
|
||||
local_mask,
|
||||
global_mask,
|
||||
local_prob,
|
||||
global_prob,
|
||||
bs,
|
||||
heads,
|
||||
kernel_name="CusMatmulDDSImpl"):
|
||||
"""
|
||||
:param q: the dict of input q (bs*seq_len, embedding_size) zN
|
||||
:param k: the dict of input k (bs*seq_len, embedding_size) nZ
|
||||
:param bs: batch size int
|
||||
:param heads: number of heads int
|
||||
:param local_mask: the dict of input mask local (bs*16*64, 64) zN
|
||||
:param global_mask: the dict of input mask global (heads*1024, 256) zN
|
||||
:param kernel_name: dds_softmax
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# bs = 4
|
||||
# heads = 16
|
||||
|
||||
shape_q = q.get(
|
||||
'shape') # shape_q (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) zN
|
||||
shape_k = k.get(
|
||||
'shape') # shape_k (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) nZ
|
||||
shape_local_mask = local_mask.get(
|
||||
'shape') # shape_local_mask (16*64, bs*64) > (64, bs*4, 16, 16) zN
|
||||
shape_global_mask = global_mask.get(
|
||||
'shape') # shape_global_mask (heads*256, 1024) > (bs*16, 64, 16, 16) zN
|
||||
# sequence length only support 1024 for now
|
||||
seq_len = shape_q[1] * shape_q[2] // bs
|
||||
# size per head assume 128
|
||||
size_per_head = shape_q[0] * shape_q[-1] // heads
|
||||
block_size = shape_local_mask[0] # block size only support 64 for now
|
||||
block_num = seq_len // block_size # block number only support 16 for now
|
||||
global_size = seq_len // 4 # global size only support 256 for now
|
||||
# seq_len = 1024
|
||||
# size_per_head = 128
|
||||
# block_size = 64
|
||||
# block_num = 16
|
||||
# global_size = 256
|
||||
|
||||
tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud'))
|
||||
|
||||
mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16),
|
||||
name="mat_q",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16),
|
||||
name="mat_k",
|
||||
scope=tik.scope_gm) # nZ
|
||||
mat_lm = tik_inst.Tensor("float32", (block_num * block_size // 16, bs * block_size // 16, 16, 16),
|
||||
name="mat_lm",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_gm = tik_inst.Tensor("float32", (bs * global_size // 16, seq_len // 16, 16, 16),
|
||||
name="mat_gm",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_lc",
|
||||
scope=tik.scope_gm) # zN
|
||||
mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16),
|
||||
name="mat_gc",
|
||||
scope=tik.scope_gm) # zN
|
||||
|
||||
channel_num = bs * heads
|
||||
|
||||
with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index:
|
||||
# apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ
|
||||
mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16),
|
||||
name='mat_l1_ones',
|
||||
scope=tik.scope_cbuf)
|
||||
|
||||
with tik_inst.new_stmt_scope():
|
||||
mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16),
|
||||
name='mat_ub_ones',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vec_dup(128, mat_ub_ones, 1.0,
|
||||
(global_size + block_size) * 16 // 128, 8)
|
||||
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
||||
0, (global_size + block_size) // 16, 16, 0, 0)
|
||||
# b = block_index // heads
|
||||
b = tik_inst.Scalar(dtype="int32")
|
||||
b.set_as(block_index // heads)
|
||||
|
||||
# head = block_index - b * heads
|
||||
head = tik_inst.Scalar(dtype="int32")
|
||||
head.set_as(block_index - b * heads)
|
||||
# s = head // 4
|
||||
s = tik_inst.Scalar(dtype="int32")
|
||||
s.set_as(head // 4)
|
||||
# global_idx = 3 - (head - 4 * s) # global idx for global key extraction
|
||||
global_idx = tik_inst.Scalar(dtype="int32")
|
||||
global_idx.set_as(3 - (head - 4 * s))
|
||||
# apply tensor for global key which is (128, 256) in L1 nZ
|
||||
# for each head, global k is the same, put global k in L1 in order of reuse
|
||||
mat_l1_gk = tik_inst.Tensor("float16",
|
||||
(size_per_head // 16,
|
||||
global_size // 16, 16, 16),
|
||||
name="mat_l1_gk",
|
||||
scope=tik.scope_cbuf)
|
||||
with tik_inst.for_range(0, size_per_head // 16) as gb:
|
||||
# move global key from gm to L1 nZ
|
||||
# the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16)
|
||||
tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0],
|
||||
mat_k[head * size_per_head // 16 + gb,
|
||||
b * seq_len // 16 + global_idx, 0, 0],
|
||||
0, block_num, 16, 48, 0)
|
||||
|
||||
with tik_inst.for_range(0, block_num) as block:
|
||||
# calculate qk matmul block by block
|
||||
|
||||
# apply tensor in l0c for local mask (64, 64) zN
|
||||
mat_l0c_l = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0c_l',
|
||||
scope=tik.scope_cc)
|
||||
# apply tensor in l0c for global mask (256, 64) zN
|
||||
mat_l0c_g = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_l0c_g',
|
||||
scope=tik.scope_cc)
|
||||
# apply tensor in l1 for local k (128, 64) nZ
|
||||
mat_l1_lk = tik_inst.Tensor("float16",
|
||||
(size_per_head // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name="mat_l1_lk",
|
||||
scope=tik.scope_cbuf)
|
||||
# apply for tensor in L1 for fp 16 exp result (320, 64) zN
|
||||
mat_l1_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name='mat_l1_lg_exp_16',
|
||||
scope=tik.scope_cbuf)
|
||||
# convert exp out to fp 16
|
||||
# apply for tensor in UB for fp 16 exp result (64, 320) zN
|
||||
mat_ub_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16,
|
||||
block_size // 16, 16, 16),
|
||||
name='mat_ub_lg_exp_16',
|
||||
scope=tik.scope_ubuf)
|
||||
# move local k from gm to l1 nZ
|
||||
# the shape of local k in gm is nZ
|
||||
# the shape of local k in l1 is nZ
|
||||
# the stride between each (16, 64) is 1024*bs-64
|
||||
# repeat 8 times
|
||||
tik_inst.data_move(mat_l1_lk,
|
||||
mat_k[head * size_per_head // 16, b * seq_len // 16 + (
|
||||
block * block_size) // 16, 0, 0],
|
||||
0, size_per_head // 16, block_size, bs * seq_len - block_size, 0)
|
||||
# apply tensor in l1 for q (64, 128) zN
|
||||
mat_l1_q = tik_inst.Tensor("float16",
|
||||
(block_size // 16,
|
||||
size_per_head // 16, 16, 16),
|
||||
name="mat_l1_q",
|
||||
scope=tik.scope_cbuf)
|
||||
# move q from gm to l1
|
||||
# the shape of local k in gm is zN
|
||||
# the shape of local k in l1 is zZ
|
||||
# the stride between each (16, 16) is 1024*bs-64
|
||||
# repeat 8 times
|
||||
# LOOP 4 times
|
||||
with tik_inst.for_range(0, block_size // 16) as lb:
|
||||
tik_inst.data_move(mat_l1_q[lb, 0, 0, 0],
|
||||
mat_q[head * size_per_head // 16, b * seq_len // 16 + (
|
||||
block * block_size) // 16 + lb, 0, 0],
|
||||
0, size_per_head // 16, 16, bs * seq_len - 16, 0)
|
||||
|
||||
# global
|
||||
# apply a new scope
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply tensor in ub for global mask (256, 64) zN
|
||||
mat_ub_gm = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_ub_gm',
|
||||
scope=tik.scope_ubuf)
|
||||
# move global mask from gm to ub zN
|
||||
# the shape of global mask in gm is zN
|
||||
# the shape of global mask in UB is zN
|
||||
# the stride between each (64, 16) is 960
|
||||
# repeat 16 times
|
||||
tik_inst.data_move(mat_ub_gm,
|
||||
mat_gm[b * global_size // 16,
|
||||
block * block_size // 16, 0, 0],
|
||||
0, global_size // 16, block_size * 2, seq_len * 2 - block_size * 2, 0)
|
||||
# move global mask from ub to l0c for bias add
|
||||
# the shape of global mask in ub is zN
|
||||
# the shape of global mask in l0c is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.data_move(mat_l0c_g[0, 0, 0, 0],
|
||||
mat_ub_gm[0, 0, 0, 0],
|
||||
0, global_size // 16, block_size // 16, 0, 0)
|
||||
with tik_inst.for_range(0, 4, thread_num=2) as gb:
|
||||
# apply for tensor in L0A for q (64, 128) zZ
|
||||
mat_l0a_g = tik_inst.Tensor('float16',
|
||||
(block_size // 16, size_per_head //
|
||||
(16 * 4), 16, 16),
|
||||
name='mat_l0a_g', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for global k (128, 256) nZ
|
||||
mat_l0b_g = tik_inst.Tensor('float16',
|
||||
(size_per_head // (16 * 4),
|
||||
global_size // 16, 16, 16),
|
||||
name='mat_l0b_g', scope=tik.scope_cb)
|
||||
# move q from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is zZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
with tik_inst.for_range(0, block_size // 16) as bl:
|
||||
tik_inst.load2dv1(mat_l0a_g[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0,
|
||||
16 * size_per_head // (4 * 16 * 16), 1, 0, False)
|
||||
# move global k from l1 to L0B for CUBE mmad
|
||||
# the shape of global k in l1 is nZ
|
||||
# the shape of global k in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 128 times
|
||||
tik_inst.load2dv1(mat_l0b_g[0, 0, 0, 0], mat_l1_gk[size_per_head * gb // 64, 0, 0, 0], 0,
|
||||
global_size * size_per_head // (4 * 16 * 16), 1, 0, False)
|
||||
# matmul q and global k
|
||||
# the shape of global scores in L0C is zN
|
||||
tik_inst.mmad(mat_l0c_g, mat_l0a_g, mat_l0b_g,
|
||||
block_size, size_per_head // 4, global_size, 1)
|
||||
|
||||
# with tik_inst.for_range(0, global_size // 16, thread_num=2) as gb:
|
||||
# mat_ub_g = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
||||
# name='mat_ub_g',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.data_move(mat_ub_g[0, 0, 0, 0], mat_l0c_g[gb, 0, 0, 0], 0,
|
||||
# 1, block_size // 16, 0, 0)
|
||||
# mat_ub_g_exp = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
||||
# name='mat_ub_g_exp',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vexp(64, mat_ub_g_exp[0, 0, 0, 0],
|
||||
# mat_ub_g[0, 0, 0, 0], block_size * 16 // 64, 1, 1, 8, 8)
|
||||
# # cast fp32 exp to fp16 zN
|
||||
# tik_inst.vec_conv(64, "", mat_ub_lg_exp_16[gb, 0, 0, 0],
|
||||
# mat_ub_g_exp[0, 0, 0, 0],
|
||||
# block_size * 16 // 64, 4, 8)
|
||||
|
||||
# local
|
||||
# apply a new scope
|
||||
with tik_inst.new_stmt_scope():
|
||||
# apply tensor in ub for local mask (64, 64) zN
|
||||
mat_ub_lm = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_ub_lm',
|
||||
scope=tik.scope_ubuf)
|
||||
# move local mask from gm to ub zN
|
||||
# the shape of local mask in gm is zN
|
||||
# the shape of local mask in UB is zN
|
||||
# the stride between each (64, 16) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_ub_lm,
|
||||
mat_lm[block * block_size // 16,
|
||||
b * block_size // 16, 0, 0],
|
||||
0, block_size // 16, block_size * 2, (bs * block_size - block_size) * 2, 0)
|
||||
# move local mask from ub to l0c for bias add
|
||||
# the shape of local mask in ub is zN
|
||||
# the shape of local mask in l0c is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_l0c_l[0, 0, 0, 0],
|
||||
mat_ub_lm[0, 0, 0, 0],
|
||||
0, block_size // 16, block_size // 16, 0, 0)
|
||||
with tik_inst.for_range(0, 4, thread_num=2) as gb:
|
||||
# apply for tensor in L0A for q (64, 128) zZ
|
||||
mat_l0a_l = tik_inst.Tensor('float16', (block_size // 16, size_per_head // (16 * 4), 16, 16),
|
||||
name='mat_l0a_l', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for local k (128, 64) nZ
|
||||
mat_l0b_l = tik_inst.Tensor('float16', (size_per_head // (16 * 4), block_size // 16, 16, 16),
|
||||
name='mat_l0b_l', scope=tik.scope_cb)
|
||||
# move q from l1 to L0A for CUBE mmad
|
||||
# the shape of q in l1 is zZ
|
||||
# the shape of q in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
with tik_inst.for_range(0, block_size // 16) as bl:
|
||||
tik_inst.load2dv1(mat_l0a_l[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0,
|
||||
16 * size_per_head // (4 * 16 * 16), 1, 0, False)
|
||||
# move local k from l1 to L0B for CUBE mmad
|
||||
# the shape of local k in l1 is nZ
|
||||
# the shape of local k in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0b_l[0, 0, 0, 0], mat_l1_lk[size_per_head * gb // 64, 0, 0, 0], 0,
|
||||
block_size * size_per_head // (16 * 16 * 4), 1, 0, False)
|
||||
# matmul q and local k
|
||||
# the shape of local scores in L0C is (64, 64) zN
|
||||
tik_inst.mmad(mat_l0c_l, mat_l0a_l, mat_l0b_l,
|
||||
block_size, size_per_head // 4, block_size, 1)
|
||||
|
||||
# with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb:
|
||||
# mat_ub_l = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
||||
# name='mat_ub_l',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.data_move(mat_ub_l[0, 0, 0, 0], mat_l0c_l[gb, 0, 0, 0], 0,
|
||||
# 1, block_size // 16, 0, 0)
|
||||
# mat_ub_l_exp = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
||||
# name='mat_ub_l_exp',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vexp(64, mat_ub_l_exp[0, 0, 0, 0],
|
||||
# mat_ub_l[0, 0, 0, 0], block_size * 16 // 64, 1, 1, 8, 8)
|
||||
# # cast fp32 exp to fp16 zN
|
||||
# tik_inst.vec_conv(64, "", mat_ub_lg_exp_16[global_size // 16 + gb, 0, 0, 0],
|
||||
# mat_ub_l_exp[0, 0, 0, 0],
|
||||
# block_size * 16 // 64, 4, 8)
|
||||
|
||||
with tik_inst.new_stmt_scope():
|
||||
with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb:
|
||||
mat_ub_lg = tik_inst.Tensor("float32", (1, (block_size + global_size) // 16, 16, 16),
|
||||
name='mat_ub_lg',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_l0c_g[0, gb, 0, 0], 0,
|
||||
global_size // 16, 1, block_size // 16 - 1, 0)
|
||||
tik_inst.data_move(mat_ub_lg[0, global_size // 16, 0, 0], mat_l0c_l[0, gb, 0, 0], 0,
|
||||
block_size // 16, 1, block_size // 16 - 1, 0)
|
||||
mat_ub_lg_16 = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16),
|
||||
name='mat_ub_lg_16',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vec_conv(64, "", mat_ub_lg_16[0, 0, 0, 0],
|
||||
mat_ub_lg[0, 0, 0, 0],
|
||||
(block_size + global_size) * 16 // 64, 4, 8)
|
||||
# mat_ub_lg_max = tik_inst.Tensor("float16", (2,),
|
||||
# name='mat_ub_lg_max',
|
||||
# scope=tik.scope_ubuf)
|
||||
with tik_inst.for_range(0, 16) as lb:
|
||||
mat_ub_lg_lb = tik_inst.Tensor("float16", (block_size + global_size,),
|
||||
name='mat_ub_lg_lb',
|
||||
scope=tik.scope_ubuf)
|
||||
mat_ub_lg_lb_subs = tik_inst.Tensor("float16", (block_size + global_size,),
|
||||
name='mat_ub_lg_lb_subs',
|
||||
scope=tik.scope_ubuf)
|
||||
|
||||
tik_inst.data_move(mat_ub_lg_lb[0], mat_ub_lg_16[0, 0, lb, 0], 0,
|
||||
(block_size + global_size) // 16, 1, 15, 0)
|
||||
max_value = tik_inst.Scalar("float16",
|
||||
name='max_value',
|
||||
init_value=0)
|
||||
with tik_inst.for_range(0, (block_size + global_size) // 64) as nb:
|
||||
mat_ub_lg_max = tik_inst.Tensor("float16", (2,),
|
||||
name='mat_ub_lg_max',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vcmax(64, mat_ub_lg_max[0], mat_ub_lg_lb[64 * nb], 1,
|
||||
1, 1, 4)
|
||||
mat_ub_lg_max_sub = tik_inst.Tensor("float16", (2,),
|
||||
name='mat_ub_lg_max_sub',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vmuls(
|
||||
2, mat_ub_lg_max_sub[0], mat_ub_lg_max[0], -1.0, 1, 1, 1, 1, 1)
|
||||
block_max_value = tik_inst.Scalar("float16",
|
||||
name='block_max_value',
|
||||
init_value=0)
|
||||
block_max_value.set_as(mat_ub_lg_max_sub[0])
|
||||
max_value_int8 = tik_inst.Scalar("int8",
|
||||
name='max_value_int8',
|
||||
init_value=0)
|
||||
max_value_int = tik_inst.Tensor("int8", (1,),
|
||||
name='max_value_int',
|
||||
scope=tik.scope_ubuf)
|
||||
max_value_fp16 = tik_inst.Tensor("float16", (1,),
|
||||
name='max_value_fp16',
|
||||
scope=tik.scope_ubuf)
|
||||
max_value_fp16[0].set_as(max_value)
|
||||
block_max_value_int = tik_inst.Tensor("int8", (1,),
|
||||
name='block_max_value_int',
|
||||
scope=tik.scope_ubuf)
|
||||
block_max_value_int8 = tik_inst.Scalar("int8",
|
||||
name='block_max_value_int8',
|
||||
init_value=0)
|
||||
tik_inst.vec_conv(
|
||||
1, "", max_value_int, max_value_fp16[0], 1, 1, 1)
|
||||
tik_inst.vec_conv(
|
||||
1, "", block_max_value_int, mat_ub_lg_max_sub[0], 1, 1, 1)
|
||||
max_value_int8.set_as(max_value_int[0])
|
||||
block_max_value_int8.set_as(block_max_value_int[0])
|
||||
with tik_inst.if_scope(block_max_value_int8 < max_value_int8):
|
||||
max_value.set_as(block_max_value)
|
||||
with tik_inst.else_scope():
|
||||
block_max_value.set_as(max_value)
|
||||
tik_inst.vadds(64, mat_ub_lg_lb_subs[0], mat_ub_lg_lb[0],
|
||||
max_value, (block_size + global_size) // 64, 1, 1, 4, 4)
|
||||
mat_ub_lg_exp_lb = tik_inst.Tensor("float16", (block_size + global_size,),
|
||||
name='mat_ub_lg_exp_lb',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.vexp(64, mat_ub_lg_exp_lb[0],
|
||||
mat_ub_lg_lb_subs[0], (block_size + global_size) // 64, 1, 1, 4, 4)
|
||||
tik_inst.data_move(mat_ub_lg_exp_16[0, gb, lb, 0], mat_ub_lg_exp_lb[0], 0,
|
||||
(block_size + global_size) // 16, 1, 0, block_size - 1)
|
||||
|
||||
# max_worker = tik_inst.Tensor("float16", ((block_size + global_size) // 8,),
|
||||
# name='max_worker',
|
||||
# scope=tik.scope_ubuf)
|
||||
|
||||
# with tik_inst.for_range(0, 16) as sb:
|
||||
# max_16 = tik_inst.Tensor("float16", ((block_size + global_size) // 16, 2),
|
||||
# name='max_16',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vcmax(16, max_16[0, 0], mat_ub_lg_16[0, 0, sb, 0], (block_size + global_size) // 16,
|
||||
# 1, 1, 16)
|
||||
# real_max = tik_inst.Tensor("float16", ((block_size + global_size) // 16,),
|
||||
# name='real_max',
|
||||
# scope=tik.scope_ubuf)
|
||||
# with tik_inst.for_range(0, (block_size + global_size) // 16) as mb:
|
||||
# max_v = tik_inst.Scalar("float16",
|
||||
# name='max_v',
|
||||
# init_value=0)
|
||||
# max_v.set_as(max_16[mb, 0])
|
||||
# real_max[mb].set_as(max_v)
|
||||
#
|
||||
# tik_inst.vcmax((block_size + global_size) // 16, mat_ub_lg_max[sb], real_max[0],
|
||||
# 1, 1, 1, 2)
|
||||
# tik_inst.vec_reduce_max(16, mat_ub_lg_max[sb, 0], mat_ub_lg_16[0, 0, sb, 0],
|
||||
# max_worker, (block_size + global_size) // 16, 16, cal_index=False)
|
||||
# mat_ub_lg_max_sub = tik_inst.Tensor("float16", (2,),
|
||||
# name='mat_ub_lg_max_sub',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vmuls(2, mat_ub_lg_max_sub[0], mat_ub_lg_max[0], -1.0, 1, 1, 1, 1, 1)
|
||||
# mat_ub_lg_subs = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16),
|
||||
# name='mat_ub_lg_subs',
|
||||
# scope=tik.scope_ubuf)
|
||||
# max_value = tik_inst.Scalar("float16",
|
||||
# name='max_value',
|
||||
# init_value=0)
|
||||
# # set value for scalar prob sum rec
|
||||
# max_value.set_as(mat_ub_lg_max_sub[0])
|
||||
# max_value_int = tik_inst.Tensor("int8", (1,),
|
||||
# name='max_value_int',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vec_conv(1, "", max_value_int, mat_ub_lg_max_sub[0], 1, 1, 1)
|
||||
# max_value_ints = tik_inst.Scalar("int8",
|
||||
# name='max_value_ints',
|
||||
# init_value=0)
|
||||
# max_value_ints.set_as(max_value_int[0])
|
||||
# with tik_inst.if_scope(max_value_ints > 0):
|
||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
||||
# 0.0, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
||||
# with tik_inst.else_scope():
|
||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
||||
# max_value, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
||||
# 5, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
||||
# with tik_inst.for_range(0, 16) as sb:
|
||||
# max_value = tik_inst.Scalar("float16",
|
||||
# name='max_value',
|
||||
# init_value=0)
|
||||
# # set value for scalar prob sum rec
|
||||
# max_value.set_as(mat_ub_lg_max_sub[sb])
|
||||
# tik_inst.vadds(16, mat_ub_lg_subs[0, 0, sb, 0], mat_ub_lg_16[0, 0, sb, 0],
|
||||
# max_value, (block_size + global_size) // 16, 1, 1, 16, 16)
|
||||
# mat_ub_lg_exp = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16),
|
||||
# name='mat_ub_lg_exp',
|
||||
# scope=tik.scope_ubuf)
|
||||
# tik_inst.vexp(128, mat_ub_lg_exp[0, 0, 0, 0],
|
||||
# mat_ub_lg_subs[0, 0, 0, 0], (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
||||
# tik_inst.data_move(mat_ub_lg_exp_16[0, gb, 0, 0], mat_ub_lg_exp[0, 0, 0, 0], 0,
|
||||
# (block_size + global_size) // 16, 16, 0, block_size - 16)
|
||||
|
||||
# move exp fp16 from ub to L1 for CUBE mmad
|
||||
# the shape of exp fp16 in ub is zN
|
||||
# the shape of exp fp16 in L1 is zN
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 4 times
|
||||
tik_inst.data_move(mat_l1_lg_exp_16[0, 0, 0, 0], mat_ub_lg_exp_16[0, 0, 0, 0],
|
||||
0, (global_size + block_size) // 16, block_size, 0, 0)
|
||||
# apply for tensor in UB for local attention out (64, 64) zN
|
||||
mat_ub_l_out = tik_inst.Tensor("float16", (block_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_ub_l_out',
|
||||
scope=tik.scope_ubuf)
|
||||
# apply for tensor in UB for global attention out (64, 256) zN
|
||||
mat_ub_g_out = tik_inst.Tensor("float16", (global_size // 16, block_size // 16, 16, 16),
|
||||
name='mat_ub_g_out',
|
||||
scope=tik.scope_ubuf)
|
||||
# apply tensor in l0c for exp sum (16, 64) zN
|
||||
mat_l0c_exp = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
||||
name='mat_l0c_exp',
|
||||
scope=tik.scope_cc)
|
||||
# apply tensor in ub for exp sum (16, 64) zN
|
||||
mat_ub_exp_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
||||
name='mat_ub_exp_sum',
|
||||
scope=tik.scope_ubuf)
|
||||
|
||||
with tik_inst.new_stmt_scope():
|
||||
with tik_inst.for_range(0, 4, thread_num=2) as gb:
|
||||
# apply for tensor in L0A for q (64, 128) zZ
|
||||
mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 64, 16, 16),
|
||||
name='mat_l0a_ones', scope=tik.scope_ca)
|
||||
# apply for tensor in L0B for exp (350, 64) nZ
|
||||
mat_l0b_exp = tik_inst.Tensor('float16',
|
||||
((global_size + block_size) //
|
||||
64, block_size // 16, 16, 16),
|
||||
name='mat_l0b_exp', scope=tik.scope_cb)
|
||||
# move ones from l1 to L0A for CUBE mmad
|
||||
# the shape of ones in l1 is zZ
|
||||
# the shape of ones in L0A is zZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 32 times
|
||||
tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0,
|
||||
(global_size + block_size) * 16 // (4 * 16 * 16), 1, 0, False)
|
||||
# move global k from l1 to L0B for CUBE mmad
|
||||
# the shape of global k in l1 is nZ
|
||||
# the shape of global k in L0B is nZ
|
||||
# the stride between each (16, 16) is 0
|
||||
# repeat 128 times
|
||||
tik_inst.load2dv1(mat_l0b_exp[0, 0, 0, 0],
|
||||
mat_l1_lg_exp_16[(
|
||||
global_size + block_size) * gb // 64, 0, 0, 0], 0,
|
||||
(global_size + block_size) * block_size // (4 * 16 * 16), 1, 0, False)
|
||||
with tik_inst.if_scope(gb == 0):
|
||||
tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16,
|
||||
(global_size + block_size) // 4, block_size, 0)
|
||||
with tik_inst.else_scope():
|
||||
tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16,
|
||||
(global_size + block_size) // 4, block_size, 1)
|
||||
|
||||
tik_inst.data_move(mat_ub_exp_sum[0, 0, 0, 0], mat_l0c_exp[0, 0, 0, 0], 0,
|
||||
block_size // 16, 1, 0, 0)
|
||||
# apply for tensor in UB for global prob sum (64,)
|
||||
mat_ub_lg_exp_sum = tik_inst.Tensor("float32", (block_size,),
|
||||
name='mat_ub_lg_exp_sum',
|
||||
scope=tik.scope_ubuf)
|
||||
tik_inst.data_move(mat_ub_lg_exp_sum[0], mat_ub_exp_sum[0, 0, 0, 0],
|
||||
0, block_size // 16, 1 * 2, 15 * 2, 0)
|
||||
# apply for tensor in UB for attention prob sum rec (64,)
|
||||
mat_ub_exp_sum_rec = tik_inst.Tensor("float32", (block_size,),
|
||||
name='mat_ub_exp_sum_rec',
|
||||
scope=tik.scope_ubuf)
|
||||
mat_ub_exp_sum_rec_16 = tik_inst.Tensor("float16", (block_size,),
|
||||
name='mat_ub_exp_sum_rec_16',
|
||||
scope=tik.scope_ubuf)
|
||||
worker_tensor = tik_inst.Tensor("float32", (block_size * 2,),
|
||||
name='worker_tensor',
|
||||
scope=tik.scope_ubuf)
|
||||
# calculate attention prob sum vec (64,)
|
||||
tik_inst.vec_rec_high_preci(
|
||||
64, mat_ub_exp_sum_rec, mat_ub_lg_exp_sum, worker_tensor, 1, 8, 8)
|
||||
# tik_instance.tikdb.debug_print('mat_ub_exp_sum_rec')
|
||||
tik_inst.vec_conv(block_size, "", mat_ub_exp_sum_rec_16[0],
|
||||
mat_ub_exp_sum_rec[0],
|
||||
block_size // 64, 4, 8)
|
||||
with tik_inst.for_range(0, block_size) as bbs:
|
||||
# apply for scalar in UB for prob sum rec
|
||||
sum_exp = tik_inst.Scalar("float16",
|
||||
name='sum_exp',
|
||||
init_value=0)
|
||||
# set value for scalar prob sum rec
|
||||
sum_exp.set_as(mat_ub_exp_sum_rec_16[bbs])
|
||||
tik_inst.vec_muls(16, mat_ub_l_out[0, bbs // 16, bbs % 16, 0],
|
||||
mat_ub_lg_exp_16[global_size //
|
||||
16, bbs // 16, bbs % 16, 0],
|
||||
sum_exp, block_size // 16,
|
||||
block_size, block_size)
|
||||
tik_inst.vec_muls(16, mat_ub_g_out[0, bbs // 16, bbs % 16, 0],
|
||||
mat_ub_lg_exp_16[0, bbs // 16, bbs %
|
||||
16, 0], sum_exp, global_size // 16,
|
||||
block_size, block_size)
|
||||
# move local out from UB to gm
|
||||
# the shape of local out in UB is zN
|
||||
# the shape of local out in gm is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 4 times
|
||||
# head_id = str(head)
|
||||
# block_id = str(block)
|
||||
# mat_ub_l_out_shape = str(mat_ub_l_out.shape)
|
||||
# tik_inst.tikdb.debug_print(head_id)
|
||||
# tik_inst.tikdb.debug_print(block_id)
|
||||
# tik_inst.tikdb.debug_print(mat_ub_l_out_shape)
|
||||
tik_inst.data_move(mat_lc[b, head, block, 0, 0, 0, 0], mat_ub_l_out[0, 0, 0, 0], 0,
|
||||
block_size // 16, block_size, 0, 0)
|
||||
# move global out from UB to gm
|
||||
# the shape of global out in UB is zN
|
||||
# the shape of global out in gm is zN
|
||||
# the stride between each (16, 64) is 0
|
||||
# repeat 16 times
|
||||
tik_inst.data_move(mat_gc[b, head, block, 0, 0, 0, 0], mat_ub_g_out[0, 0, 0, 0], 0,
|
||||
global_size // 16, block_size, 0, 0)
|
||||
|
||||
tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lm, mat_gm],
|
||||
outputs=[mat_lc, mat_gc])
|
||||
return tik_inst
|
Loading…
Reference in New Issue