forked from mindspore-Ecosystem/mindspore
!25534 Fix Code Stype For MatMulDDS
Merge pull request !25534 from huangxinjing/origin/master
This commit is contained in:
commit
267c67f40f
|
@ -72,7 +72,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
|
||||||
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel
|
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel
|
||||||
mindspore/mindspore/ops/_op_impl/_custom_op/dsd_impl.py:DSDMatmulimpl
|
mindspore/mindspore/ops/_op_impl/_custom_op/dsd_impl.py:dsd_matmul
|
||||||
mindspore/mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py:dsdbpropimpl
|
mindspore/mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py:dsdbpropimpl
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c:Conv1x1SW3x32Kernel, Conv1x1SW4x24Kernel, Conv1x1SW12x8Kernel, Conv1x1SW8x8Kernel, Conv1x1SW4x8Kernel, Conv1x1SW6x16Kernel, Conv1x1SW4x16Kernel, Conv1x1SW1x32Kernel, Conv1x1SW1x24Kernel, Conv1x1SW1x16Kernel, Conv1x1SW1x8Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c:Conv1x1SW3x32Kernel, Conv1x1SW4x24Kernel, Conv1x1SW12x8Kernel, Conv1x1SW8x8Kernel, Conv1x1SW4x8Kernel, Conv1x1SW6x16Kernel, Conv1x1SW4x16Kernel, Conv1x1SW1x32Kernel, Conv1x1SW1x24Kernel, Conv1x1SW1x16Kernel, Conv1x1SW1x8Kernel
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
|
||||||
|
|
|
@ -33,7 +33,7 @@ from .fake_quant_perlayer import _fake_quant_per_layer_tbe
|
||||||
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
|
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
|
||||||
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
|
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
|
||||||
from .minmax_update_perlayer import _minmax_update_perlayer_tbe
|
from .minmax_update_perlayer import _minmax_update_perlayer_tbe
|
||||||
from .matmul_dds_impl import MatmulDDSImpl
|
from .matmul_dds_impl import matmul_dds
|
||||||
from .matmul_dds_grad_impl import matmul_dds_grad
|
from .matmul_dds_grad_impl import matmul_dds_grad
|
||||||
from .dsd_impl import DSDMatmulimpl
|
from .dsd_impl import dsd_matmul
|
||||||
from .dsd_back_impl import dsdbpropimpl
|
from .dsd_back_impl import dsdbpropimpl
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""dsd back impl"""
|
"""dsd back impl"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
|
||||||
from te import tik
|
from te import tik
|
||||||
from topi.cce import util
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
||||||
|
|
||||||
dsd_grad_info = TBERegOp('DSDGrad') \
|
dsd_grad_info = TBERegOp('DSDGrad') \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
|
@ -48,11 +47,11 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
else:
|
else:
|
||||||
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
||||||
|
|
||||||
# (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
# shape is:(batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
||||||
input_w1_shape = w1_gm.get('shape')
|
input_w1_shape = w1_gm.get('shape')
|
||||||
# (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
# shape is:(batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
||||||
input_w2_shape = w2_gm.get('shape')
|
input_w2_shape = w2_gm.get('shape')
|
||||||
# (batch_size, seq_len//16, 16, head, v_embedding//16, 16)
|
# shape is:(batch_size, seq_len//16, 16, head, v_embedding//16, 16)
|
||||||
input_v_shape = v_gm.get('shape')
|
input_v_shape = v_gm.get('shape')
|
||||||
|
|
||||||
batch_size = input_w1_shape[0]
|
batch_size = input_w1_shape[0]
|
||||||
|
@ -64,15 +63,6 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
v_embedding = input_v_shape[1] * 16 // head
|
v_embedding = input_v_shape[1] * 16 // head
|
||||||
seq_len = input_v_shape[0] * 16 // batch_size
|
seq_len = input_v_shape[0] * 16 // batch_size
|
||||||
|
|
||||||
# batch_size = 1
|
|
||||||
# head = 1
|
|
||||||
# block_num = 1024 // 64
|
|
||||||
# block_size = 64
|
|
||||||
# head_size = 64
|
|
||||||
# global_size = 256
|
|
||||||
# v_embedding = 128
|
|
||||||
# seq_len = 1024
|
|
||||||
|
|
||||||
block_bite_size = 32
|
block_bite_size = 32
|
||||||
|
|
||||||
# 4, 16, 1024//64, 64//16, 64//16, 16*16
|
# 4, 16, 1024//64, 64//16, 64//16, 16*16
|
||||||
|
@ -88,7 +78,7 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
scope=tik.scope_gm)
|
scope=tik.scope_gm)
|
||||||
|
|
||||||
v_gm = tik_inst.Tensor('float16',
|
v_gm = tik_inst.Tensor('float16',
|
||||||
(batch_size*seq_len//16, head*v_embedding//16, 16, 16),
|
(batch_size * seq_len // 16, head * v_embedding // 16, 16, 16),
|
||||||
name='v_gm',
|
name='v_gm',
|
||||||
scope=tik.scope_gm)
|
scope=tik.scope_gm)
|
||||||
|
|
||||||
|
@ -124,12 +114,8 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
scope=tik.scope_gm)
|
scope=tik.scope_gm)
|
||||||
|
|
||||||
# v-nZ
|
# v-nZ
|
||||||
# d_v_gm = tik_inst.Tensor('float16',
|
|
||||||
# (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16),
|
|
||||||
# name='d_v_gm',
|
|
||||||
# scope=tik.scope_gm)
|
|
||||||
d_v_gm = tik_inst.Tensor('float16',
|
d_v_gm = tik_inst.Tensor('float16',
|
||||||
(batch_size*seq_len//16, head*v_embedding//16, 16, 16),
|
(batch_size * seq_len // 16, head * v_embedding // 16, 16, 16),
|
||||||
name='d_v_gm',
|
name='d_v_gm',
|
||||||
scope=tik.scope_gm)
|
scope=tik.scope_gm)
|
||||||
|
|
||||||
|
@ -140,21 +126,19 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
global_idx = 3 - head_idx % 4
|
global_idx = 3 - head_idx % 4
|
||||||
# tensor size // (byte * l0b size * thread)
|
# tensor size // (byte * l0b size * thread)
|
||||||
cpt_time = 1 if global_size * v_embedding * \
|
cpt_time = 1 if global_size * v_embedding * \
|
||||||
4//(1024 * 64) <= 1 else global_size * v_embedding * 4//(1024 * 64)
|
4 // (1024 * 64) <= 1 else global_size * v_embedding * 4 // (1024 * 64)
|
||||||
ub_time = 1 if global_size == 256 else 2
|
ub_time = 1 if global_size == 256 else 2
|
||||||
|
|
||||||
# d_a_gm = (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
|
|
||||||
d_a_l1 = tik_inst.Tensor('float16', (seq_len // 16, v_embedding // 16, 16, 16),
|
d_a_l1 = tik_inst.Tensor('float16', (seq_len // 16, v_embedding // 16, 16, 16),
|
||||||
name='d_a_l1', scope=tik.scope_cbuf)
|
name='d_a_l1', scope=tik.scope_cbuf)
|
||||||
|
|
||||||
with tik_inst.for_range(0, v_embedding//16) as brick_i:
|
with tik_inst.for_range(0, v_embedding // 16) as brick_i:
|
||||||
tik_inst.data_move(d_a_l1[0, brick_i, 0, 0], d_a_gm[bs_idx, head_idx, brick_i, 0, 0, 0], 0,
|
tik_inst.data_move(d_a_l1[0, brick_i, 0, 0], d_a_gm[bs_idx, head_idx, brick_i, 0, 0, 0], 0,
|
||||||
seq_len//16, 16*16*2//block_bite_size,
|
seq_len // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
0, (v_embedding//16-1)*16*16*2//block_bite_size)
|
0, (v_embedding // 16 - 1) * 16 * 16 * 2 // block_bite_size)
|
||||||
|
|
||||||
# dv
|
# dv
|
||||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||||
|
|
||||||
d_v_l0c = tik_inst.Tensor('float32', (v_embedding // 16, head_size // 16, 16, 16),
|
d_v_l0c = tik_inst.Tensor('float32', (v_embedding // 16, head_size // 16, 16, 16),
|
||||||
name='d_v_local_l0c', scope=tik.scope_cc)
|
name='d_v_local_l0c', scope=tik.scope_cc)
|
||||||
d_v_ub = tik_inst.Tensor('float16', (v_embedding // 16, head_size // 16, 16, 16),
|
d_v_ub = tik_inst.Tensor('float16', (v_embedding // 16, head_size // 16, 16, 16),
|
||||||
|
@ -170,31 +154,30 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
|
|
||||||
# d_v_local
|
# d_v_local
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
w_local_l1 = tik_inst.Tensor('float16', (head_size//16, block_size//16, 16, 16),
|
w_local_l1 = tik_inst.Tensor('float16', (head_size // 16, block_size // 16, 16, 16),
|
||||||
name='w_local_l1', scope=tik.scope_cbuf)
|
name='w_local_l1', scope=tik.scope_cbuf)
|
||||||
w_local_l0a = tik_inst.Tensor('float16', (head_size//16, block_size//16, 16, 16),
|
w_local_l0a = tik_inst.Tensor('float16', (head_size // 16, block_size // 16, 16, 16),
|
||||||
name='w_local_l0a', scope=tik.scope_ca)
|
name='w_local_l0a', scope=tik.scope_ca)
|
||||||
|
|
||||||
d_a_l0b = tik_inst.Tensor('float16', (block_size//16, v_embedding//16, 16, 16),
|
d_a_l0b = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
||||||
name='d_a_l0b', scope=tik.scope_cb)
|
name='d_a_l0b', scope=tik.scope_cb)
|
||||||
|
|
||||||
# (batch_size, head, block_num, head_size // 16, block_size // 16, 16, 16)
|
|
||||||
tik_inst.data_move(w_local_l1[0, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], 0,
|
tik_inst.data_move(w_local_l1[0, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], 0,
|
||||||
1, (block_size*head_size*2)//block_bite_size,
|
1, (block_size * head_size * 2) // block_bite_size,
|
||||||
0, 0)
|
0, 0)
|
||||||
|
|
||||||
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx * block_size//16, 0, 0, 0], 0,
|
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx * block_size // 16, 0, 0, 0], 0,
|
||||||
(block_size*v_embedding)//(16*16), 1, 0, True)
|
(block_size * v_embedding) // (16 * 16), 1, 0, True)
|
||||||
|
|
||||||
tik_inst.load2dv1(w_local_l0a[0, 0, 0, 0], w_local_l1[0, 0, 0, 0],
|
tik_inst.load2dv1(w_local_l0a[0, 0, 0, 0], w_local_l1[0, 0, 0, 0],
|
||||||
0, (head_size*block_size)//(16*16),
|
0, (head_size * block_size) // (16 * 16),
|
||||||
1, 0, True)
|
1, 0, True)
|
||||||
|
|
||||||
tik_inst.mmad(d_v_l0c, w_local_l0a, d_a_l0b,
|
tik_inst.mmad(d_v_l0c, w_local_l0a, d_a_l0b,
|
||||||
head_size, block_size, v_embedding, 0)
|
head_size, block_size, v_embedding, 0)
|
||||||
|
|
||||||
tik_inst.data_move(d_v_ub_32[0, 0, 0, 0], d_v_l0c[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_v_ub_32[0, 0, 0, 0], d_v_l0c[0, 0, 0, 0], 0,
|
||||||
1, (v_embedding * head_size)*4//1024, 0, 0)
|
1, (v_embedding * head_size) * 4 // 1024, 0, 0)
|
||||||
|
|
||||||
# d_v_global
|
# d_v_global
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
|
@ -203,24 +186,20 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
w_global_l0a = tik_inst.Tensor('float16', (1, head_size // 16, 16, 16),
|
w_global_l0a = tik_inst.Tensor('float16', (1, head_size // 16, 16, 16),
|
||||||
name='w_global_l0a', scope=tik.scope_ca)
|
name='w_global_l0a', scope=tik.scope_ca)
|
||||||
|
|
||||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
|
||||||
d_a_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16),
|
d_a_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16),
|
||||||
name='d_a_l0b', scope=tik.scope_cb)
|
name='d_a_l0b', scope=tik.scope_cb)
|
||||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx_1:
|
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx_1:
|
||||||
# w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx_1 * (block_size // 16), 0, 0, 0], 0,
|
||||||
# (1, head_size // 16, 16, 16)
|
(head_size * v_embedding) // (16 * 16), 1, 0, True)
|
||||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
|
||||||
tik_inst.load2dv1(d_a_l0b[0, 0, 0, 0], d_a_l1[w_idx_1*(block_size//16), 0, 0, 0], 0,
|
|
||||||
(head_size * v_embedding)//(16*16), 1, 0, True)
|
|
||||||
|
|
||||||
tik_inst.data_move(w_global_l1[0, 0, 0, 0], w2_gm[bs_idx, head_idx, w_idx_1, w_idx, 0, 0, 0], 0,
|
tik_inst.data_move(w_global_l1[0, 0, 0, 0], w2_gm[bs_idx, head_idx, w_idx_1, w_idx, 0, 0, 0], 0,
|
||||||
head_size // 16, 16*16*2//block_bite_size,
|
head_size // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
0, 0)
|
0, 0)
|
||||||
tik_inst.load2dv1(w_global_l0a[0, 0, 0, 0], w_global_l1[0, 0, 0, 0], 0,
|
tik_inst.load2dv1(w_global_l0a[0, 0, 0, 0], w_global_l1[0, 0, 0, 0], 0,
|
||||||
16 * head_size // (16 * 16),
|
16 * head_size // (16 * 16),
|
||||||
1, 0, True)
|
1, 0, True)
|
||||||
|
|
||||||
# d_v_l0c = (v_embedding // 16, head_size // 16, 16, 16)
|
# shape: d_v_l0c = (v_embedding // 16, head_size // 16, 16, 16)
|
||||||
with tik_inst.if_scope(w_idx_1 == 0):
|
with tik_inst.if_scope(w_idx_1 == 0):
|
||||||
tik_inst.mmad(d_v_global_32_l0c, w_global_l0a, d_a_l0b,
|
tik_inst.mmad(d_v_global_32_l0c, w_global_l0a, d_a_l0b,
|
||||||
16, head_size, v_embedding, 0)
|
16, head_size, v_embedding, 0)
|
||||||
|
@ -229,34 +208,31 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
16, head_size, v_embedding, 1)
|
16, head_size, v_embedding, 1)
|
||||||
|
|
||||||
tik_inst.data_move(d_v_global_32_ub[0, 0, 0, 0], d_v_global_32_l0c[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_v_global_32_ub[0, 0, 0, 0], d_v_global_32_l0c[0, 0, 0, 0], 0,
|
||||||
1, v_embedding*16*4//1024, 0, 0)
|
1, v_embedding * 16 * 4 // 1024, 0, 0)
|
||||||
|
|
||||||
with tik_inst.for_range(0, 4) as cpt_i:
|
with tik_inst.for_range(0, 4) as cpt_i:
|
||||||
tik_inst.vadd(64, d_v_ub_32[0, global_idx, cpt_i*4, 0], d_v_ub_32[0, global_idx, cpt_i*4, 0],
|
tik_inst.vadd(64, d_v_ub_32[0, global_idx, cpt_i * 4, 0], d_v_ub_32[0, global_idx, cpt_i * 4, 0],
|
||||||
d_v_global_32_ub[0, 0,
|
d_v_global_32_ub[0, 0,
|
||||||
cpt_i*4, 0], v_embedding//16,
|
cpt_i * 4, 0], v_embedding // 16,
|
||||||
1, 1, 1,
|
1, 1, 1,
|
||||||
head_size*16*4//block_bite_size, head_size*16*4//block_bite_size,
|
head_size * 16 * 4 // block_bite_size, head_size * 16 * 4 // block_bite_size,
|
||||||
16*16*4//block_bite_size)
|
16 * 16 * 4 // block_bite_size)
|
||||||
|
|
||||||
tik_inst.vconv(64, '', d_v_ub[0, 0, 0, 0], d_v_ub_32[0, 0, 0, 0],
|
tik_inst.vconv(64, '', d_v_ub[0, 0, 0, 0], d_v_ub_32[0, 0, 0, 0],
|
||||||
v_embedding * head_size//64, 1, 1, 4, 8)
|
v_embedding * head_size // 64, 1, 1, 4, 8)
|
||||||
|
|
||||||
with tik_inst.for_range(0, head_size // 16) as h_idx:
|
with tik_inst.for_range(0, head_size // 16) as h_idx:
|
||||||
with tik_inst.for_range(0, v_embedding//16) as v_idx:
|
with tik_inst.for_range(0, v_embedding // 16) as v_idx:
|
||||||
tik_inst.vtranspose(
|
tik_inst.vtranspose(
|
||||||
d_v_ub[v_idx, h_idx, 0, 0], d_v_ub[v_idx, h_idx, 0, 0])
|
d_v_ub[v_idx, h_idx, 0, 0], d_v_ub[v_idx, h_idx, 0, 0])
|
||||||
tik_inst.data_move(d_v_gm[bs_idx*seq_len//16+w_idx * (block_size // 16) + h_idx,
|
tik_inst.data_move(d_v_gm[bs_idx * seq_len // 16 + w_idx * (block_size // 16) + h_idx,
|
||||||
head_idx*v_embedding//16, 0, 0],
|
head_idx * v_embedding // 16, 0, 0],
|
||||||
d_v_ub[0, h_idx, 0, 0], 0,
|
d_v_ub[0, h_idx, 0, 0], 0,
|
||||||
v_embedding // 16, 16 * 16 * 2 // block_bite_size,
|
v_embedding // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
(head_size // 16 - 1) * 16 * 16 * 2 // 32, 0)
|
(head_size // 16 - 1) * 16 * 16 * 2 // 32, 0)
|
||||||
|
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
# dw = da * v^t
|
|
||||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||||
# d_local_l1 = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
|
||||||
# name='d_local_l1', scope=tik.scope_cbuf)
|
|
||||||
d_local_l0a = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
d_local_l0a = tik_inst.Tensor('float16', (block_size // 16, v_embedding // 16, 16, 16),
|
||||||
name='d_local_l0a', scope=tik.scope_ca)
|
name='d_local_l0a', scope=tik.scope_ca)
|
||||||
|
|
||||||
|
@ -275,74 +251,66 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
d_w_local_ub = tik_inst.Tensor('float16', (head_size // 16, block_size // 16, 16, 16),
|
d_w_local_ub = tik_inst.Tensor('float16', (head_size // 16, block_size // 16, 16, 16),
|
||||||
name='d_w_local_ub', scope=tik.scope_ubuf)
|
name='d_w_local_ub', scope=tik.scope_ubuf)
|
||||||
|
|
||||||
tik_inst.load2dv1(d_local_l0a[0, 0, 0, 0], d_a_l1[w_idx*(block_size//16), 0, 0, 0],
|
tik_inst.load2dv1(d_local_l0a[0, 0, 0, 0], d_a_l1[w_idx * (block_size // 16), 0, 0, 0],
|
||||||
0, (block_size*v_embedding)//(16*16), 1, 0, False)
|
0, (block_size * v_embedding) // (16 * 16), 1, 0, False)
|
||||||
|
|
||||||
# v_gm = (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16)
|
# shape is: v_gm = (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16)
|
||||||
# v_local_l1 = (v_embedding//16, head_size//16, 16, 16)
|
# shape is: v_local_l1 = (v_embedding//16, head_size//16, 16, 16)
|
||||||
with tik_inst.for_range(0, head_size//16) as brick_i:
|
with tik_inst.for_range(0, head_size // 16) as brick_i:
|
||||||
tik_inst.data_move(v_local_l1[0, brick_i, 0, 0],
|
tik_inst.data_move(v_local_l1[0, brick_i, 0, 0],
|
||||||
v_gm[bs_idx*seq_len//16+w_idx *
|
v_gm[bs_idx * seq_len // 16 + w_idx *
|
||||||
(head_size//16)+brick_i, head_idx*v_embedding//16, 0, 0],
|
(head_size // 16) + brick_i, head_idx * v_embedding // 16, 0, 0],
|
||||||
0, v_embedding//16, 16*16*2//block_bite_size,
|
0, v_embedding // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
0, (head_size//16-1)*16*16*2//block_bite_size)
|
0, (head_size // 16 - 1) * 16 * 16 * 2 // block_bite_size)
|
||||||
|
|
||||||
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0],
|
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0],
|
||||||
0, v_embedding*head_size//(16*16), 1, 0, True)
|
0, v_embedding * head_size // (16 * 16), 1, 0, True)
|
||||||
|
|
||||||
# dw
|
# dw
|
||||||
tik_inst.mmad(d_w_local_l0c, d_local_l0a, v_local_l0b,
|
tik_inst.mmad(d_w_local_l0c, d_local_l0a, v_local_l0b,
|
||||||
block_size, v_embedding, head_size, 0)
|
block_size, v_embedding, head_size, 0)
|
||||||
|
|
||||||
tik_inst.data_move(d_w_local_ub_32[0, 0, 0, 0], d_w_local_l0c[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_w_local_ub_32[0, 0, 0, 0], d_w_local_l0c[0, 0, 0, 0], 0,
|
||||||
1, head_size*block_size*4//1024,
|
1, head_size * block_size * 4 // 1024,
|
||||||
0, 0)
|
0, 0)
|
||||||
|
|
||||||
tik_inst.vconv(64, '', d_w_local_ub[0, 0, 0, 0], d_w_local_ub_32[0, 0, 0, 0],
|
tik_inst.vconv(64, '', d_w_local_ub[0, 0, 0, 0], d_w_local_ub_32[0, 0, 0, 0],
|
||||||
head_size * block_size//64, 1, 1, 4, 8)
|
head_size * block_size // 64, 1, 1, 4, 8)
|
||||||
|
|
||||||
# d_w1_gm = (batch_size, head, block_num, head_size // 16, block_size // 16, 16, 16)
|
|
||||||
tik_inst.data_move(d_w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_local_ub[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_w1_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_local_ub[0, 0, 0, 0], 0,
|
||||||
1, head_size*block_size*2//block_bite_size,
|
1, head_size * block_size * 2 // block_bite_size,
|
||||||
0, 0)
|
0, 0)
|
||||||
|
|
||||||
# calculate d_w_global
|
# calculate d_w_global
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
# load2d permute
|
# load2d permute
|
||||||
v_global_l1 = tik_inst.Tensor('float16', (v_embedding//16, global_size//16, 16, 16),
|
v_global_l1 = tik_inst.Tensor('float16', (v_embedding // 16, global_size // 16, 16, 16),
|
||||||
name='v_global_l1', scope=tik.scope_cbuf)
|
name='v_global_l1', scope=tik.scope_cbuf)
|
||||||
|
|
||||||
# v_gm = (batch_size, seq_len // 16, head, v_embedding // 16, 16, 16)
|
|
||||||
# tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx, global_idx, head_idx, 0, 0, 0], 0,
|
|
||||||
# seq_len//(4*16), (16*v_embedding*2)//block_bite_size,
|
|
||||||
# ((4*head*v_embedding*16-16*v_embedding)*2)//block_bite_size, 0)
|
|
||||||
with tik_inst.for_range(0, block_num) as w_idx:
|
with tik_inst.for_range(0, block_num) as w_idx:
|
||||||
tik_inst.data_move(v_global_l1[0, w_idx, 0, 0],
|
tik_inst.data_move(v_global_l1[0, w_idx, 0, 0],
|
||||||
v_gm[bs_idx*seq_len//16 + (
|
v_gm[bs_idx * seq_len // 16 + (
|
||||||
w_idx * (block_size//16) + global_idx), head_idx * v_embedding//16, 0, 0],
|
w_idx * (
|
||||||
0, v_embedding//16, 16*16*2//block_bite_size,
|
block_size // 16) + global_idx), head_idx * v_embedding // 16, 0, 0],
|
||||||
0, (global_size // 16 - 1)*16*16*2//block_bite_size)
|
0, v_embedding // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
|
0, (global_size // 16 - 1) * 16 * 16 * 2 // block_bite_size)
|
||||||
|
|
||||||
with tik_inst.for_range(0, block_num * ub_time, thread_num=2) as w_idx:
|
with tik_inst.for_range(0, block_num * ub_time, thread_num=2) as w_idx:
|
||||||
# tik_inst.tikdb.debug_print("'v_embedding//(16*cpt_time): '+str(v_embedding//(16*cpt_time))")
|
d_global_l0a = tik_inst.Tensor('float16', (head_size // (16 * ub_time),
|
||||||
# d_global_l1 = tik_inst.Tensor('float16', (head_size//16, v_embedding//(16*cpt_time), 16, 16),
|
v_embedding // (16 * cpt_time), 16, 16),
|
||||||
# name='d_global_l1', scope=tik.scope_cbuf)
|
|
||||||
d_global_l0a = tik_inst.Tensor('float16', (head_size // (16*ub_time),
|
|
||||||
v_embedding // (16*cpt_time), 16, 16),
|
|
||||||
name='d_global_l0a', scope=tik.scope_ca)
|
name='d_global_l0a', scope=tik.scope_ca)
|
||||||
|
|
||||||
v_global_l0b = tik_inst.Tensor('float16', (v_embedding // (16*cpt_time),
|
v_global_l0b = tik_inst.Tensor('float16', (v_embedding // (16 * cpt_time),
|
||||||
global_size // 16, 16, 16),
|
global_size // 16, 16, 16),
|
||||||
name='v_global_l0b', scope=tik.scope_cb)
|
name='v_global_l0b', scope=tik.scope_cb)
|
||||||
|
|
||||||
# d_w_global,小z大n
|
# d_w_global,小z大n
|
||||||
d_w_global_l0c = tik_inst.Tensor('float32', (global_size//16, head_size//(16*ub_time), 16, 16),
|
d_w_global_l0c = tik_inst.Tensor('float32', (global_size // 16, head_size // (16 * ub_time), 16, 16),
|
||||||
name='d_w_global_l0c', scope=tik.scope_cc)
|
name='d_w_global_l0c', scope=tik.scope_cc)
|
||||||
d_w_global_ub = tik_inst.Tensor('float16', (global_size // 16,
|
d_w_global_ub = tik_inst.Tensor('float16', (global_size // 16,
|
||||||
head_size // (16*ub_time), 16, 16),
|
head_size // (16 * ub_time), 16, 16),
|
||||||
name='d_w_global_ub', scope=tik.scope_ubuf)
|
name='d_w_global_ub', scope=tik.scope_ubuf)
|
||||||
d_w_global_ub_32 = tik_inst.Tensor('float32', (global_size // 16,
|
d_w_global_ub_32 = tik_inst.Tensor('float32', (global_size // 16,
|
||||||
head_size // (16*ub_time), 16, 16),
|
head_size // (16 * ub_time), 16, 16),
|
||||||
name='d_w_global_ub_32', scope=tik.scope_ubuf)
|
name='d_w_global_ub_32', scope=tik.scope_ubuf)
|
||||||
|
|
||||||
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
||||||
|
@ -350,81 +318,51 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
||||||
v_global_l1[cpt_idx * v_embedding //
|
v_global_l1[cpt_idx * v_embedding //
|
||||||
(16 * cpt_time), 0, 0, 0], 0,
|
(16 * cpt_time), 0, 0, 0], 0,
|
||||||
global_size * v_embedding // (16 * 16 * cpt_time), 1, 0, True)
|
global_size * v_embedding // (16 * 16 * cpt_time), 1, 0, True)
|
||||||
# d_global_gm = (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
|
with tik_inst.for_range(0, head_size // (16 * ub_time)) as brick_i:
|
||||||
# d_global_l1 = (head_size//16, v_embedding//(16*cpt_time), 16, 16)
|
|
||||||
# with tik_inst.for_range(0, v_embedding//(16*cpt_time)) as brick_i:
|
|
||||||
# tik_inst.data_move(d_global_l1[0, brick_i, 0, 0],
|
|
||||||
# d_global_gm[bs_idx,
|
|
||||||
# head_idx, brick_i + cpt_idx * (v_embedding//(16*cpt_time)),
|
|
||||||
# w_idx*(head_size//16), 0, 0], 0,
|
|
||||||
# head_size//16, (16*16*2)//block_bite_size,
|
|
||||||
# 0, (v_embedding//(16*cpt_time)-1)*16*16*2//block_bite_size)
|
|
||||||
#
|
|
||||||
# tik_inst.load2dv1(d_global_l0a[0, 0, 0, 0], d_global_l1[0, 0, 0, 0],
|
|
||||||
# 0, (head_size*v_embedding)//(16*16*cpt_time), 1, 0, False)
|
|
||||||
|
|
||||||
# d_a_l1 = (seq_len // 16, v_embedding // 16, 16, 16)
|
|
||||||
# d_global_l0a = (head_size // 16, v_embedding // (16*cpt_time), 16, 16)
|
|
||||||
with tik_inst.for_range(0, head_size//(16*ub_time)) as brick_i:
|
|
||||||
tik_inst.load2dv1(d_global_l0a[brick_i, 0, 0, 0],
|
tik_inst.load2dv1(d_global_l0a[brick_i, 0, 0, 0],
|
||||||
d_a_l1[w_idx*(block_size//(16*ub_time)) + brick_i,
|
d_a_l1[w_idx * (block_size // (16 * ub_time)) + brick_i,
|
||||||
cpt_idx*v_embedding//(16*cpt_time), 0, 0],
|
cpt_idx * v_embedding // (16 * cpt_time), 0, 0],
|
||||||
0, (16*v_embedding)//(16*16*cpt_time), 1, 0, False)
|
0, (16 * v_embedding) // (16 * 16 * cpt_time), 1, 0, False)
|
||||||
|
|
||||||
# (head_size, global_size) = (head_size, v_embedding//cpttime) *
|
# shape is: (head_size, global_size) = (head_size, v_embedding//cpttime) *
|
||||||
# (v_embedding//cpttime, global_size)
|
# shape is: (v_embedding//cpttime, global_size)
|
||||||
with tik_inst.if_scope(cpt_idx == 0):
|
with tik_inst.if_scope(cpt_idx == 0):
|
||||||
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
||||||
head_size//ub_time, v_embedding//cpt_time, global_size, 0)
|
head_size // ub_time, v_embedding // cpt_time, global_size, 0)
|
||||||
with tik_inst.else_scope():
|
with tik_inst.else_scope():
|
||||||
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
tik_inst.mmad(d_w_global_l0c, d_global_l0a, v_global_l0b,
|
||||||
head_size//ub_time, v_embedding//cpt_time, global_size, 1)
|
head_size // ub_time, v_embedding // cpt_time, global_size, 1)
|
||||||
|
|
||||||
tik_inst.data_move(d_w_global_ub_32[0, 0, 0, 0], d_w_global_l0c[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_w_global_ub_32[0, 0, 0, 0], d_w_global_l0c[0, 0, 0, 0], 0,
|
||||||
1, head_size*global_size*4//(1024*ub_time),
|
1, head_size * global_size * 4 // (1024 * ub_time),
|
||||||
0, 0)
|
0, 0)
|
||||||
|
|
||||||
# tik_inst.tikdb.debug_print("'d_w_global_ub_32: '+str(d_global_l1)")
|
# shape is: global_size // 16, head_size // 16, 16, 16)
|
||||||
|
rpt_time = global_size // (16 * 8)
|
||||||
# (global_size // 16, head_size // 16, 16, 16)
|
|
||||||
rpt_time = global_size//(16*8)
|
|
||||||
with tik_inst.for_range(0, rpt_time) as conv_i:
|
with tik_inst.for_range(0, rpt_time) as conv_i:
|
||||||
tik_inst.vconv(64, '',
|
tik_inst.vconv(64, '',
|
||||||
d_w_global_ub[conv_i*global_size //
|
d_w_global_ub[conv_i * global_size //
|
||||||
(16*rpt_time), 0, 0, 0],
|
(16 * rpt_time), 0, 0, 0],
|
||||||
d_w_global_ub_32[conv_i*global_size //
|
d_w_global_ub_32[conv_i * global_size //
|
||||||
(16*rpt_time), 0, 0, 0],
|
(16 * rpt_time), 0, 0, 0],
|
||||||
global_size * head_size//(64*rpt_time*ub_time), 1, 1, 4, 8)
|
global_size * head_size // (64 * rpt_time * ub_time), 1, 1, 4, 8)
|
||||||
|
|
||||||
# tik_inst.vconv(64, '', d_w_global_ub[0, 0, 0, 0], d_w_global_ub_32[0, 0, 0, 0],
|
|
||||||
# global_size * head_size//64, 1, 1, 4, 8)
|
|
||||||
|
|
||||||
# d_w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
|
||||||
# d_w_global_ub = (global_size // 16, head_size // (16*ub_time), 16, 16)
|
|
||||||
with tik_inst.if_scope(ub_time == 1):
|
with tik_inst.if_scope(ub_time == 1):
|
||||||
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_global_ub[0, 0, 0, 0], 0,
|
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx, 0, 0, 0, 0], d_w_global_ub[0, 0, 0, 0], 0,
|
||||||
1, head_size*global_size *
|
1, head_size * global_size *
|
||||||
2//(block_bite_size),
|
2 // (block_bite_size),
|
||||||
0, 0)
|
0, 0)
|
||||||
with tik_inst.else_scope():
|
with tik_inst.else_scope():
|
||||||
w_idx_i = w_idx // 2
|
w_idx_i = w_idx // 2
|
||||||
h_idx = (w_idx % 2) * 2 # 0/2
|
h_idx = (w_idx % 2) * 2 # 0/2
|
||||||
|
|
||||||
# tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx_i, 0,h_idx,0,0], d_w_global_ub[0,0,0,0], 0,
|
with tik_inst.for_range(0, head_size // (16 * ub_time)) as m_idx:
|
||||||
# 1, head_size*global_size*2//(block_bite_size*ub_time),
|
|
||||||
# 0,0)
|
|
||||||
|
|
||||||
with tik_inst.for_range(0, head_size//(16*ub_time)) as m_idx:
|
|
||||||
# d_w2_gm = (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
|
|
||||||
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx_i, 0, h_idx + m_idx, 0, 0],
|
tik_inst.data_move(d_w2_gm[bs_idx, head_idx, w_idx_i, 0, h_idx + m_idx, 0, 0],
|
||||||
d_w_global_ub[0, m_idx, 0, 0], 0,
|
d_w_global_ub[0, m_idx, 0, 0], 0,
|
||||||
global_size//16, 16*16*2//block_bite_size,
|
global_size // 16, 16 * 16 * 2 // block_bite_size,
|
||||||
(head_size//(16*ub_time) - 1) *
|
(head_size // (16 * ub_time) - 1) *
|
||||||
16*16*2//block_bite_size,
|
16 * 16 * 2 // block_bite_size,
|
||||||
(head_size//16 - 1)*16*16*2//block_bite_size)
|
(head_size // 16 - 1) * 16 * 16 * 2 // block_bite_size)
|
||||||
|
|
||||||
# tik_inst.tikdb.debug_print("'d_w_global_ub_32: '+str(d_global_l1)")
|
|
||||||
# tik_inst.tikdb.debug_print("'d_w2_gm: '+str(d_w2_gm[bs_idx, head_idx, w_idx_i, 0,h_idx,:,:])")
|
|
||||||
|
|
||||||
tik_inst.BuildCCE(kernel_name=kernel_name,
|
tik_inst.BuildCCE(kernel_name=kernel_name,
|
||||||
inputs=[w1_gm, w2_gm, v_gm, a_gm, d_a_gm],
|
inputs=[w1_gm, w2_gm, v_gm, a_gm, d_a_gm],
|
||||||
|
|
|
@ -14,17 +14,16 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" dense sparse to densne matmul"""
|
""" dense sparse to densne matmul"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
|
||||||
from te import tik
|
from te import tik
|
||||||
from topi.cce import util
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
|
||||||
|
|
||||||
dsd_matmul_info = TBERegOp('DSDMatmul') \
|
dsd_matmul_info = TBERegOp('DSDMatmul') \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("dsdmatmul.so") \
|
.binfile_name("dsdmatmul.so") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("DSDMatmulimpl") \
|
.kernel_name("dsd_matmul") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.input(0, "input_w1", False, "required", "all") \
|
.input(0, "input_w1", False, "required", "all") \
|
||||||
.input(1, "input_w2", False, "required", "all") \
|
.input(1, "input_w2", False, "required", "all") \
|
||||||
|
@ -35,19 +34,16 @@ dsd_matmul_info = TBERegOp('DSDMatmul') \
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(dsd_matmul_info)
|
@op_info_register(dsd_matmul_info)
|
||||||
def DSDMatmulimpl(input_w1, input_w2, input_v, output_y={}, kernel_name='DSDMatmulimpl'):
|
def dsd_matmul(input_w1, input_w2, input_v, output_y={}, kernel_name='dsd_matmul'):
|
||||||
""" dense sparse to densne matmul"""
|
""" dense sparse to densne matmul"""
|
||||||
# shape_w1 = input_w1.get('shape')
|
|
||||||
# shape_w2 = input_w2.get('shape')
|
|
||||||
# shape_v = input_v.get('shape')
|
|
||||||
if util.get_product_version() == util.VERSION_MINI:
|
if util.get_product_version() == util.VERSION_MINI:
|
||||||
tik_inst = tik.Tik(tik.Dprofile("v100", "mini"))
|
tik_inst = tik.Tik(tik.Dprofile("v100", "mini"))
|
||||||
else:
|
else:
|
||||||
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
|
||||||
|
|
||||||
# (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
# shape is: (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
|
||||||
input_w1_shape = input_w1.get('shape')
|
input_w1_shape = input_w1.get('shape')
|
||||||
# (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
# shape is: (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
|
||||||
input_w2_shape = input_w2.get('shape')
|
input_w2_shape = input_w2.get('shape')
|
||||||
input_v_shape = input_v.get('shape')
|
input_v_shape = input_v.get('shape')
|
||||||
|
|
||||||
|
@ -61,151 +57,106 @@ def DSDMatmulimpl(input_w1, input_w2, input_v, output_y={}, kernel_name='DSDMatm
|
||||||
seq_len = input_v_shape[0] * 16 // batch_size
|
seq_len = input_v_shape[0] * 16 // batch_size
|
||||||
|
|
||||||
block_bite_size = 32
|
block_bite_size = 32
|
||||||
cpt_time = seq_len//512
|
cpt_time = seq_len // 512
|
||||||
if v_embedding == 128:
|
|
||||||
thread = 1
|
|
||||||
else:
|
|
||||||
thread = 2
|
|
||||||
|
|
||||||
# # w:zN
|
|
||||||
# # 4, 16, 1024//64, 64//16, 64//16, 16*16
|
|
||||||
# w1_gm = tik_inst.Tensor('float16', (1, 2, 16, 4, 4, 16, 16), name='w1_gm', scope=tik.scope_gm)
|
|
||||||
# w2_gm = tik_inst.Tensor('float16', (1, 2, 16, 16, 4, 16, 16), name='w2_gm', scope=tik.scope_gm)
|
|
||||||
#
|
|
||||||
# # v:nZ
|
|
||||||
# v_gm = tik_inst.Tensor('float16', (64, 16, 16, 16), name='v_gm', scope=tik.scope_gm)
|
|
||||||
|
|
||||||
# w:zN
|
|
||||||
w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size //
|
w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size //
|
||||||
16, block_size//16, 16, 16), name='w1_gm', scope=tik.scope_gm)
|
16, block_size // 16, 16, 16), name='w1_gm', scope=tik.scope_gm)
|
||||||
w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size //
|
w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size //
|
||||||
16, head_size//16, 16, 16), name='w2_gm', scope=tik.scope_gm)
|
16, head_size // 16, 16, 16), name='w2_gm', scope=tik.scope_gm)
|
||||||
#
|
#
|
||||||
# # v:nZ
|
v_gm = tik_inst.Tensor('float16', (batch_size * seq_len // 16,
|
||||||
v_gm = tik_inst.Tensor('float16', (batch_size*seq_len//16,
|
head * v_embedding // 16, 16, 16), name='v_gm', scope=tik.scope_gm)
|
||||||
head*v_embedding//16, 16, 16), name='v_gm', scope=tik.scope_gm)
|
|
||||||
|
|
||||||
# zN
|
# zN
|
||||||
output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len//16, 16, 16), name='output_gm',
|
output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16),
|
||||||
|
name='output_gm',
|
||||||
scope=tik.scope_gm)
|
scope=tik.scope_gm)
|
||||||
|
|
||||||
channel_num = batch_size*head
|
channel_num = batch_size * head
|
||||||
with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx:
|
with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx:
|
||||||
head_idx = channel_idx // batch_size
|
head_idx = channel_idx // batch_size
|
||||||
bs_idx = channel_idx % batch_size
|
bs_idx = channel_idx % batch_size
|
||||||
|
output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16), name='output_l0c',
|
||||||
output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16),
|
|
||||||
name='output_l0c',
|
|
||||||
scope=tik.scope_cc)
|
scope=tik.scope_cc)
|
||||||
|
output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub_32',
|
||||||
output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16),
|
|
||||||
name='output_ub_32',
|
|
||||||
scope=tik.scope_ubuf)
|
scope=tik.scope_ubuf)
|
||||||
|
output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub',
|
||||||
output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16),
|
|
||||||
name='output_ub',
|
|
||||||
scope=tik.scope_ubuf)
|
scope=tik.scope_ubuf)
|
||||||
# zZ
|
# zZ
|
||||||
w1_l1 = tik_inst.Tensor(
|
w1_l1 = tik_inst.Tensor(
|
||||||
'float16', (block_size//16, head_size//16, 16, 16), name='w1_l1', scope=tik.scope_cbuf)
|
'float16', (block_size // 16, head_size // 16, 16, 16), name='w1_l1', scope=tik.scope_cbuf)
|
||||||
# nZ
|
# nZ
|
||||||
v_local_l1 = tik_inst.Tensor(
|
v_local_l1 = tik_inst.Tensor(
|
||||||
'float16', (head_size//16, v_embedding//16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf)
|
'float16', (head_size // 16, v_embedding // 16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf)
|
||||||
|
|
||||||
# zZ
|
# zZ
|
||||||
w2_l1 = tik_inst.Tensor('float16', (head_size//16, global_size//(16*cpt_time), 16, 16),
|
w2_l1 = tik_inst.Tensor('float16', (head_size // 16, global_size // (16 * cpt_time), 16, 16),
|
||||||
name='w2_l1',
|
name='w2_l1', scope=tik.scope_cbuf)
|
||||||
scope=tik.scope_cbuf)
|
|
||||||
|
|
||||||
# nZ
|
# nZ
|
||||||
# use same v_global
|
# use same v_global
|
||||||
v_global_l1 = tik_inst.Tensor('float16', (global_size//16, v_embedding//16, 16, 16),
|
v_global_l1 = tik_inst.Tensor('float16', (global_size // 16, v_embedding // 16, 16, 16),
|
||||||
name='v_global_l1',
|
name='v_global_l1', scope=tik.scope_cbuf)
|
||||||
scope=tik.scope_cbuf)
|
|
||||||
# global v
|
# global v
|
||||||
global_idx = 3 - head_idx % 4
|
global_idx = 3 - head_idx % 4
|
||||||
tik_inst.data_move(v_global_l1[0, 0, 0, 0],
|
tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx * seq_len // 16 + global_idx,
|
||||||
v_gm[bs_idx*seq_len//16+global_idx,
|
head_idx * v_embedding // 16, 0, 0], 0, seq_len // (4 * 16),
|
||||||
head_idx*v_embedding//16, 0, 0], 0,
|
16 * v_embedding * 2 // block_bite_size,
|
||||||
seq_len//(4*16), 16*v_embedding*2//block_bite_size,
|
(4 * head * v_embedding * 16 - 16 * v_embedding) * 2 // block_bite_size, 0)
|
||||||
(4*head*v_embedding*16-16*v_embedding)*2//block_bite_size, 0)
|
# every block size is 64, the output of the local and global is (1024,128) Zn
|
||||||
|
|
||||||
# every block size is 64, local和global输出均为(1024,128)的小z大n矩阵
|
|
||||||
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
|
||||||
# global
|
# global
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
w2_l0a = tik_inst.Tensor('float16', (head_size//16, global_size//(cpt_time*16), 16, 16),
|
w2_l0a = tik_inst.Tensor('float16', (head_size // 16, global_size // (cpt_time * 16), 16, 16),
|
||||||
name='w2_l0a', scope=tik.scope_ca)
|
name='w2_l0a', scope=tik.scope_ca)
|
||||||
|
v_global_l0b = tik_inst.Tensor('float16', (global_size // (cpt_time * 16), v_embedding // 16, 16, 16),
|
||||||
v_global_l0b = tik_inst.Tensor('float16', (global_size//(cpt_time*16), v_embedding//16, 16, 16),
|
|
||||||
name='v_global_l0b', scope=tik.scope_cb)
|
name='v_global_l0b', scope=tik.scope_cb)
|
||||||
|
|
||||||
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
with tik_inst.for_range(0, cpt_time) as cpt_idx:
|
||||||
with tik_inst.for_range(0, head_size//16) as brick_i:
|
with tik_inst.for_range(0, head_size // 16) as brick_i:
|
||||||
# tik_inst.tikdb.debug_print("'w_idx: '+str(w_idx)")
|
|
||||||
# tik_inst.tikdb.debug_print("'w2_gm: '+str(w2_gm.shape)")
|
|
||||||
# tik_inst.tikdb.debug_print("'brick_i: '+str(brick_i)")
|
|
||||||
# tik_inst.tikdb.debug_print("'(block_size//16-1)*16*16*2//block_bite_size:
|
|
||||||
# '+str((block_size//16-1)*16*16*2//block_bite_size)")
|
|
||||||
tik_inst.data_move(w2_l1[brick_i, 0, 0, 0],
|
tik_inst.data_move(w2_l1[brick_i, 0, 0, 0],
|
||||||
w2_gm[bs_idx, head_idx, w_idx, cpt_idx *
|
w2_gm[bs_idx, head_idx, w_idx, cpt_idx *
|
||||||
global_size//(16*cpt_time), brick_i, 0, 0], 0,
|
global_size // (16 * cpt_time), brick_i, 0, 0], 0,
|
||||||
global_size//(16*cpt_time), 16 *
|
global_size // (16 * cpt_time), 16 * 16 * 2 // block_bite_size,
|
||||||
16*2//block_bite_size,
|
(block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0)
|
||||||
(block_size//16-1)*16*16*2//block_bite_size, 0)
|
|
||||||
|
|
||||||
tik_inst.load2dv1(
|
tik_inst.load2dv1(
|
||||||
w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size*global_size//(cpt_time*16*16), 1, 0)
|
w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size * global_size // (cpt_time * 16 * 16), 1,
|
||||||
|
0)
|
||||||
|
|
||||||
tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx*global_size//(
|
tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx * global_size // (
|
||||||
16*cpt_time), 0, 0, 0], 0, global_size*v_embedding//(16*16*cpt_time), 1, 0)
|
16 * cpt_time), 0, 0, 0], 0, global_size * v_embedding // (16 * 16 * cpt_time), 1, 0)
|
||||||
|
|
||||||
with tik_inst.if_scope(cpt_idx == 0):
|
with tik_inst.if_scope(cpt_idx == 0):
|
||||||
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
||||||
block_size, global_size//cpt_time, v_embedding, 0)
|
block_size, global_size // cpt_time, v_embedding, 0)
|
||||||
with tik_inst.else_scope():
|
with tik_inst.else_scope():
|
||||||
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
|
||||||
block_size, global_size//cpt_time, v_embedding, 1)
|
block_size, global_size // cpt_time, v_embedding, 1)
|
||||||
|
|
||||||
# local
|
# local
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
w1_l0a = tik_inst.Tensor('float16', (block_size//16, head_size//16, 16, 16),
|
w1_l0a = tik_inst.Tensor('float16', (block_size // 16, head_size // 16, 16, 16),
|
||||||
name='w1_l0a', scope=tik.scope_ca)
|
name='w1_l0a', scope=tik.scope_ca)
|
||||||
v_local_l0b = tik_inst.Tensor('float16', (head_size//16, v_embedding//16, 16, 16),
|
v_local_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16),
|
||||||
name='v_local_l0b', scope=tik.scope_cb)
|
name='v_local_l0b', scope=tik.scope_cb)
|
||||||
|
|
||||||
# v
|
|
||||||
tik_inst.data_move(v_local_l1[0, 0, 0, 0],
|
tik_inst.data_move(v_local_l1[0, 0, 0, 0],
|
||||||
v_gm[bs_idx * seq_len//16 + w_idx * 4, head_idx *
|
v_gm[bs_idx * seq_len // 16 + w_idx * 4, head_idx *
|
||||||
v_embedding//16, 0, 0], 0, block_size//16,
|
v_embedding // 16, 0, 0], 0, block_size // 16,
|
||||||
16 * v_embedding * 2 // block_bite_size,
|
16 * v_embedding * 2 // block_bite_size,
|
||||||
16*(head-1)*v_embedding*2//block_bite_size, 0)
|
16 * (head - 1) * v_embedding * 2 // block_bite_size, 0)
|
||||||
|
|
||||||
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0,
|
tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0,
|
||||||
head_size*v_embedding//(16*16), 1, 0)
|
head_size * v_embedding // (16 * 16), 1, 0)
|
||||||
|
|
||||||
# w
|
# w
|
||||||
with tik_inst.for_range(0, block_size // 16) as brick_i:
|
with tik_inst.for_range(0, block_size // 16) as brick_i:
|
||||||
tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0,
|
tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0,
|
||||||
head_size // 16, (16 *
|
head_size // 16, (16 * 16 * 2) // block_bite_size,
|
||||||
16*2)//block_bite_size,
|
|
||||||
(block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0)
|
(block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0)
|
||||||
tik_inst.load2dv1(
|
tik_inst.load2dv1(w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size * head_size // (16 * 16), 1, 0)
|
||||||
w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size*head_size//(16*16), 1, 0)
|
|
||||||
|
|
||||||
tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b,
|
tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b,
|
||||||
block_size, head_size, v_embedding, 1)
|
block_size, head_size, v_embedding, 1)
|
||||||
|
|
||||||
tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0,
|
tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0,
|
||||||
1, block_size * v_embedding * 4 // 1024, 0, 0)
|
1, block_size * v_embedding * 4 // 1024, 0, 0)
|
||||||
|
|
||||||
tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0],
|
tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0],
|
||||||
v_embedding * block_size//64, 1, 1, 4, 8)
|
v_embedding * block_size // 64, 1, 1, 4, 8)
|
||||||
|
tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx * (block_size // 16), 0, 0],
|
||||||
tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx*(block_size//16), 0, 0], output_ub[0, 0, 0, 0],
|
output_ub[0, 0, 0, 0],
|
||||||
0, v_embedding//16, 16*block_size*2//block_bite_size, 0,
|
0, v_embedding // 16, 16 * block_size * 2 // block_bite_size, 0,
|
||||||
(seq_len - block_size)*16*2//block_bite_size)
|
(seq_len - block_size) * 16 * 2 // block_bite_size)
|
||||||
|
|
||||||
tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm],
|
tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm],
|
||||||
outputs=[output_gm])
|
outputs=[output_gm])
|
||||||
return tik_inst
|
return tik_inst
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
from te import tik
|
from te import tik
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
|
||||||
matmul_dds_grad_op_info = TBERegOp("MatmulDDSGrad") \
|
matmul_dds_grad_op_info = TBERegOp("MatmulDDSGrad") \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
|
@ -59,13 +58,7 @@ def matmul_dds_grad(q,
|
||||||
:param local_prob_grad: local output grad (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN
|
:param local_prob_grad: local output grad (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN
|
||||||
:param global_prob_grad: global output grad (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN
|
:param global_prob_grad: global output grad (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN
|
||||||
"""
|
"""
|
||||||
# seq_len = 1024
|
|
||||||
# size_per_head = 128
|
|
||||||
# block_size = 64
|
|
||||||
# block_num = 16
|
|
||||||
# global_size = 256
|
|
||||||
# bs = 2
|
|
||||||
# heads = 2
|
|
||||||
shape_q = q.get(
|
shape_q = q.get(
|
||||||
'shape')
|
'shape')
|
||||||
shape_lc = local_prob.get(
|
shape_lc = local_prob.get(
|
||||||
|
@ -120,17 +113,16 @@ def matmul_dds_grad(q,
|
||||||
(global_size + block_size) * 16 // 128, 8)
|
(global_size + block_size) * 16 // 128, 8)
|
||||||
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
||||||
0, (global_size + block_size) // 16, 16, 0, 0)
|
0, (global_size + block_size) // 16, 16, 0, 0)
|
||||||
# all_head = 32 * rx + block_index
|
|
||||||
b = tik_inst.Scalar(dtype="int32")
|
b = tik_inst.Scalar(dtype="int32")
|
||||||
b.set_as(block_index // heads)
|
b.set_as(block_index // heads)
|
||||||
|
|
||||||
# head = block_index - b * heads
|
|
||||||
head = tik_inst.Scalar(dtype="int32")
|
head = tik_inst.Scalar(dtype="int32")
|
||||||
head.set_as(block_index - b * heads)
|
head.set_as(block_index - b * heads)
|
||||||
# s = head // 4
|
|
||||||
s = tik_inst.Scalar(dtype="int32")
|
s = tik_inst.Scalar(dtype="int32")
|
||||||
s.set_as(head // 4)
|
s.set_as(head // 4)
|
||||||
# global_idx = 3 - (head - 4 * s) # global idx for global key extraction
|
# formula: global_idx = 3 - (head - 4 * s) # global idx for global key extraction
|
||||||
global_idx = tik_inst.Scalar(dtype="int32")
|
global_idx = tik_inst.Scalar(dtype="int32")
|
||||||
global_idx.set_as(3 - (head - 4 * s))
|
global_idx.set_as(3 - (head - 4 * s))
|
||||||
# apply tensor in l1 for global k (256, 128) nZ
|
# apply tensor in l1 for global k (256, 128) nZ
|
||||||
|
@ -154,7 +146,7 @@ def matmul_dds_grad(q,
|
||||||
0, size_per_head // 16, 16, bs * seq_len - 16, 0)
|
0, size_per_head // 16, 16, bs * seq_len - 16, 0)
|
||||||
with tik_inst.for_range(0, block_num) as block:
|
with tik_inst.for_range(0, block_num) as block:
|
||||||
# do backward softmax
|
# do backward softmax
|
||||||
# grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
|
# formula: grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
|
||||||
# apply for tensor in ub for grad_x out (64, 320) zN
|
# apply for tensor in ub for grad_x out (64, 320) zN
|
||||||
mat_ub_lg_d = tik_inst.Tensor("float16",
|
mat_ub_lg_d = tik_inst.Tensor("float16",
|
||||||
((global_size + block_size) //
|
((global_size + block_size) //
|
||||||
|
@ -217,16 +209,16 @@ def matmul_dds_grad(q,
|
||||||
1, 1, 1, 8, 8, 8)
|
1, 1, 1, 8, 8, 8)
|
||||||
|
|
||||||
# apply for tensor in L1 for dsoftmax*softmax result (320, 64) nZ
|
# apply for tensor in L1 for dsoftmax*softmax result (320, 64) nZ
|
||||||
mat_l1_ssg_nZ = tik_inst.Tensor("float16", ((global_size + block_size) // 16,
|
mat_l1_ssg_nz = tik_inst.Tensor("float16", ((global_size + block_size) // 16,
|
||||||
block_size // 16, 16, 16),
|
block_size // 16, 16, 16),
|
||||||
name='mat_l1_ssg_nZ',
|
name='mat_l1_ssg_nz',
|
||||||
scope=tik.scope_cbuf)
|
scope=tik.scope_cbuf)
|
||||||
# move ones from ub to L1 for CUBE mmad
|
# move ones from ub to L1 for CUBE mmad
|
||||||
# the shape of ones in ub is nZ
|
# the shape of ones in ub is nZ
|
||||||
# the shape of ones in L0A is nZ
|
# the shape of ones in L0A is nZ
|
||||||
# the stride between each (16, 16) is 0
|
# the stride between each (16, 16) is 0
|
||||||
# repeat 32 times
|
# repeat 32 times
|
||||||
tik_inst.data_move(mat_l1_ssg_nZ[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], 0,
|
tik_inst.data_move(mat_l1_ssg_nz[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], 0,
|
||||||
(global_size + block_size) // 16, block_size, 0, 0)
|
(global_size + block_size) // 16, block_size, 0, 0)
|
||||||
# apply tensor in l0c for exp sum (16, 64) zN
|
# apply tensor in l0c for exp sum (16, 64) zN
|
||||||
mat_l0c_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
mat_l0c_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16),
|
||||||
|
@ -254,7 +246,7 @@ def matmul_dds_grad(q,
|
||||||
# the shape of ssg in L0B is nZ
|
# the shape of ssg in L0B is nZ
|
||||||
# the stride between each (16, 16) is 0
|
# the stride between each (16, 16) is 0
|
||||||
# repeat 128 times
|
# repeat 128 times
|
||||||
tik_inst.load2dv1(mat_l0b_ssg[0, 0, 0, 0], mat_l1_ssg_nZ[0, 0, 0, 0], 0,
|
tik_inst.load2dv1(mat_l0b_ssg[0, 0, 0, 0], mat_l1_ssg_nz[0, 0, 0, 0], 0,
|
||||||
(global_size + block_size) * block_size // (16 * 16), 1, 0, False)
|
(global_size + block_size) * block_size // (16 * 16), 1, 0, False)
|
||||||
tik_inst.mmad(mat_l0c_ssg_sum, mat_l0a_ones, mat_l0b_ssg,
|
tik_inst.mmad(mat_l0c_ssg_sum, mat_l0a_ones, mat_l0b_ssg,
|
||||||
16, (global_size + block_size), block_size, 0)
|
16, (global_size + block_size), block_size, 0)
|
||||||
|
@ -265,7 +257,7 @@ def matmul_dds_grad(q,
|
||||||
name='mat_ub_ssg_sums',
|
name='mat_ub_ssg_sums',
|
||||||
scope=tik.scope_ubuf)
|
scope=tik.scope_ubuf)
|
||||||
tik_inst.data_move(mat_ub_ssg_sums[0], mat_ub_ssg_sum[0, 0, 0, 0],
|
tik_inst.data_move(mat_ub_ssg_sums[0], mat_ub_ssg_sum[0, 0, 0, 0],
|
||||||
0, block_size // 16, 1*2, 15*2, 0)
|
0, block_size // 16, 1 * 2, 15 * 2, 0)
|
||||||
# apply for tensor in UB for global prob sum (64,)
|
# apply for tensor in UB for global prob sum (64,)
|
||||||
mat_ub_ssg_sums_16 = tik_inst.Tensor("float16", (block_size,),
|
mat_ub_ssg_sums_16 = tik_inst.Tensor("float16", (block_size,),
|
||||||
name='mat_ub_ssg_sums_16',
|
name='mat_ub_ssg_sums_16',
|
||||||
|
@ -423,30 +415,30 @@ def matmul_dds_grad(q,
|
||||||
# local dk calculation
|
# local dk calculation
|
||||||
# dk calculation q.T X dw
|
# dk calculation q.T X dw
|
||||||
# apply for tensor in ub for dw (320, 64) nZ
|
# apply for tensor in ub for dw (320, 64) nZ
|
||||||
mat_ub_lg_d_nZ = tik_inst.Tensor("float16",
|
mat_ub_lg_d_nz = tik_inst.Tensor("float16",
|
||||||
(block_size // 16, (global_size +
|
(block_size // 16, (global_size +
|
||||||
block_size) // 16, 16, 16),
|
block_size) // 16, 16, 16),
|
||||||
name='mat_ub_lg_d_nZ',
|
name='mat_ub_lg_d_nz',
|
||||||
scope=tik.scope_ubuf)
|
scope=tik.scope_ubuf)
|
||||||
# transpose dw from zN to nZ
|
# transpose dw from zN to nZ
|
||||||
with tik_inst.for_range(0, (global_size + block_size) // 16) as lb:
|
with tik_inst.for_range(0, (global_size + block_size) // 16) as lb:
|
||||||
with tik_inst.for_range(0, block_size // 16) as gb:
|
with tik_inst.for_range(0, block_size // 16) as gb:
|
||||||
tik_inst.vtranspose(
|
tik_inst.vtranspose(
|
||||||
mat_ub_lg_d_nZ[gb, lb, 0, 0], mat_ub_lg_d[lb, gb, 0, 0])
|
mat_ub_lg_d_nz[gb, lb, 0, 0], mat_ub_lg_d[lb, gb, 0, 0])
|
||||||
|
|
||||||
# apply tensor in l1 for local dw (64, 64) nZ
|
# apply tensor in l1 for local dw (64, 64) nZ
|
||||||
mat_l1_ldw_nZ = tik_inst.Tensor("float16",
|
mat_l1_ldw_nz = tik_inst.Tensor("float16",
|
||||||
(block_size // 16,
|
(block_size // 16,
|
||||||
block_size // 16, 16, 16),
|
block_size // 16, 16, 16),
|
||||||
name="mat_l1_ldw_nZ",
|
name="mat_l1_ldw_nz",
|
||||||
scope=tik.scope_cbuf)
|
scope=tik.scope_cbuf)
|
||||||
# move local dw from ub to l1
|
# move local dw from ub to l1
|
||||||
# the shape of local dw in ub is nZ
|
# the shape of local dw in ub is nZ
|
||||||
# the shape of local dw in l1 is nZ
|
# the shape of local dw in l1 is nZ
|
||||||
# the stride between each (16, 64) is 256
|
# the stride between each (16, 64) is 256
|
||||||
# repeat 4 times
|
# repeat 4 times
|
||||||
tik_inst.data_move(mat_l1_ldw_nZ[0, 0, 0, 0],
|
tik_inst.data_move(mat_l1_ldw_nz[0, 0, 0, 0],
|
||||||
mat_ub_lg_d_nZ[0, 0, 0, 0],
|
mat_ub_lg_d_nz[0, 0, 0, 0],
|
||||||
0, block_size // 16, block_size, global_size, 0)
|
0, block_size // 16, block_size, global_size, 0)
|
||||||
# apply for tensor in L1 for q (128, 64) nZ
|
# apply for tensor in L1 for q (128, 64) nZ
|
||||||
mat_l1_q_b = tik_inst.Tensor("float16",
|
mat_l1_q_b = tik_inst.Tensor("float16",
|
||||||
|
@ -498,7 +490,7 @@ def matmul_dds_grad(q,
|
||||||
# the shape of local dw in L0B is nZ
|
# the shape of local dw in L0B is nZ
|
||||||
# the stride between each (16, 16) is 0
|
# the stride between each (16, 16) is 0
|
||||||
# repeat 32 times
|
# repeat 32 times
|
||||||
tik_inst.load2dv1(mat_l0b_ldw[0, 0, 0, 0], mat_l1_ldw_nZ[0, 0, 0, 0], 0,
|
tik_inst.load2dv1(mat_l0b_ldw[0, 0, 0, 0], mat_l1_ldw_nz[0, 0, 0, 0], 0,
|
||||||
block_size * block_size // (16 * 16), 1, 0, False)
|
block_size * block_size // (16 * 16), 1, 0, False)
|
||||||
# matmul q and local dw
|
# matmul q and local dw
|
||||||
# the shape of local k in L0C is zN
|
# the shape of local k in L0C is zN
|
||||||
|
@ -525,7 +517,7 @@ def matmul_dds_grad(q,
|
||||||
# the stride between each (16, 64) is 0
|
# the stride between each (16, 64) is 0
|
||||||
# repeat 8 times
|
# repeat 8 times
|
||||||
tik_inst.data_move(mat_l1_dwg_b[0, 0, 0, 0],
|
tik_inst.data_move(mat_l1_dwg_b[0, 0, 0, 0],
|
||||||
mat_ub_lg_d_nZ[0, block_size // 16, 0, 0],
|
mat_ub_lg_d_nz[0, block_size // 16, 0, 0],
|
||||||
0, block_size // 16, global_size, block_size, 0)
|
0, block_size // 16, global_size, block_size, 0)
|
||||||
|
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
|
|
|
@ -21,7 +21,7 @@ matmul_dds_op_info = TBERegOp("MatmulDDS") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("matmul_dds.so") \
|
.binfile_name("matmul_dds.so") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("MatmulDDSImpl") \
|
.kernel_name("matmul_dds") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.attr("bs", "required", "int", "all") \
|
.attr("bs", "required", "int", "all") \
|
||||||
.attr("heads", "required", "int", "all") \
|
.attr("heads", "required", "int", "all") \
|
||||||
|
@ -38,15 +38,15 @@ matmul_dds_op_info = TBERegOp("MatmulDDS") \
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(matmul_dds_op_info)
|
@op_info_register(matmul_dds_op_info)
|
||||||
def MatmulDDSImpl(q,
|
def matmul_dds(q,
|
||||||
k,
|
k,
|
||||||
local_mask,
|
local_mask,
|
||||||
global_mask,
|
global_mask,
|
||||||
local_prob,
|
local_prob,
|
||||||
global_prob,
|
global_prob,
|
||||||
bs,
|
bs,
|
||||||
heads,
|
heads,
|
||||||
kernel_name="MatmulDDSImpl"):
|
kernel_name="matmul_dds"):
|
||||||
"""
|
"""
|
||||||
:param q: the dict of input q (bs*seq_len, embedding_size) zN
|
:param q: the dict of input q (bs*seq_len, embedding_size) zN
|
||||||
:param k: the dict of input k (bs*seq_len, embedding_size) nZ
|
:param k: the dict of input k (bs*seq_len, embedding_size) nZ
|
||||||
|
@ -58,9 +58,6 @@ def MatmulDDSImpl(q,
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# bs = 4
|
|
||||||
# heads = 16
|
|
||||||
|
|
||||||
shape_q = q.get(
|
shape_q = q.get(
|
||||||
'shape') # shape_q (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) zN
|
'shape') # shape_q (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) zN
|
||||||
shape_k = k.get(
|
shape_k = k.get(
|
||||||
|
@ -76,11 +73,6 @@ def MatmulDDSImpl(q,
|
||||||
block_size = shape_local_mask[0] # block size only support 64 for now
|
block_size = shape_local_mask[0] # block size only support 64 for now
|
||||||
block_num = seq_len // block_size # block number only support 16 for now
|
block_num = seq_len // block_size # block number only support 16 for now
|
||||||
global_size = seq_len // 4 # global size only support 256 for now
|
global_size = seq_len // 4 # global size only support 256 for now
|
||||||
# seq_len = 1024
|
|
||||||
# size_per_head = 128
|
|
||||||
# block_size = 64
|
|
||||||
# block_num = 16
|
|
||||||
# global_size = 256
|
|
||||||
|
|
||||||
tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud'))
|
tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud'))
|
||||||
|
|
||||||
|
@ -119,17 +111,14 @@ def MatmulDDSImpl(q,
|
||||||
(global_size + block_size) * 16 // 128, 8)
|
(global_size + block_size) * 16 // 128, 8)
|
||||||
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0],
|
||||||
0, (global_size + block_size) // 16, 16, 0, 0)
|
0, (global_size + block_size) // 16, 16, 0, 0)
|
||||||
# b = block_index // heads
|
|
||||||
b = tik_inst.Scalar(dtype="int32")
|
b = tik_inst.Scalar(dtype="int32")
|
||||||
b.set_as(block_index // heads)
|
b.set_as(block_index // heads)
|
||||||
|
|
||||||
# head = block_index - b * heads
|
|
||||||
head = tik_inst.Scalar(dtype="int32")
|
head = tik_inst.Scalar(dtype="int32")
|
||||||
head.set_as(block_index - b * heads)
|
head.set_as(block_index - b * heads)
|
||||||
# s = head // 4
|
|
||||||
s = tik_inst.Scalar(dtype="int32")
|
s = tik_inst.Scalar(dtype="int32")
|
||||||
s.set_as(head // 4)
|
s.set_as(head // 4)
|
||||||
# global_idx = 3 - (head - 4 * s) # global idx for global key extraction
|
|
||||||
global_idx = tik_inst.Scalar(dtype="int32")
|
global_idx = tik_inst.Scalar(dtype="int32")
|
||||||
global_idx.set_as(3 - (head - 4 * s))
|
global_idx.set_as(3 - (head - 4 * s))
|
||||||
# apply tensor for global key which is (128, 256) in L1 nZ
|
# apply tensor for global key which is (128, 256) in L1 nZ
|
||||||
|
@ -257,22 +246,6 @@ def MatmulDDSImpl(q,
|
||||||
tik_inst.mmad(mat_l0c_g, mat_l0a_g, mat_l0b_g,
|
tik_inst.mmad(mat_l0c_g, mat_l0a_g, mat_l0b_g,
|
||||||
block_size, size_per_head // 4, global_size, 1)
|
block_size, size_per_head // 4, global_size, 1)
|
||||||
|
|
||||||
# with tik_inst.for_range(0, global_size // 16, thread_num=2) as gb:
|
|
||||||
# mat_ub_g = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
|
||||||
# name='mat_ub_g',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.data_move(mat_ub_g[0, 0, 0, 0], mat_l0c_g[gb, 0, 0, 0], 0,
|
|
||||||
# 1, block_size // 16, 0, 0)
|
|
||||||
# mat_ub_g_exp = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
|
||||||
# name='mat_ub_g_exp',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vexp(64, mat_ub_g_exp[0, 0, 0, 0],
|
|
||||||
# mat_ub_g[0, 0, 0, 0], block_size * 16 // 64, 1, 1, 8, 8)
|
|
||||||
# # cast fp32 exp to fp16 zN
|
|
||||||
# tik_inst.vec_conv(64, "", mat_ub_lg_exp_16[gb, 0, 0, 0],
|
|
||||||
# mat_ub_g_exp[0, 0, 0, 0],
|
|
||||||
# block_size * 16 // 64, 4, 8)
|
|
||||||
|
|
||||||
# local
|
# local
|
||||||
# apply a new scope
|
# apply a new scope
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
|
@ -324,22 +297,6 @@ def MatmulDDSImpl(q,
|
||||||
tik_inst.mmad(mat_l0c_l, mat_l0a_l, mat_l0b_l,
|
tik_inst.mmad(mat_l0c_l, mat_l0a_l, mat_l0b_l,
|
||||||
block_size, size_per_head // 4, block_size, 1)
|
block_size, size_per_head // 4, block_size, 1)
|
||||||
|
|
||||||
# with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb:
|
|
||||||
# mat_ub_l = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
|
||||||
# name='mat_ub_l',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.data_move(mat_ub_l[0, 0, 0, 0], mat_l0c_l[gb, 0, 0, 0], 0,
|
|
||||||
# 1, block_size // 16, 0, 0)
|
|
||||||
# mat_ub_l_exp = tik_inst.Tensor("float32", (1, block_size // 16, 16, 16),
|
|
||||||
# name='mat_ub_l_exp',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vexp(64, mat_ub_l_exp[0, 0, 0, 0],
|
|
||||||
# mat_ub_l[0, 0, 0, 0], block_size * 16 // 64, 1, 1, 8, 8)
|
|
||||||
# # cast fp32 exp to fp16 zN
|
|
||||||
# tik_inst.vec_conv(64, "", mat_ub_lg_exp_16[global_size // 16 + gb, 0, 0, 0],
|
|
||||||
# mat_ub_l_exp[0, 0, 0, 0],
|
|
||||||
# block_size * 16 // 64, 4, 8)
|
|
||||||
|
|
||||||
with tik_inst.new_stmt_scope():
|
with tik_inst.new_stmt_scope():
|
||||||
with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb:
|
with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb:
|
||||||
mat_ub_lg = tik_inst.Tensor("float32", (1, (block_size + global_size) // 16, 16, 16),
|
mat_ub_lg = tik_inst.Tensor("float32", (1, (block_size + global_size) // 16, 16, 16),
|
||||||
|
@ -355,9 +312,6 @@ def MatmulDDSImpl(q,
|
||||||
tik_inst.vec_conv(64, "", mat_ub_lg_16[0, 0, 0, 0],
|
tik_inst.vec_conv(64, "", mat_ub_lg_16[0, 0, 0, 0],
|
||||||
mat_ub_lg[0, 0, 0, 0],
|
mat_ub_lg[0, 0, 0, 0],
|
||||||
(block_size + global_size) * 16 // 64, 4, 8)
|
(block_size + global_size) * 16 // 64, 4, 8)
|
||||||
# mat_ub_lg_max = tik_inst.Tensor("float16", (2,),
|
|
||||||
# name='mat_ub_lg_max',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
with tik_inst.for_range(0, 16) as lb:
|
with tik_inst.for_range(0, 16) as lb:
|
||||||
mat_ub_lg_lb = tik_inst.Tensor("float16", (block_size + global_size,),
|
mat_ub_lg_lb = tik_inst.Tensor("float16", (block_size + global_size,),
|
||||||
name='mat_ub_lg_lb',
|
name='mat_ub_lg_lb',
|
||||||
|
@ -422,74 +376,6 @@ def MatmulDDSImpl(q,
|
||||||
tik_inst.data_move(mat_ub_lg_exp_16[0, gb, lb, 0], mat_ub_lg_exp_lb[0], 0,
|
tik_inst.data_move(mat_ub_lg_exp_16[0, gb, lb, 0], mat_ub_lg_exp_lb[0], 0,
|
||||||
(block_size + global_size) // 16, 1, 0, block_size - 1)
|
(block_size + global_size) // 16, 1, 0, block_size - 1)
|
||||||
|
|
||||||
# max_worker = tik_inst.Tensor("float16", ((block_size + global_size) // 8,),
|
|
||||||
# name='max_worker',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
|
|
||||||
# with tik_inst.for_range(0, 16) as sb:
|
|
||||||
# max_16 = tik_inst.Tensor("float16", ((block_size + global_size) // 16, 2),
|
|
||||||
# name='max_16',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vcmax(16, max_16[0, 0], mat_ub_lg_16[0, 0, sb, 0], (block_size + global_size) // 16,
|
|
||||||
# 1, 1, 16)
|
|
||||||
# real_max = tik_inst.Tensor("float16", ((block_size + global_size) // 16,),
|
|
||||||
# name='real_max',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# with tik_inst.for_range(0, (block_size + global_size) // 16) as mb:
|
|
||||||
# max_v = tik_inst.Scalar("float16",
|
|
||||||
# name='max_v',
|
|
||||||
# init_value=0)
|
|
||||||
# max_v.set_as(max_16[mb, 0])
|
|
||||||
# real_max[mb].set_as(max_v)
|
|
||||||
#
|
|
||||||
# tik_inst.vcmax((block_size + global_size) // 16, mat_ub_lg_max[sb], real_max[0],
|
|
||||||
# 1, 1, 1, 2)
|
|
||||||
# tik_inst.vec_reduce_max(16, mat_ub_lg_max[sb, 0], mat_ub_lg_16[0, 0, sb, 0],
|
|
||||||
# max_worker, (block_size + global_size) // 16, 16, cal_index=False)
|
|
||||||
# mat_ub_lg_max_sub = tik_inst.Tensor("float16", (2,),
|
|
||||||
# name='mat_ub_lg_max_sub',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vmuls(2, mat_ub_lg_max_sub[0], mat_ub_lg_max[0], -1.0, 1, 1, 1, 1, 1)
|
|
||||||
# mat_ub_lg_subs = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16),
|
|
||||||
# name='mat_ub_lg_subs',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# max_value = tik_inst.Scalar("float16",
|
|
||||||
# name='max_value',
|
|
||||||
# init_value=0)
|
|
||||||
# # set value for scalar prob sum rec
|
|
||||||
# max_value.set_as(mat_ub_lg_max_sub[0])
|
|
||||||
# max_value_int = tik_inst.Tensor("int8", (1,),
|
|
||||||
# name='max_value_int',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vec_conv(1, "", max_value_int, mat_ub_lg_max_sub[0], 1, 1, 1)
|
|
||||||
# max_value_ints = tik_inst.Scalar("int8",
|
|
||||||
# name='max_value_ints',
|
|
||||||
# init_value=0)
|
|
||||||
# max_value_ints.set_as(max_value_int[0])
|
|
||||||
# with tik_inst.if_scope(max_value_ints > 0):
|
|
||||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
|
||||||
# 0.0, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
|
||||||
# with tik_inst.else_scope():
|
|
||||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
|
||||||
# max_value, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
|
||||||
# tik_inst.vadds(128, mat_ub_lg_subs[0, 0, 0, 0], mat_ub_lg_16[0, 0, 0, 0],
|
|
||||||
# 5, (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
|
||||||
# with tik_inst.for_range(0, 16) as sb:
|
|
||||||
# max_value = tik_inst.Scalar("float16",
|
|
||||||
# name='max_value',
|
|
||||||
# init_value=0)
|
|
||||||
# # set value for scalar prob sum rec
|
|
||||||
# max_value.set_as(mat_ub_lg_max_sub[sb])
|
|
||||||
# tik_inst.vadds(16, mat_ub_lg_subs[0, 0, sb, 0], mat_ub_lg_16[0, 0, sb, 0],
|
|
||||||
# max_value, (block_size + global_size) // 16, 1, 1, 16, 16)
|
|
||||||
# mat_ub_lg_exp = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16),
|
|
||||||
# name='mat_ub_lg_exp',
|
|
||||||
# scope=tik.scope_ubuf)
|
|
||||||
# tik_inst.vexp(128, mat_ub_lg_exp[0, 0, 0, 0],
|
|
||||||
# mat_ub_lg_subs[0, 0, 0, 0], (block_size + global_size) * 16 // 128, 1, 1, 8, 8)
|
|
||||||
# tik_inst.data_move(mat_ub_lg_exp_16[0, gb, 0, 0], mat_ub_lg_exp[0, 0, 0, 0], 0,
|
|
||||||
# (block_size + global_size) // 16, 16, 0, block_size - 16)
|
|
||||||
|
|
||||||
# move exp fp16 from ub to L1 for CUBE mmad
|
# move exp fp16 from ub to L1 for CUBE mmad
|
||||||
# the shape of exp fp16 in ub is zN
|
# the shape of exp fp16 in ub is zN
|
||||||
# the shape of exp fp16 in L1 is zN
|
# the shape of exp fp16 in L1 is zN
|
||||||
|
@ -537,8 +423,7 @@ def MatmulDDSImpl(q,
|
||||||
# the stride between each (16, 16) is 0
|
# the stride between each (16, 16) is 0
|
||||||
# repeat 128 times
|
# repeat 128 times
|
||||||
tik_inst.load2dv1(mat_l0b_exp[0, 0, 0, 0],
|
tik_inst.load2dv1(mat_l0b_exp[0, 0, 0, 0],
|
||||||
mat_l1_lg_exp_16[(
|
mat_l1_lg_exp_16[(global_size + block_size) * gb // 64, 0, 0, 0], 0,
|
||||||
global_size + block_size) * gb // 64, 0, 0, 0], 0,
|
|
||||||
(global_size + block_size) * block_size // (4 * 16 * 16), 1, 0, False)
|
(global_size + block_size) * block_size // (4 * 16 * 16), 1, 0, False)
|
||||||
with tik_inst.if_scope(gb == 0):
|
with tik_inst.if_scope(gb == 0):
|
||||||
tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16,
|
tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16,
|
||||||
|
@ -568,7 +453,6 @@ def MatmulDDSImpl(q,
|
||||||
# calculate attention prob sum vec (64,)
|
# calculate attention prob sum vec (64,)
|
||||||
tik_inst.vec_rec_high_preci(
|
tik_inst.vec_rec_high_preci(
|
||||||
64, mat_ub_exp_sum_rec, mat_ub_lg_exp_sum, worker_tensor, 1, 8, 8)
|
64, mat_ub_exp_sum_rec, mat_ub_lg_exp_sum, worker_tensor, 1, 8, 8)
|
||||||
# tik_instance.tikdb.debug_print('mat_ub_exp_sum_rec')
|
|
||||||
tik_inst.vec_conv(block_size, "", mat_ub_exp_sum_rec_16[0],
|
tik_inst.vec_conv(block_size, "", mat_ub_exp_sum_rec_16[0],
|
||||||
mat_ub_exp_sum_rec[0],
|
mat_ub_exp_sum_rec[0],
|
||||||
block_size // 64, 4, 8)
|
block_size // 64, 4, 8)
|
||||||
|
@ -593,12 +477,6 @@ def MatmulDDSImpl(q,
|
||||||
# the shape of local out in gm is zN
|
# the shape of local out in gm is zN
|
||||||
# the stride between each (16, 64) is 0
|
# the stride between each (16, 64) is 0
|
||||||
# repeat 4 times
|
# repeat 4 times
|
||||||
# head_id = str(head)
|
|
||||||
# block_id = str(block)
|
|
||||||
# mat_ub_l_out_shape = str(mat_ub_l_out.shape)
|
|
||||||
# tik_inst.tikdb.debug_print(head_id)
|
|
||||||
# tik_inst.tikdb.debug_print(block_id)
|
|
||||||
# tik_inst.tikdb.debug_print(mat_ub_l_out_shape)
|
|
||||||
tik_inst.data_move(mat_lc[b, head, block, 0, 0, 0, 0], mat_ub_l_out[0, 0, 0, 0], 0,
|
tik_inst.data_move(mat_lc[b, head, block, 0, 0, 0, 0], mat_ub_l_out[0, 0, 0, 0], 0,
|
||||||
block_size // 16, block_size, 0, 0)
|
block_size // 16, block_size, 0, 0)
|
||||||
# move global out from UB to gm
|
# move global out from UB to gm
|
||||||
|
|
|
@ -285,7 +285,7 @@ class FeedForward(Cell):
|
||||||
Float tensor.
|
Float tensor.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor. the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
|
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
|
||||||
[batch * seq_length, hidden_size]`.
|
[batch * seq_length, hidden_size]`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
|
Loading…
Reference in New Issue