forked from mindspore-Ecosystem/mindspore
batchmatmul
This commit is contained in:
parent
df85dcc3ad
commit
69f882c083
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h"
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool BatchMatMulCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() < 2 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "batchmatmul error input output size!";
|
||||
}
|
||||
|
||||
if (batch_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "batchmatmul error batch size!";
|
||||
}
|
||||
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BatchMatMulCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
T *input_a = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *input_b = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
const int lda = (trans_a_ == TRANSPOSE_YES) ? SizeToInt(dim_m_) : SizeToInt(dim_k_);
|
||||
const int ldb = (trans_b_ == TRANSPOSE_YES) ? SizeToInt(dim_k_) : SizeToInt(dim_n_);
|
||||
const int ldc = dim_n_;
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
for (unsigned int i = 0; i < batch_; i++) {
|
||||
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda,
|
||||
input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_output_, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
void BatchMatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
if (src_shape.size() < 3 || weight_shape.size() < 3 || dst_shape.size() < 3) {
|
||||
MS_LOG(EXCEPTION) << "batchmatmul invalid input size";
|
||||
}
|
||||
|
||||
auto dims = dst_shape.size();
|
||||
|
||||
dim_m_ = static_cast<dnnl_dim_t>(dst_shape[dims - 2]);
|
||||
dim_n_ = static_cast<dnnl_dim_t>(dst_shape[dims - 1]);
|
||||
|
||||
size_mat_a_ = src_shape[dims - 2] * src_shape[dims - 1];
|
||||
size_mat_b_ = weight_shape[dims - 2] * weight_shape[dims - 1];
|
||||
size_mat_output_ = dst_shape[dims - 2] * dst_shape[dims - 1];
|
||||
|
||||
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
|
||||
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
|
||||
|
||||
batch_ = 1;
|
||||
for (unsigned int i = 0; i < dst_shape.size() - 2; i++) {
|
||||
batch_ *= dst_shape[i];
|
||||
}
|
||||
|
||||
auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dim_k_ = trans_a ? input1_shape[dims - 2] : input1_shape[dims - 1];
|
||||
|
||||
trans_a_ = trans_a ? TRANSPOSE_YES : TRANSPOSE_NO;
|
||||
trans_b_ = trans_b ? TRANSPOSE_YES : TRANSPOSE_NO;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BatchMatMulCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
BatchMatMulCPUKernel() = default;
|
||||
~BatchMatMulCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
char trans_a_{TRANSPOSE_NO};
|
||||
char trans_b_{TRANSPOSE_NO};
|
||||
dnnl_dim_t dim_m_{0};
|
||||
dnnl_dim_t dim_n_{0};
|
||||
dnnl_dim_t dim_k_{0};
|
||||
size_t batch_{0};
|
||||
size_t size_mat_a_{0};
|
||||
size_t size_mat_b_{0};
|
||||
size_t size_mat_output_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BatchMatMulCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat32),
|
||||
BatchMatMulCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
|
@ -791,7 +791,7 @@ class BatchMatMul(MatMul):
|
|||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMulNet, self).__init__()
|
||||
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
|
||||
def test_4d():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_a():
|
||||
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_a=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_b():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[5, 14, 23, 32]],
|
||||
[[158, 194, 230, 266]],
|
||||
[[527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_ab():
|
||||
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[5, 14, 23, 32]],
|
||||
[[158, 194, 230, 266]],
|
||||
[[527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue