From 69f882c0838cc4d36662d4d8920c27ccd47d9bd8 Mon Sep 17 00:00:00 2001 From: x00540480 Date: Wed, 23 Dec 2020 10:18:08 +0800 Subject: [PATCH] batchmatmul --- .../cpu/mkldnn/batchmatmul_cpu_kernel.cc | 92 ++++++++++++++ .../cpu/mkldnn/batchmatmul_cpu_kernel.h | 63 +++++++++ mindspore/ops/operations/math_ops.py | 2 +- tests/st/ops/cpu/test_batch_matmul.py | 120 ++++++++++++++++++ 4 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_batch_matmul.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.cc new file mode 100644 index 00000000000..cca3d889a51 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h" +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "utils/ms_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +bool BatchMatMulCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "batchmatmul error input output size!"; + } + + if (batch_ == 0) { + MS_LOG(EXCEPTION) << "batchmatmul error batch size!"; + } + + LaunchKernel(inputs, outputs); + + return true; +} + +template +void BatchMatMulCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + T *input_a = reinterpret_cast(inputs[0]->addr); + T *input_b = reinterpret_cast(inputs[1]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + + const int lda = (trans_a_ == TRANSPOSE_YES) ? SizeToInt(dim_m_) : SizeToInt(dim_k_); + const int ldb = (trans_b_ == TRANSPOSE_YES) ? SizeToInt(dim_k_) : SizeToInt(dim_n_); + const int ldc = dim_n_; + + const float alpha = 1; + const float beta = 0; + + for (unsigned int i = 0; i < batch_; i++) { + (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda, + input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_output_, ldc); + } +} + +void BatchMatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + if (src_shape.size() < 3 || weight_shape.size() < 3 || dst_shape.size() < 3) { + MS_LOG(EXCEPTION) << "batchmatmul invalid input size"; + } + + auto dims = dst_shape.size(); + + dim_m_ = static_cast(dst_shape[dims - 2]); + dim_n_ = static_cast(dst_shape[dims - 1]); + + size_mat_a_ = src_shape[dims - 2] * src_shape[dims - 1]; + size_mat_b_ = weight_shape[dims - 2] * weight_shape[dims - 1]; + size_mat_output_ = dst_shape[dims - 2] * dst_shape[dims - 1]; + + bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); + bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); + + batch_ = 1; + for (unsigned int i = 0; i < dst_shape.size() - 2; i++) { + batch_ *= dst_shape[i]; + } + + auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + dim_k_ = trans_a ? input1_shape[dims - 2] : input1_shape[dims - 1]; + + trans_a_ = trans_a ? TRANSPOSE_YES : TRANSPOSE_NO; + trans_b_ = trans_b ? TRANSPOSE_YES : TRANSPOSE_NO; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h new file mode 100644 index 00000000000..e29a7200288 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class BatchMatMulCPUKernel : public MKLCPUKernel { + public: + BatchMatMulCPUKernel() = default; + ~BatchMatMulCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + char trans_a_{TRANSPOSE_NO}; + char trans_b_{TRANSPOSE_NO}; + dnnl_dim_t dim_m_{0}; + dnnl_dim_t dim_n_{0}; + dnnl_dim_t dim_k_{0}; + size_t batch_{0}; + size_t size_mat_a_{0}; + size_t size_mat_b_{0}; + size_t size_mat_output_{0}; +}; + +MS_REG_CPU_KERNEL( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BatchMatMulCPUKernel); + +MS_REG_CPU_KERNEL( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat32), + BatchMatMulCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 726b67df6b7..9211fdca811 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -791,7 +791,7 @@ class BatchMatMul(MatMul): Tensor, the shape of the output tensor is :math:`(*B, N, M)`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_batch_matmul.py b/tests/st/ops/cpu/test_batch_matmul.py new file mode 100644 index 00000000000..13e81697ec4 --- /dev/null +++ b/tests/st/ops/cpu/test_batch_matmul.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +class BatchMatMulNet(nn.Cell): + def __init__(self, transpose_a=False, transpose_b=False): + super(BatchMatMulNet, self).__init__() + self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b) + + def construct(self, x, y): + return self.batch_matmul(x, y) + + +def test_4d(): + input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) + input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = BatchMatMulNet() + output = net(input_x, input_y) + expect = [[[[20, 23, 26, 29]], + [[200, 212, 224, 236]], + [[596, 617, 638, 659]], + [[1208, 1238, 1268, 1298]]], + + [[[2036, 2075, 2114, 2153]], + [[3080, 3128, 3176, 3224]], + [[4340, 4397, 4454, 4511]], + [[5816, 5882, 5948, 6014]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_4d_transpose_a(): + input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32) + input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + net = BatchMatMulNet(transpose_a=True) + output = net(input_x, input_y) + expect = [[[[20, 23, 26, 29]], + [[200, 212, 224, 236]], + [[596, 617, 638, 659]], + [[1208, 1238, 1268, 1298]]], + + [[[2036, 2075, 2114, 2153]], + [[3080, 3128, 3176, 3224]], + [[4340, 4397, 4454, 4511]], + [[5816, 5882, 5948, 6014]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_4d_transpose_b(): + input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) + input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + net = BatchMatMulNet(transpose_b=True) + output = net(input_x, input_y) + expect = [[[[5, 14, 23, 32]], + [[158, 194, 230, 266]], + [[527, 590, 653, 716]], + [[1112, 1202, 1292, 1382]]], + + [[[1913, 2030, 2147, 2264]], + [[2930, 3074, 3218, 3362]], + [[4163, 4334, 4505, 4676]], + [[5612, 5810, 6008, 6206]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_4d_transpose_ab(): + input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32) + input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + net = BatchMatMulNet(transpose_a=True, transpose_b=True) + output = net(input_x, input_y) + expect = [[[[5, 14, 23, 32]], + [[158, 194, 230, 266]], + [[527, 590, 653, 716]], + [[1112, 1202, 1292, 1382]]], + + [[[1913, 2030, 2147, 2264]], + [[2930, 3074, 3218, 3362]], + [[4163, 4334, 4505, 4676]], + [[5612, 5810, 6008, 6206]]]] + assert (output.asnumpy() == expect).all()