replace matmul/batchmatmul by mul when k is 1

This commit is contained in:
lingyunli63 2021-06-02 15:26:48 +08:00
parent 76be4aff12
commit 4f34e537a0
5 changed files with 220 additions and 12 deletions

View File

@ -53,3 +53,4 @@ from .softmax_grad_ext import SoftmaxGradExt
from .square_sum_v1 import SquareSumV1 from .square_sum_v1 import SquareSumV1
from .fused_mul_add import FusedMulAdd from .fused_mul_add import FusedMulAdd
from .conv2d import Conv2D from .conv2d import Conv2D
from .matmul import MatMul, BatchMatMul

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for BatchMatMul and MatMul"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
class MatMul(Expander):
"""
MatMul expander
"""
def __init__(self, expand_info):
super().__init__(expand_info)
self.transpose_a = self.attrs['transpose_a']
self.transpose_b = self.attrs['transpose_b']
self.left_format = self.attrs['left_format']
self.right_format = self.attrs['right_format']
self.shape_a = self.inputs[0]['shape']
self.shape_b = self.inputs[1]['shape']
def _optimize_to_mul(self):
"""check if matmul can be replace by mul"""
if self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
return False
k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
if k_a != 1 or k_b != 1:
return False
return True
def _check(self):
input_num = len(self.inputs)
if input_num < 2:
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
def _trans_shape(self, shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
return trans_shape
def _expand(self, graph_builder):
if not self._optimize_to_mul():
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
#Matmul is replaced by Mul([b m k], [b k n]) when k==1
input_a = self.inputs[0]
input_b = self.inputs[1]
if self.transpose_a:
shape_a_trans = self._trans_shape(self.shape_a)
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
if self.transpose_b:
shape_b_trans = self._trans_shape(self.shape_b)
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
result = graph_builder.emit('Mul', [input_a, input_b])
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
return result
class BatchMatMul(MatMul):
"""BatchMatMul expander"""

View File

@ -16,7 +16,6 @@
import copy import copy
import sys import sys
from functools import reduce
from .model import GraphKernelUnsupportedException as GKException from .model import GraphKernelUnsupportedException as GKException
from .model import PrimLib, DataFormat as DF from .model import PrimLib, DataFormat as DF
@ -102,18 +101,59 @@ class OpInfer:
class _Elemwise(OpInfer): class _Elemwise(OpInfer):
"""Common infer for elementwise operators""" """Common infer for elementwise operators"""
def _infer_shape(self): def _broadcast_shape(self, shapes):
"""returns the input shape with largest flatten size""" """deduce broadcast shape using same rules as numpy"""
shape = (1,) dim_size = max([len(shape) for shape in shapes])
max_flatten_size = 1 align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
for t in self.inputs: out_shape = [1] * dim_size
if t.data_format != DF.DEFAULT: for i in range(dim_size):
return t.shape for align_shape in align_shapes:
flatten_size = reduce(lambda x, y: x * y, t.shape) if align_shape[i] > 1:
if flatten_size > max_flatten_size or (flatten_size == max_flatten_size and len(t.shape) > len(shape)): if out_shape[i] == 1:
max_flatten_size = flatten_size out_shape[i] = align_shape[i]
shape = t.shape if out_shape[i] != align_shape[i]:
raise GKException("shape broadcast failed!")
return out_shape
def _to_nz(self, default_shape):
"""default format shape to fractal_Nz format shape"""
if len(default_shape) not in (1, 2):
raise GKException("shape is too long!")
# (32) or (1, 32) -> (2, 1, 1, 16)
if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1):
shape = [default_shape[-1] // 16, 1, 1, 16]
if default_shape[-1] % 16 != 0:
raise GKException("should be multiplies of 16")
return shape return shape
#(32, 1) -> (1, 2, 16, 1)
if len(default_shape) == 2 and default_shape[1] == 1:
shape = [1, default_shape[0] // 16, 16, 1]
if default_shape[0] % 16 != 0:
raise GKException("should be multiples of 16")
return shape
#(32, 48) -> (3, 2, 16, 16)
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16]
if default_shape[0] % 16 != 0 or defautl_shape[1] % 16 != 0:
raise GKException("should be multiples of 16")
return shape
def _infer_shape(self):
"""returns the output shape with broadcast"""
# in case all inputs are default format/NHWC/NCHW
is_default = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for input in self.inputs]
if all(is_default):
return self._broadcast_shape([input.shape for input in self.inputs])
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ) \
for input in self.inputs]
if all(is_default_frac_nz):
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape \
for input in self.inputs]
return self._broadcast_shape(nz_shapes)
raise GKException("Only support default and fractal_nz")
def _infer_format(self): def _infer_format(self):
for tensor in self.inputs: for tensor in self.inputs:

View File

@ -56,6 +56,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimLogSoftmax, prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad, prim::kPrimLogSoftmaxGrad,
prim::kPrimTile, prim::kPrimTile,
prim::kPrimMatMul,
prim::kPrimBatchMatMul,
#if ENABLE_D #if ENABLE_D
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
prim::kPrimClipByNormNoDivSum, prim::kPrimClipByNormNoDivSum,

View File

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul(transpose_a=False, transpose_b=False)
def construct(self, x, y):
return self.matmul(x, y)
class Net1(Cell):
def __init__(self):
super(Net1, self).__init__()
self.bmm = P.BatchMatMul(transpose_a=False, transpose_b=False)
def construct(self, x, y):
return self.bmm(x, y)
def get_output(i0, i1, net_cls, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = net_cls()
output = net(i0, i1)
return output
def test_matmul():
i0 = Tensor(np.random.normal(1, 0.01, [96, 1]).astype(np.float32))
i1 = Tensor(np.random.normal(1, 0.01, [1, 128]).astype(np.float32))
expect = get_output(i0, i1, Net, False)
output = get_output(i0, i1, Net, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
def test_batchmatmul():
i0 = Tensor(np.random.normal(1, 0.01, [16, 96, 1]).astype(np.float32))
i1 = Tensor(np.random.normal(1, 0.01, [16, 1, 128]).astype(np.float32))
expect = get_output(i0, i1, Net1, False)
output = get_output(i0, i1, Net1, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_matmul_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_matmul()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_batchmatmul_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_batchmatmul()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_matmul_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_matmul()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchmatmul_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_batchmatmul()