!27717 optimize cpu ops

Merge pull request !27717 from 范吉斌/cpu_ops_optimize
This commit is contained in:
i-robot 2021-12-20 16:05:10 +00:00 committed by Gitee
commit b421e46957
14 changed files with 184 additions and 269 deletions

View File

@ -32,137 +32,91 @@ constexpr size_t kOutputsNum = 1;
} // namespace
template <typename T>
void ArithmeticLogicCPUKernel<T>::Less(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
if (output_size_ > kMaxLessSerialSize) {
auto task = [&](size_t start, size_t end) {
template <typename Op>
void ArithmeticLogicCPUKernel<T>::BinaryOp(const T *input1, const T *input2, bool *out, Op op) {
size_t input1_size = 1;
size_t input2_size = 2;
for (size_t i = 0; i < output_shape_.size(); i++) {
input1_size *= input_shape1_[i];
input2_size *= input_shape2_[i];
}
if (input_shape1_ == input_shape2_) {
auto task = [this, input1, input2, out, op](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = op(input1[i], input2[i]);
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} else if (input1_size == 1) {
auto task = [this, input1, input2, out, op](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = op(input1[0], input2[i]);
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} else if (input2_size == 1) {
auto task = [this, input1, input2, out, op](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = op(input1[i], input2[0]);
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} else {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [this, input1, input2, out, op, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::less<T>()(x, y);
out[i] = op(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} else {
base_iter.SetPos(0);
for (size_t i = 0; i < output_size_; i++) {
auto x = input1[base_iter.GetInputPosA()];
auto y = input2[base_iter.GetInputPosB()];
out[i] = std::less<T>()(x, y);
base_iter.GenNextPos();
}
}
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::Less(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::less<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::Equal(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::equal_to<T>()(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::equal_to<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::NotEqual(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::not_equal_to<T>()(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::not_equal_to<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::LogicalAnd(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
out[i] = input1[iter.GetInputPosA()] && input2[iter.GetInputPosB()];
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::logical_and<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::LogicalOr(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
out[i] = input1[iter.GetInputPosA()] || input2[iter.GetInputPosB()];
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::logical_or<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::Greater(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::greater<T>()(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::greater<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::GreaterEqual(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::greater_equal<T>()(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::greater_equal<T>());
}
template <typename T>
void ArithmeticLogicCPUKernel<T>::LessEqual(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto x = input1[iter.GetInputPosA()];
auto y = input2[iter.GetInputPosB()];
out[i] = std::less_equal<T>()(x, y);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
BinaryOp(input1, input2, out, std::less_equal<T>());
}
template <typename T>

View File

@ -39,6 +39,9 @@ class ArithmeticLogicCPUKernel : public CPUKernel {
private:
void InitComputeFunc();
template <typename Op>
void BinaryOp(const T *input1, const T *input2, bool *out, Op op);
void Less(const T *input1, const T *input2, bool *out);
void Equal(const T *input1, const T *input2, bool *out);
void NotEqual(const T *input1, const T *input2, bool *out);

View File

@ -156,6 +156,8 @@ MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),

View File

@ -74,46 +74,26 @@ bool AssignCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::v
<< "', memcpy size must be less than or equal to max size, but got memcpy size: " << total_size
<< ", and max size: " << max_size;
}
constexpr size_t kBlockSize = 10000;
size_t thread_num = (total_size + kBlockSize - 1) / kBlockSize;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
thread_num = thread_num > max_thread_num ? max_thread_num : thread_num;
if (thread_num == 0) {
return true;
}
size_t stride = total_size / thread_num;
std::vector<common::Task> tasks;
size_t thread_index = 0;
auto input0_addr = reinterpret_cast<int8_t *>(inputs[0]->addr);
auto input1_addr = reinterpret_cast<int8_t *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
size_t length = stride;
while (thread_index < thread_num) {
auto thread_stride = stride * thread_index;
size_t max_length = total_size - thread_stride;
if (thread_index == thread_num - 1) {
length = max_length;
auto task = [&](size_t start, size_t end) {
int8_t *input0 = input0_addr + start;
int8_t *input1 = input1_addr + start;
int8_t *output = output_addr + start;
size_t length = end - start;
size_t max_length = total_size - start;
int ret = memcpy_s(input0, max_length, input1, length);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret;
}
int8_t *input0 = input0_addr + thread_stride;
int8_t *input1 = input1_addr + thread_stride;
int8_t *output = output_addr + thread_stride;
auto block = [input0, input1, output, max_length, length]() {
int ret = memcpy_s(input0, max_length, input1, length);
if (ret != 0) {
MS_LOG(ERROR) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret;
return common::FAIL;
}
ret = memcpy_s(output, max_length, input1, length);
if (ret != 0) {
MS_LOG(ERROR) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret;
return common::FAIL;
}
return common::SUCCESS;
};
(void)tasks.emplace_back(block);
thread_index++;
}
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
ret = memcpy_s(output, max_length, input1, length);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret;
}
};
ParallelLaunchAutoSearch(task, total_size, this, &parallel_search_info_);
return true;
}
} // namespace kernel

View File

@ -106,31 +106,12 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
const auto *input_addr = reinterpret_cast<float *>(inputs[0]->addr);
const auto *indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto *output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t thread_num = indices_lens_ / kBlockSize + 1;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
thread_num = thread_num > max_thread_num ? max_thread_num : thread_num;
std::vector<common::Task> tasks;
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;
MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens;
for (i = 0; i < thread_num; i++) {
if (task_offset >= indices_lens_) {
break;
}
MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens;
auto task = [input_addr, indices_addr, output_addr, task_offset, task_proc_lens, this]() {
LookUpTableTask<T>(input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_,
task_proc_lens, outer_dim_size_, static_cast<T>(offset_), first_dim_size_, kernel_name_);
return common::SUCCESS;
};
(void)tasks.emplace_back(task);
task_offset += task_proc_lens;
if (task_offset + task_proc_lens > indices_lens_) {
task_proc_lens = indices_lens_ - task_offset;
}
}
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
auto task = [&](size_t start, size_t end) {
size_t task_proc_lens = end - start;
LookUpTableTask<T>(input_addr, indices_addr + start, output_addr + start * outer_dim_size_, task_proc_lens,
outer_dim_size_, static_cast<T>(offset_), first_dim_size_, kernel_name_);
};
ParallelLaunchAutoSearch(task, indices_lens_, this, &parallel_search_info_);
}
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -66,14 +66,6 @@ bool GatherV2CPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
axis_ = axis_ + dims;
}
int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum());
ParallelRun(input_tensor, indices_data, output_addr, max_thread_num);
return true;
}
template <typename T>
void GatherV2CPUKernel<T>::ParallelRun(const int8_t *input_addr, const int *indices_data, int8_t *output_addr,
int thread_num) {
size_t outer_size = 1, inner_size = 1;
auto axis = static_cast<size_t>(axis_);
for (size_t i = 0; i < axis; ++i) {
@ -87,32 +79,17 @@ void GatherV2CPUKernel<T>::ParallelRun(const int8_t *input_addr, const int *indi
indices_element_size *= indices_shape_.at(i);
}
auto limit = input_shape_.at(axis);
size_t stride = UP_DIV(outer_size, IntToSize(thread_num));
std::vector<common::Task> tasks;
int thread_index = 0;
while (thread_index < thread_num) {
int count = MSMIN(SizeToInt(stride), SizeToInt(outer_size) - SizeToInt(stride) * thread_index);
if (count <= 0) {
break;
auto task = [&](size_t start, size_t end) {
int count = SizeToInt(end - start);
const int8_t *in = input_tensor + start * limit * inner_size * sizeof(T);
int8_t *out = output_addr + start * indices_element_size * inner_size * sizeof(T);
int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T));
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
}
auto thread_stride = static_cast<size_t>(stride * thread_index);
const int8_t *in = input_addr + thread_stride * limit * inner_size * sizeof(T);
int8_t *out = output_addr + thread_stride * indices_element_size * inner_size * sizeof(T);
auto block = [this, in, indices_data, count, inner_size, limit, indices_element_size, out, thread_index]() {
int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T));
if (ret != 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', run error task_id[" << thread_index << "] error_code[" << ret
<< "]";
return common::FAIL;
}
return common::SUCCESS;
};
(void)tasks.emplace_back(block);
thread_index++;
}
if (!common::ThreadPool::GetInstance().SyncRun(tasks)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', syncRun error.";
}
};
ParallelLaunchAutoSearch(task, outer_size, this, &parallel_search_info_);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -37,7 +37,6 @@ class GatherV2CPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void ParallelRun(const int8_t *input_addr, const int *indices_data, int8_t *output_addr, int thread_num);
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;

View File

@ -35,6 +35,13 @@ void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
}
}
void AddFloat(const float *in_0, const float *in_1, float *out, int start, int end) {
int ret = ElementAdd(in_0 + start, in_1 + start, out + start, end - start);
if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "Add failed.";
}
}
void AddDouble(const double *in0, const double *in1, double *out, int start, int end) {
for (int index = start; index < end; index++) {
out[index] = in0[index] + in1[index];
@ -81,35 +88,43 @@ bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const
ExecutePrimitive();
}
} else if (dtype_ == kNumberTypeInt32) {
size_t elements_num = outputs[0]->size / sizeof(int);
const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
auto output = reinterpret_cast<int *>(outputs[0]->addr);
auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task_0, elements_num, this, &parallel_search_info_);
for (size_t index = 2; index < input_num_; ++index) {
const auto input = reinterpret_cast<int *>(inputs[index]->addr);
auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task, elements_num, this, &parallel_search_info_);
}
LaunchNnacl<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
size_t elements_num = outputs[0]->size / sizeof(double);
const auto input_0 = reinterpret_cast<double *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<double *>(inputs[1]->addr);
auto output = reinterpret_cast<double *>(outputs[0]->addr);
auto task_0 = std::bind(AddDouble, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task_0, elements_num);
for (size_t index = 2; index < input_num_; ++index) {
const auto input = reinterpret_cast<double *>(inputs[index]->addr);
auto task = std::bind(AddDouble, input, output, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task, elements_num);
}
LaunchNnacl<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "AddN only support float32, float64 and int32, but got " << TypeIdToType(dtype_)->ToString();
}
return true;
}
template <typename T>
void AddNCPUKernel::LaunchNnacl(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
std::function<void(const T *, const T *, T *, int, int)> m_func;
if constexpr (std::is_same<T, float>::value) {
m_func = AddFloat;
} else if constexpr (std::is_same<T, int>::value) {
m_func = AddInt;
} else if constexpr (std::is_same<T, double>::value) {
m_func = AddDouble;
} else {
MS_LOG(EXCEPTION) << "AddN only support float32, float64 and int32, but got " << TypeIdToType(dtype_)->ToString();
}
size_t elements_num = outputs[0]->size / sizeof(T);
const auto input_0 = reinterpret_cast<T *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<T *>(inputs[1]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
auto task_0 = std::bind(m_func, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task_0, elements_num, this, &parallel_search_info_);
const size_t iter_start = 2;
for (size_t index = iter_start; index < input_num_; ++index) {
const auto input = reinterpret_cast<T *>(inputs[index]->addr);
auto task = std::bind(m_func, input, output, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task, elements_num, this, &parallel_search_info_);
}
}
void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);

View File

@ -33,6 +33,8 @@ class AddNCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchNnacl(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void CheckParam(const CNodePtr &kernel_node);
size_t input_num_{0};
std::vector<size_t> output_shape_;

View File

@ -45,8 +45,6 @@ MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA
EltWiseCPUKernel);
MS_REG_CPU_KERNEL(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -19,6 +19,8 @@
#include "common/thread_pool.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
#include "backend/kernel_compiler/cpu/nnacl/matmul_parameter.h"
#include "backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -27,39 +29,64 @@ namespace kernel {
namespace {
constexpr size_t kMatMulInputsNum = 2;
constexpr size_t kMatMulOutputsNum = 1;
const size_t kIndexOffset = 2;
constexpr size_t kIndexOffset = 2;
constexpr size_t kRankMin = 2;
using dims = dnnl::memory::dims;
} // namespace
void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
std::vector<size_t> a_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> b_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> o_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
const size_t rank_min = 2;
if (a_shape.size() < rank_min || b_shape.size() < rank_min || o_shape.size() < rank_min) {
MS_LOG(EXCEPTION) << "The tensor rank of MatMul should be greater than or equal to 2.";
if (a_shape.size() < kRankMin || b_shape.size() < kRankMin || o_shape.size() < kRankMin) {
MS_LOG(EXCEPTION) << "The tensor rank of MatMul should be greater than or equal to " << kRankMin;
}
bool trans_a = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_A);
bool trans_b = AnfAlgo::GetNodeAttr<bool>(kernel_node, TRANSPOSE_B);
rank_ = a_shape.size();
batch_ = 1;
for (size_t i = 0; i < rank_ - kIndexOffset; ++i) {
batch_ *= a_shape[i];
auto rank = a_shape.size();
int64_t batch = 1;
for (size_t i = 0; i < rank - kIndexOffset; ++i) {
batch *= SizeToLong(a_shape[i]);
}
size_mat_a_ = a_shape[rank_ - kIndexOffset] * a_shape[rank_ - 1];
size_mat_b_ = b_shape[rank_ - kIndexOffset] * b_shape[rank_ - 1];
size_mat_o_ = o_shape[rank_ - kIndexOffset] * o_shape[rank_ - 1];
int64_t dim_m = SizeToLong(o_shape[rank - kIndexOffset]);
int64_t dim_n = SizeToLong(o_shape[rank - 1]);
int64_t dim_k = 1;
if (trans_a) {
trans_a_ = TRANSPOSE_YES;
dim_k_ = static_cast<dnnl_dim_t>(a_shape[rank_ - kIndexOffset]);
dim_k = SizeToLong(a_shape[rank - kIndexOffset]);
} else {
dim_k_ = static_cast<dnnl_dim_t>(a_shape[rank_ - 1]);
dim_k = SizeToLong(a_shape[rank - 1]);
}
if (trans_b) {
trans_b_ = TRANSPOSE_YES;
dims src_dims, weights_dims, dst_dims, a_strides, b_strides, o_strides;
if (batch > 1) {
src_dims = {batch, dim_m, dim_k};
weights_dims = {batch, dim_k, dim_n};
dst_dims = {batch, dim_m, dim_n};
a_strides = {trans_a ? dims{dim_m * dim_k, 1, dim_m} : dims{dim_m * dim_k, dim_k, 1}};
b_strides = {trans_b ? dims{dim_n * dim_k, 1, dim_k} : dims{dim_n * dim_k, dim_n, 1}};
o_strides = {dim_n * dim_m, dim_n, 1};
} else {
src_dims = {dim_m, dim_k};
weights_dims = {dim_k, dim_n};
dst_dims = {dim_m, dim_n};
a_strides = {trans_a ? dims{1, dim_m} : dims{dim_k, 1}};
b_strides = {trans_b ? dims{1, dim_k} : dims{dim_n, 1}};
o_strides = {dim_n, 1};
}
dim_m_ = static_cast<dnnl_dim_t>(o_shape[rank_ - kIndexOffset]);
dim_n_ = static_cast<dnnl_dim_t>(o_shape[rank_ - 1]);
dnnl::memory::desc src_md(src_dims, dnnl::memory::data_type::f32, a_strides);
dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides);
dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides);
dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md);
dnnl::matmul::primitive_desc prim_desc(matmul_desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::matmul>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_md);
AddArgument(DNNL_ARG_WEIGHTS, weights_md);
AddArgument(DNNL_ARG_DST, dst_md);
}
bool MatMulCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
@ -70,15 +97,10 @@ bool MatMulCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, cons
const auto input_b = reinterpret_cast<float *>(inputs[1]->addr);
auto output = reinterpret_cast<float *>(outputs[0]->addr);
dnnl_dim_t lda = (trans_a_ == TRANSPOSE_YES ? dim_m_ : dim_k_);
dnnl_dim_t ldb = (trans_b_ == TRANSPOSE_YES ? dim_k_ : dim_n_);
dnnl_dim_t ldc = dim_n_;
float alpha = 1.0;
float beta = 0.0;
for (size_t i = 0; i < batch_; i++) {
(void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda,
input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_o_, ldc);
}
SetArgumentHandle(DNNL_ARG_SRC, input_a);
SetArgumentHandle(DNNL_ARG_WEIGHTS, input_b);
SetArgumentHandle(DNNL_ARG_DST, output);
ExecutePrimitive();
return true;
}
} // namespace kernel

View File

@ -32,18 +32,6 @@ class MatMulCPUKernel : public MKLCPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
dnnl_dim_t dim_m_{0};
dnnl_dim_t dim_n_{0};
dnnl_dim_t dim_k_{0};
size_t batch_{0};
size_t rank_{0};
size_t size_mat_a_{0};
size_t size_mat_b_{0};
size_t size_mat_o_{0};
char trans_a_{TRANSPOSE_NO};
char trans_b_{TRANSPOSE_NO};
};
MS_REG_CPU_KERNEL(
MatMul,

View File

@ -133,13 +133,7 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
template <typename T>
void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
const size_t thread_num =
count < block_size * max_thread_num ? FloatToSize(std::ceil(count / block_size)) : max_thread_num;
std::vector<common::Task> tasks;
std::function<void(const T *, T *, const int *, TransposeParameter *, int, int)> TransposeDims;
if constexpr (std::is_same_v<T, int8_t>) {
TransposeDims = &TransposeDimsInt8;
} else if constexpr (std::is_same_v<T, int16_t>) {
@ -163,14 +157,14 @@ void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, con
} else if constexpr (std::is_same_v<T, bool>) {
TransposeDims = &TransposeDimsBool;
}
for (int task_id = 0; task_id < SizeToInt(thread_num); ++task_id) {
auto task = [this, &TransposeDims, &input_addr, &output_addr, &output_shape, task_id, thread_num]() {
TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, task_id, SizeToInt(thread_num));
return common::SUCCESS;
};
(void)tasks.emplace_back(task);
}
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
auto thread_pool = GetActorMgrInnerThreadPool();
size_t thread_num = thread_pool->GetKernelThreadNum();
auto task = [this, &TransposeDims, input_addr, output_addr, output_shape, thread_num](size_t start, size_t end) {
for (size_t idx = start; idx < end; idx++) {
TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, SizeToInt(idx), SizeToInt(thread_num));
}
};
ParallelLaunchAutoSearch(task, thread_num, this, &parallel_search_info_);
}
} // namespace kernel
} // namespace mindspore

View File

@ -150,10 +150,10 @@ class ActorDispatcher {
// The first five executions are for warm-up, the next five executions are statistics of multi thread execution time,
// and the next next five executions are statistics of single thread execution time.
static constexpr size_t kMultiThreadExecutionCountBegin{6};
static constexpr size_t kMultiThreadExecutionCountEnd{10};
static constexpr size_t kSingleThreadExecutionCountBegin{11};
static constexpr size_t kSingleThreadExecutionCountEnd{15};
static constexpr size_t kMultiThreadExecutionCountBegin{31};
static constexpr size_t kMultiThreadExecutionCountEnd{40};
static constexpr size_t kSingleThreadExecutionCountBegin{41};
static constexpr size_t kSingleThreadExecutionCountEnd{50};
// The single thread execution constraint.
static constexpr size_t kSingleThreadExecutionActorMaxNum{100};