!18679 call nnacl implement for matmul and batchmatmul on ARM
Merge pull request !18679 from zhangbuxue/call_nnacl_implement_for_matmul_and_batchmatmul_on_ARM
This commit is contained in:
commit
6d254ac409
|
@ -51,7 +51,7 @@ if(ENABLE_CPU)
|
|||
list(REMOVE_ITEM CPU_SRC_LIST ${qcc})
|
||||
endforeach()
|
||||
|
||||
if(ENABLE_CPU AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux"
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux"
|
||||
AND ${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message("compiled quantum kernel_compiler")
|
||||
set_property(SOURCE ${QUANTUM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
|
||||
|
@ -63,6 +63,10 @@ if(ENABLE_CPU)
|
|||
set(QUANTUM_SRC_LIST "")
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM64)
|
||||
add_compile_definitions(ENABLE_ARM)
|
||||
endif()
|
||||
|
||||
if("${ARM_SIMD}" STREQUAL "neon")
|
||||
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
|
||||
add_compile_definitions(ENABLE_NEON)
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batchmatmul_cpu_kernel.h"
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool BatchMatMulCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() < 2 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Batchmatmul error input output size!";
|
||||
}
|
||||
|
||||
if (batch_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Batchmatmul error batch size!";
|
||||
}
|
||||
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BatchMatMulCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
T *input_a = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *input_b = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
const int lda = (trans_a_ == TRANSPOSE_YES) ? SizeToInt(dim_m_) : SizeToInt(dim_k_);
|
||||
const int ldb = (trans_b_ == TRANSPOSE_YES) ? SizeToInt(dim_k_) : SizeToInt(dim_n_);
|
||||
const int ldc = dim_n_;
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
for (unsigned int i = 0; i < batch_; i++) {
|
||||
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda,
|
||||
input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_output_, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
void BatchMatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
if (src_shape.size() < 3 || weight_shape.size() < 3 || dst_shape.size() < 3) {
|
||||
MS_LOG(EXCEPTION) << "Batchmatmul invalid input size";
|
||||
}
|
||||
|
||||
auto dims = dst_shape.size();
|
||||
|
||||
dim_m_ = static_cast<dnnl_dim_t>(dst_shape[dims - 2]);
|
||||
dim_n_ = static_cast<dnnl_dim_t>(dst_shape[dims - 1]);
|
||||
|
||||
size_mat_a_ = src_shape[dims - 2] * src_shape[dims - 1];
|
||||
size_mat_b_ = weight_shape[dims - 2] * weight_shape[dims - 1];
|
||||
size_mat_output_ = dst_shape[dims - 2] * dst_shape[dims - 1];
|
||||
|
||||
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
|
||||
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
|
||||
|
||||
batch_ = 1;
|
||||
for (unsigned int i = 0; i < dst_shape.size() - 2; i++) {
|
||||
batch_ *= dst_shape[i];
|
||||
}
|
||||
|
||||
auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dim_k_ = static_cast<dnnl_dim_t>(trans_a ? input1_shape[dims - 2] : input1_shape[dims - 1]);
|
||||
|
||||
trans_a_ = trans_a ? TRANSPOSE_YES : TRANSPOSE_NO;
|
||||
trans_b_ = trans_b ? TRANSPOSE_YES : TRANSPOSE_NO;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,63 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BatchMatMulCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
BatchMatMulCPUKernel() = default;
|
||||
~BatchMatMulCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
char trans_a_{TRANSPOSE_NO};
|
||||
char trans_b_{TRANSPOSE_NO};
|
||||
dnnl_dim_t dim_m_{0};
|
||||
dnnl_dim_t dim_n_{0};
|
||||
dnnl_dim_t dim_k_{0};
|
||||
size_t batch_{0};
|
||||
size_t size_mat_a_{0};
|
||||
size_t size_mat_b_{0};
|
||||
size_t size_mat_output_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BatchMatMulCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat32),
|
||||
BatchMatMulCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCHMATMUL_CPU_KERNEL_H_
|
|
@ -14,56 +14,271 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "common/thread_pool.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
namespace {
|
||||
const size_t kIndexOffset = 2;
|
||||
}
|
||||
|
||||
if (src_shape.size() != 2 || weight_shape.size() != 2 || dst_shape.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Matmul invalid input size";
|
||||
void MatMulCPUKernel::InitTile() {
|
||||
#ifdef ENABLE_AVX
|
||||
row_tile_ = C6NUM;
|
||||
col_tile_ = C16NUM;
|
||||
#elif defined(ENABLE_ARM32)
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C4NUM;
|
||||
#elif defined(ENABLE_SSE)
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::InitMatrixA(const float *src_ptr) {
|
||||
const size_t size = param_.batch * param_.row_align_ * param_.deep_;
|
||||
a_pack_ptr_ = new float[size];
|
||||
if (a_pack_ptr_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Malloc a_pack_ptr_ failed.";
|
||||
}
|
||||
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
|
||||
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
|
||||
|
||||
if (vec_matmul_) {
|
||||
const size_t count = size * sizeof(float);
|
||||
if (memcpy_s(a_pack_ptr_, count, src_ptr, count) != EOK) {
|
||||
FreeBuffer();
|
||||
MS_LOG(EXCEPTION) << "Memcpy a_pack_ptr_ failed.";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < param_.batch; i++) {
|
||||
const float *src = src_ptr + i * param_.row_ * param_.deep_;
|
||||
float *dst = a_pack_ptr_ + i * param_.row_align_ * param_.deep_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (param_.a_transpose_) {
|
||||
RowMajor2Row6Major(src, dst, param_.deep_, param_.row_);
|
||||
} else {
|
||||
RowMajor2Col6Major(src, dst, param_.row_, param_.deep_);
|
||||
}
|
||||
#elif defined(ENABLE_SSE)
|
||||
if (param_.a_transpose_) {
|
||||
RowMajor2Row4Major(src, dst, param_.deep_, param_.row_);
|
||||
} else {
|
||||
RowMajor2Col4Major(src, dst, param_.row_, param_.deep_);
|
||||
}
|
||||
#else
|
||||
if (param_.a_transpose_) {
|
||||
RowMajor2Row12Major(src, dst, param_.deep_, param_.row_);
|
||||
} else {
|
||||
RowMajor2Col12Major(src, dst, param_.row_, param_.deep_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::InitMatrixB(const float *src_ptr) {
|
||||
const size_t size = param_.batch * param_.col_align_ * param_.deep_;
|
||||
b_pack_ptr_ = new float[size];
|
||||
if (b_pack_ptr_ == nullptr) {
|
||||
FreeBuffer();
|
||||
MS_LOG(EXCEPTION) << "Malloc b_pack_ptr_ failed";
|
||||
}
|
||||
if (vec_matmul_) {
|
||||
if (param_.b_transpose_) {
|
||||
const size_t count = size * sizeof(float);
|
||||
if (memcpy_s(b_pack_ptr_, count, src_ptr, count) != EOK) {
|
||||
FreeBuffer();
|
||||
MS_LOG(EXCEPTION) << "Memcpy b_pack_ptr_ failed.";
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < param_.batch; i++) {
|
||||
const float *src = src_ptr + i * param_.deep_ * param_.col_;
|
||||
float *dst = b_pack_ptr_ + i * param_.deep_ * param_.col_;
|
||||
RowMajor2ColMajor(src, dst, param_.deep_, param_.col_);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < param_.batch; i++) {
|
||||
const float *src = src_ptr + i * param_.deep_ * param_.col_;
|
||||
float *dst = b_pack_ptr_ + i * param_.deep_ * param_.col_align_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (param_.b_transpose_) {
|
||||
RowMajor2Col16Major(src, dst, param_.col_, param_.deep_);
|
||||
} else {
|
||||
RowMajor2Row16Major(src, dst, param_.deep_, param_.col_);
|
||||
}
|
||||
#elif defined(ENABLE_ARM32)
|
||||
if (param_.b_transpose_) {
|
||||
RowMajor2Col4Major(src, dst, param_.col_, param_.deep_);
|
||||
} else {
|
||||
RowMajor2Row4Major(src, dst, param_.deep_, param_.col_);
|
||||
}
|
||||
#else
|
||||
if (param_.b_transpose_) {
|
||||
RowMajor2Col8Major(src, dst, param_.col_, param_.deep_);
|
||||
} else {
|
||||
RowMajor2Row8Major(src, dst, param_.deep_, param_.col_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::InitArmKernel(bool trans_a, bool trans_b, const std::vector<size_t> &a_shape,
|
||||
const std::vector<size_t> &o_shape) {
|
||||
InitTile();
|
||||
param_.batch = SizeToInt(batch_);
|
||||
param_.a_transpose_ = trans_a;
|
||||
param_.b_transpose_ = trans_b;
|
||||
param_.row_ = SizeToInt(o_shape[rank_ - kIndexOffset]);
|
||||
param_.deep_ = SizeToInt(trans_a ? a_shape[rank_ - kIndexOffset] : a_shape[rank_ - 1]);
|
||||
param_.col_ = SizeToInt(o_shape[rank_ - 1]);
|
||||
vec_matmul_ = (param_.row_ == 1);
|
||||
param_.row_align_ = vec_matmul_ ? 1 : UP_ROUND(param_.row_, row_tile_);
|
||||
param_.col_align_ = vec_matmul_ ? param_.col_ : UP_ROUND(param_.col_, col_tile_);
|
||||
size_t max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
thread_count_ = MSMIN(max_thread_num, UP_DIV(param_.col_align_, col_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(param_.col_align_, col_tile_), thread_count_);
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::InitX64Kernel(bool trans_a, bool trans_b, const std::vector<size_t> &a_shape,
|
||||
const std::vector<size_t> &b_shape, const std::vector<size_t> &o_shape) {
|
||||
size_mat_a_ = a_shape[rank_ - kIndexOffset] * a_shape[rank_ - 1];
|
||||
size_mat_b_ = b_shape[rank_ - kIndexOffset] * b_shape[rank_ - 1];
|
||||
size_mat_o_ = o_shape[rank_ - kIndexOffset] * o_shape[rank_ - 1];
|
||||
if (trans_a) {
|
||||
trans_a_ = TRANSPOSE_YES;
|
||||
dim_m_ = static_cast<dnnl_dim_t>(src_shape[1]);
|
||||
dim_k_ = static_cast<dnnl_dim_t>(src_shape[0]);
|
||||
dim_k_ = static_cast<dnnl_dim_t>(a_shape[rank_ - kIndexOffset]);
|
||||
} else {
|
||||
dim_m_ = static_cast<dnnl_dim_t>(src_shape[0]);
|
||||
dim_k_ = static_cast<dnnl_dim_t>(src_shape[1]);
|
||||
dim_k_ = static_cast<dnnl_dim_t>(a_shape[rank_ - 1]);
|
||||
}
|
||||
if (trans_b) {
|
||||
trans_b_ = TRANSPOSE_YES;
|
||||
}
|
||||
dim_n_ = static_cast<dnnl_dim_t>(dst_shape[1]);
|
||||
dim_m_ = static_cast<dnnl_dim_t>(o_shape[rank_ - kIndexOffset]);
|
||||
dim_n_ = static_cast<dnnl_dim_t>(o_shape[rank_ - 1]);
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> a_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> b_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> o_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
const size_t rank_min = 2;
|
||||
if (a_shape.size() < rank_min || b_shape.size() < rank_min || o_shape.size() < rank_min) {
|
||||
MS_LOG(EXCEPTION) << "The tensor rank of MatMul should be greater than or equal to 2.";
|
||||
}
|
||||
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
|
||||
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
|
||||
rank_ = a_shape.size();
|
||||
batch_ = 1;
|
||||
for (size_t i = 0; i < rank_ - kIndexOffset; ++i) {
|
||||
batch_ *= a_shape[i];
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
InitArmKernel(trans_a, trans_b, a_shape, o_shape);
|
||||
#else
|
||||
InitX64Kernel(trans_a, trans_b, a_shape, b_shape, o_shape);
|
||||
#endif
|
||||
}
|
||||
|
||||
int MatMulCPUKernel::FloatRun(size_t task_id) {
|
||||
size_t current_stride_oc = thread_stride_ * col_tile_;
|
||||
if (IntToSize(param_.col_) <= task_id * current_stride_oc) {
|
||||
return common::SUCCESS;
|
||||
}
|
||||
|
||||
size_t current_rest_oc = IntToSize(param_.col_) - task_id * current_stride_oc;
|
||||
size_t cur_oc = MSMIN(current_stride_oc, current_rest_oc);
|
||||
auto b = batch_b_ptr_ + task_id * thread_stride_ * col_tile_ * IntToSize(param_.deep_);
|
||||
auto output = batch_o_ptr_ + task_id * thread_stride_ * col_tile_;
|
||||
float *bias = nullptr;
|
||||
if (vec_matmul_) {
|
||||
MatVecMulFp32(batch_a_ptr_, b, output, bias, param_.act_type_, param_.deep_, SizeToInt(cur_oc));
|
||||
} else {
|
||||
MatMulOpt(batch_a_ptr_, b, output, bias, param_.act_type_, param_.deep_, param_.row_, SizeToInt(cur_oc),
|
||||
param_.col_, OutType_Nhwc);
|
||||
}
|
||||
return common::SUCCESS;
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::ParallelRun(float *output) {
|
||||
for (int i = 0; i < param_.batch; ++i) {
|
||||
if (vec_matmul_) {
|
||||
batch_a_ptr_ = a_pack_ptr_ + i * param_.deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + i * param_.deep_ * param_.col_;
|
||||
batch_o_ptr_ = output + i * param_.row_ * param_.col_;
|
||||
} else {
|
||||
batch_a_ptr_ = a_pack_ptr_ + i * param_.row_align_ * param_.deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + i * param_.deep_ * param_.col_align_;
|
||||
batch_o_ptr_ = output + i * param_.row_ * param_.col_;
|
||||
}
|
||||
std::vector<common::Task> tasks;
|
||||
size_t thread_index = 0;
|
||||
while (thread_index < thread_count_) {
|
||||
tasks.push_back(std::bind(&MatMulCPUKernel::FloatRun, this, thread_index));
|
||||
thread_index++;
|
||||
}
|
||||
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::LaunchARM(const float *input_a, const float *input_b, float *output) {
|
||||
InitMatrixA(input_a);
|
||||
InitMatrixB(input_b);
|
||||
ParallelRun(output);
|
||||
FreeBuffer();
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::LaunchX64(const float *input_a, const float *input_b, float *output) {
|
||||
dnnl_dim_t lda = (trans_a_ == TRANSPOSE_YES ? dim_m_ : dim_k_);
|
||||
dnnl_dim_t ldb = (trans_b_ == TRANSPOSE_YES ? dim_k_ : dim_n_);
|
||||
dnnl_dim_t ldc = dim_n_;
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda,
|
||||
input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_o_, ldc);
|
||||
}
|
||||
}
|
||||
|
||||
bool MatMulCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() < 2 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Matmul error input output size!";
|
||||
MS_LOG(EXCEPTION) << "matmul error input output size!";
|
||||
}
|
||||
dnnl_dim_t lda = dim_m_;
|
||||
if (trans_a_ == TRANSPOSE_NO) {
|
||||
lda = dim_k_;
|
||||
}
|
||||
dnnl_dim_t ldb = dim_k_;
|
||||
if (trans_b_ == TRANSPOSE_NO) {
|
||||
ldb = dim_n_;
|
||||
}
|
||||
auto input_a = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto input_b = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
const auto input_a = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
const auto input_b = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, 1.f, input_a, lda, input_b, ldb, 0.f, output, dim_n_);
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
LaunchARM(input_a, input_b, output);
|
||||
#else
|
||||
LaunchX64(input_a, input_b, output);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
void MatMulCPUKernel::FreeBuffer() {
|
||||
if (a_pack_ptr_ != nullptr) {
|
||||
delete[] a_pack_ptr_;
|
||||
a_pack_ptr_ = nullptr;
|
||||
}
|
||||
if (b_pack_ptr_ != nullptr) {
|
||||
delete[] b_pack_ptr_;
|
||||
b_pack_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/matmul_parameter.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -33,17 +35,51 @@ class MatMulCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void InitTile();
|
||||
void InitMatrixA(const float *src_ptr);
|
||||
void InitMatrixB(const float *src_ptr);
|
||||
void InitArmKernel(bool trans_a, bool trans_b, const std::vector<size_t> &a_shape,
|
||||
const std::vector<size_t> &o_shape);
|
||||
void InitX64Kernel(bool trans_a, bool trans_b, const std::vector<size_t> &a_shape, const std::vector<size_t> &b_shape,
|
||||
const std::vector<size_t> &o_shape);
|
||||
void LaunchX64(const float *input_a, const float *input_b, float *output);
|
||||
void LaunchARM(const float *input_a, const float *input_b, float *output);
|
||||
void ParallelRun(float *output);
|
||||
int FloatRun(size_t task_id);
|
||||
void FreeBuffer();
|
||||
|
||||
char trans_a_{TRANSPOSE_NO};
|
||||
char trans_b_{TRANSPOSE_NO};
|
||||
dnnl_dim_t dim_m_{0};
|
||||
dnnl_dim_t dim_n_{0};
|
||||
dnnl_dim_t dim_k_{0};
|
||||
size_t batch_{0};
|
||||
size_t rank_{0};
|
||||
size_t row_tile_{0};
|
||||
size_t col_tile_{0};
|
||||
size_t thread_count_{0};
|
||||
size_t thread_stride_{0};
|
||||
size_t size_mat_a_{0};
|
||||
size_t size_mat_b_{0};
|
||||
size_t size_mat_o_{0};
|
||||
bool vec_matmul_{false};
|
||||
float *a_pack_ptr_{nullptr};
|
||||
float *b_pack_ptr_{nullptr};
|
||||
float *batch_a_ptr_{nullptr};
|
||||
float *batch_b_ptr_{nullptr};
|
||||
float *batch_o_ptr_{nullptr};
|
||||
MatMulParameter param_{};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MatMulCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MatMulCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// w5: depth
|
||||
// w6: col
|
||||
|
||||
asm_function MatVecMulFp32
|
||||
asm_default_function MatVecMulFp32
|
||||
sub sp, sp, #128
|
||||
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64
|
||||
|
|
|
@ -31,6 +31,20 @@ _\fname:
|
|||
#endif
|
||||
.endm
|
||||
|
||||
// clang-format off
|
||||
.macro asm_default_function fname
|
||||
#ifdef __APPLE__
|
||||
.globl _\fname
|
||||
_\fname:
|
||||
#else
|
||||
.global \fname
|
||||
#ifdef __ELF__
|
||||
.type \fname, %function
|
||||
#endif
|
||||
\fname:
|
||||
#endif
|
||||
.endm
|
||||
|
||||
// clang-format on
|
||||
|
||||
#endif // MINDSPORE_NNACL_ASSEMBLY_GLOBAL_H
|
||||
|
|
|
@ -466,14 +466,14 @@ void DumpJsonParser::JudgeDumpEnabled() {
|
|||
if (!async_dump_enabled_ && !e2e_dump_enabled_) {
|
||||
MS_LOG(WARNING) << "Dump json parse failed. Dump not enabled";
|
||||
}
|
||||
|
||||
auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
if (support_devices_.find(device_id) == support_devices_.end()) {
|
||||
async_dump_enabled_ = false;
|
||||
e2e_dump_enabled_ = false;
|
||||
MS_LOG(WARNING) << "Dump not enabled. device_id:" << device_id << " not support";
|
||||
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
|
||||
auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
if (support_devices_.find(device_id) == support_devices_.end()) {
|
||||
async_dump_enabled_ = false;
|
||||
e2e_dump_enabled_ = false;
|
||||
MS_LOG(WARNING) << "Dump not enabled. device_id:" << device_id << " not support";
|
||||
}
|
||||
}
|
||||
|
||||
JsonConfigToString();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,9 +23,6 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMulNet, self).__init__()
|
||||
|
@ -35,86 +32,128 @@ class BatchMatMulNet(nn.Cell):
|
|||
return self.batch_matmul(x, y)
|
||||
|
||||
|
||||
def test_4d():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
def judge_result_correct(result, expect):
|
||||
assert result.dtype == expect.dtype
|
||||
assert result.shape == expect.shape
|
||||
assert np.allclose(result, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_no_transpose_vec():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape((2, 4, 1, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape((2, 4, 3, 4)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
expect = np.array([[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_no_transpose():
|
||||
input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[[[20., 23., 26., 29.],
|
||||
[56., 68., 80., 92.]],
|
||||
[[344., 365., 386., 407.],
|
||||
[488., 518., 548., 578.]],
|
||||
[[1100., 1139., 1178., 1217.],
|
||||
[1352., 1400., 1448., 1496.]]],
|
||||
[[[2288., 2345., 2402., 2459.],
|
||||
[2648., 2714., 2780., 2846.]],
|
||||
[[3908., 3983., 4058., 4133.],
|
||||
[4376., 4460., 4544., 4628.]],
|
||||
[[5960., 6053., 6146., 6239.],
|
||||
[6536., 6638., 6740., 6842.]]]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_a():
|
||||
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 3 * 3 * 4).reshape((2, 3, 3, 4)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_a=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
expect = np.array([[[[40., 46., 52., 58.],
|
||||
[52., 61., 70., 79.]],
|
||||
[[400., 424., 448., 472.],
|
||||
[448., 475., 502., 529.]],
|
||||
[[1192., 1234., 1276., 1318.],
|
||||
[1276., 1321., 1366., 1411.]]],
|
||||
[[[2416., 2476., 2536., 2596.],
|
||||
[2536., 2599., 2662., 2725.]],
|
||||
[[4072., 4150., 4228., 4306.],
|
||||
[4228., 4309., 4390., 4471.]],
|
||||
[[6160., 6256., 6352., 6448.],
|
||||
[6352., 6451., 6550., 6649.]]]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_b():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
|
||||
input_x = Tensor(np.arange(2 * 3 * 2 * 3).reshape((2, 3, 2, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[5, 14, 23, 32]],
|
||||
[[158, 194, 230, 266]],
|
||||
[[527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
expect = np.array([[[[5.000e+00, 1.400e+01, 2.300e+01, 3.200e+01],
|
||||
[1.400e+01, 5.000e+01, 8.600e+01, 1.220e+02]],
|
||||
[[2.750e+02, 3.380e+02, 4.010e+02, 4.640e+02],
|
||||
[3.920e+02, 4.820e+02, 5.720e+02, 6.620e+02]],
|
||||
[[9.770e+02, 1.094e+03, 1.211e+03, 1.328e+03],
|
||||
[1.202e+03, 1.346e+03, 1.490e+03, 1.634e+03]]],
|
||||
[[[2.111e+03, 2.282e+03, 2.453e+03, 2.624e+03],
|
||||
[2.444e+03, 2.642e+03, 2.840e+03, 3.038e+03]],
|
||||
[[3.677e+03, 3.902e+03, 4.127e+03, 4.352e+03],
|
||||
[4.118e+03, 4.370e+03, 4.622e+03, 4.874e+03]],
|
||||
[[5.675e+03, 5.954e+03, 6.233e+03, 6.512e+03],
|
||||
[6.224e+03, 6.530e+03, 6.836e+03, 7.142e+03]]]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_transpose_ab():
|
||||
input_x = Tensor(np.arange(2 * 4 * 3 * 1).reshape(2, 4, 3, 1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), mstype.float32)
|
||||
input_x = Tensor(np.arange(2 * 3 * 3 * 2).reshape((2, 3, 3, 2)), mstype.float16)
|
||||
input_y = Tensor(np.arange(2 * 3 * 4 * 3).reshape((2, 3, 4, 3)), mstype.float16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[5, 14, 23, 32]],
|
||||
[[158, 194, 230, 266]],
|
||||
[[527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
expect = np.array([[[[10., 28., 46., 64.],
|
||||
[13., 40., 67., 94.]],
|
||||
[[316., 388., 460., 532.],
|
||||
[355., 436., 517., 598.]],
|
||||
[[1054., 1180., 1306., 1432.],
|
||||
[1129., 1264., 1399., 1534.]]],
|
||||
[[[2224., 2404., 2584., 2764.],
|
||||
[2335., 2524., 2713., 2902.]],
|
||||
[[3826., 4060., 4294., 4528.],
|
||||
[3973., 4216., 4459., 4702.]],
|
||||
[[5860., 6148., 6436., 6724.],
|
||||
[6043., 6340., 6637., 6934.]]]], np.float16)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class MatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(MatMulNet, self).__init__()
|
||||
self.matmul = P.MatMul(transpose_a, transpose_b)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.matmul(x, y)
|
||||
|
||||
def judge_result_correct(result, expect):
|
||||
assert result.dtype == expect.dtype
|
||||
assert result.shape == expect.shape
|
||||
assert np.allclose(result, expect)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_no_transpose_vec():
|
||||
input_x = Tensor(np.arange(1 * 3).reshape((1, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
net = MatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[25., 28., 31., 34., 37.]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_no_transpose():
|
||||
input_x = Tensor(np.arange(4 * 3).reshape((4, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
net = MatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[25., 28., 31., 34., 37.],
|
||||
[70., 82., 94., 106., 118.],
|
||||
[115., 136., 157., 178., 199.],
|
||||
[160., 190., 220., 250., 280.]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_transpose_a():
|
||||
input_x = Tensor(np.arange(3 * 2).reshape((3, 2)), mstype.float32)
|
||||
input_y = Tensor(np.arange(3 * 4).reshape((3, 4)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = MatMulNet(transpose_a=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[40., 46., 52., 58.],
|
||||
[52., 61., 70., 79.]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_transpose_b():
|
||||
input_x = Tensor(np.arange(2 * 3).reshape((2, 3)), mstype.float32)
|
||||
input_y = Tensor(np.arange(5 * 3).reshape((5, 3)), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = MatMulNet(transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[5., 14., 23., 32., 41.],
|
||||
[14., 50., 86., 122., 158.]], dtype=np.float32)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_transpose_ab():
|
||||
input_x = Tensor(np.arange(3 * 5).reshape((3, 5)), mstype.float16)
|
||||
input_y = Tensor(np.arange(4 * 3).reshape((4, 3)), mstype.float16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = MatMulNet(transpose_a=True, transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = np.array([[25., 70., 115., 160.],
|
||||
[28., 82., 136., 190.],
|
||||
[31., 94., 157., 220.],
|
||||
[34., 106., 178., 250.],
|
||||
[37., 118., 199., 280.]], dtype=np.float16)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
Loading…
Reference in New Issue