batchmatmul

This commit is contained in:
x00540480 2020-12-23 10:18:08 +08:00
parent df85dcc3ad
commit 69f882c083
4 changed files with 276 additions and 1 deletions

View File

@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h"
#include <utility>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "utils/ms_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
bool BatchMatMulCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() < 2 || outputs.empty()) {
MS_LOG(EXCEPTION) << "batchmatmul error input output size!";
}
if (batch_ == 0) {
MS_LOG(EXCEPTION) << "batchmatmul error batch size!";
}
LaunchKernel<float>(inputs, outputs);
return true;
}
template <typename T>
void BatchMatMulCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
T *input_a = reinterpret_cast<T *>(inputs[0]->addr);
T *input_b = reinterpret_cast<T *>(inputs[1]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
const int lda = (trans_a_ == TRANSPOSE_YES) ? SizeToInt(dim_m_) : SizeToInt(dim_k_);
const int ldb = (trans_b_ == TRANSPOSE_YES) ? SizeToInt(dim_k_) : SizeToInt(dim_n_);
const int ldc = dim_n_;
const float alpha = 1;
const float beta = 0;
for (unsigned int i = 0; i < batch_; i++) {
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda,
input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_output_, ldc);
}
}
void BatchMatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (src_shape.size() < 3 || weight_shape.size() < 3 || dst_shape.size() < 3) {
MS_LOG(EXCEPTION) << "batchmatmul invalid input size";
}
auto dims = dst_shape.size();
dim_m_ = static_cast<dnnl_dim_t>(dst_shape[dims - 2]);
dim_n_ = static_cast<dnnl_dim_t>(dst_shape[dims - 1]);
size_mat_a_ = src_shape[dims - 2] * src_shape[dims - 1];
size_mat_b_ = weight_shape[dims - 2] * weight_shape[dims - 1];
size_mat_output_ = dst_shape[dims - 2] * dst_shape[dims - 1];
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
batch_ = 1;
for (unsigned int i = 0; i < dst_shape.size() - 2; i++) {
batch_ *= dst_shape[i];
}
auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dim_k_ = trans_a ? input1_shape[dims - 2] : input1_shape[dims - 1];
trans_a_ = trans_a ? TRANSPOSE_YES : TRANSPOSE_NO;
trans_b_ = trans_b ? TRANSPOSE_YES : TRANSPOSE_NO;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class BatchMatMulCPUKernel : public MKLCPUKernel {
public:
BatchMatMulCPUKernel() = default;
~BatchMatMulCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
char trans_a_{TRANSPOSE_NO};
char trans_b_{TRANSPOSE_NO};
dnnl_dim_t dim_m_{0};
dnnl_dim_t dim_n_{0};
dnnl_dim_t dim_k_{0};
size_t batch_{0};
size_t size_mat_a_{0};
size_t size_mat_b_{0};
size_t size_mat_output_{0};
};
MS_REG_CPU_KERNEL(
BatchMatMul,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BatchMatMulCPUKernel);
MS_REG_CPU_KERNEL(
BatchMatMul,
KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat32),
BatchMatMulCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_

View File

@ -791,7 +791,7 @@ class BatchMatMul(MatMul):
Tensor, the shape of the output tensor is :math:`(*B, N, M)`. Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32) >>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)

View File

@ -0,0 +1,120 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
class BatchMatMulNet(nn.Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(BatchMatMulNet, self).__init__()
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
def construct(self, x, y):
return self.batch_matmul(x, y)
def test_4d():
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = BatchMatMulNet()
output = net(input_x, input_y)
expect = [[[[20, 23, 26, 29]],
[[200, 212, 224, 236]],
[[596, 617, 638, 659]],
[[1208, 1238, 1268, 1298]]],
[[[2036, 2075, 2114, 2153]],
[[3080, 3128, 3176, 3224]],
[[4340, 4397, 4454, 4511]],
[[5816, 5882, 5948, 6014]]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_4d_transpose_a():
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = BatchMatMulNet(transpose_a=True)
output = net(input_x, input_y)
expect = [[[[20, 23, 26, 29]],
[[200, 212, 224, 236]],
[[596, 617, 638, 659]],
[[1208, 1238, 1268, 1298]]],
[[[2036, 2075, 2114, 2153]],
[[3080, 3128, 3176, 3224]],
[[4340, 4397, 4454, 4511]],
[[5816, 5882, 5948, 6014]]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_4d_transpose_b():
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = BatchMatMulNet(transpose_b=True)
output = net(input_x, input_y)
expect = [[[[5, 14, 23, 32]],
[[158, 194, 230, 266]],
[[527, 590, 653, 716]],
[[1112, 1202, 1292, 1382]]],
[[[1913, 2030, 2147, 2264]],
[[2930, 3074, 3218, 3362]],
[[4163, 4334, 4505, 4676]],
[[5612, 5810, 6008, 6206]]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_4d_transpose_ab():
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
output = net(input_x, input_y)
expect = [[[[5, 14, 23, 32]],
[[158, 194, 230, 266]],
[[527, 590, 653, 716]],
[[1112, 1202, 1292, 1382]]],
[[[1913, 2030, 2147, 2264]],
[[2930, 3074, 3218, 3362]],
[[4163, 4334, 4505, 4676]],
[[5612, 5810, 6008, 6206]]]]
assert (output.asnumpy() == expect).all()