Refator eltwisegrad cpu ops

This commit is contained in:
wuxuejian 2021-03-31 09:43:26 +08:00
parent d346a861bc
commit 1d5f77d075
2 changed files with 77 additions and 180 deletions

View File

@ -14,15 +14,16 @@
* limitations under the License.
*/
#include <cmath>
#include <string>
#include <thread>
#include <map>
#include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h"
#include "common/thread_pool.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (input2[i] > 0) {
out[i] = input1[i];
@ -33,7 +34,7 @@ void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (input2[i] > 0 && input2[i] <= 6) {
out[i] = input1[i];
@ -44,7 +45,7 @@ void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, s
}
template <typename T>
void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (input1[i] > 0) {
out[i] = input2[i];
@ -57,21 +58,21 @@ void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, siz
}
template <typename T>
void EltWiseGradCPUKernel::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input2[i] * input1[i] * (1 - input1[i]);
}
}
template <typename T>
void EltWiseGradCPUKernel::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input2[i] / (input1[i] * 2);
}
}
template <typename T>
void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T tmp = input1[i] * input1[i];
out[i] = input2[i] * (1 - tmp);
@ -79,7 +80,7 @@ void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T x = input2[i];
auto double_x = static_cast<T>(x);
@ -91,7 +92,7 @@ void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(1 - input1[i] * input1[i]);
@ -112,7 +113,7 @@ void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = -input2[i];
T divisor = sqrt(1 - input1[i] * input1[i]);
@ -133,7 +134,7 @@ void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = 1 + input1[i] * input1[i];
@ -154,7 +155,7 @@ void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, si
}
template <typename T>
void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(1 + input1[i] * input1[i]);
@ -175,7 +176,7 @@ void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, s
}
template <typename T>
void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
void EltWiseGradCPUKernel<T>::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(input1[i] * input1[i] - 1);
@ -195,132 +196,46 @@ void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, s
}
}
void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
template <typename T>
void EltWiseGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "ReluGrad") {
operate_type_ = RELUGRAD;
} else if (kernel_name == "ReLU6Grad") {
operate_type_ = RELU6GRAD;
} else if (kernel_name == "SigmoidGrad") {
operate_type_ = SIGMOIDGRAD;
} else if (kernel_name == "AbsGrad") {
operate_type_ = ABSGRAD;
} else if (kernel_name == "TanhGrad") {
operate_type_ = TANHGRAD;
} else if (kernel_name == "SqrtGrad") {
operate_type_ = SQRTGRAD;
} else if (kernel_name == "GeLUGrad") {
operate_type_ = GELUGRAD;
} else if (kernel_name == "AsinGrad") {
operate_type_ = ASINGRAD;
} else if (kernel_name == "ACosGrad") {
operate_type_ = ACOSGRAD;
} else if (kernel_name == "AtanGrad") {
operate_type_ = ATANGRAD;
} else if (kernel_name == "AsinhGrad") {
operate_type_ = ASINHGRAD;
} else if (kernel_name == "AcoshGrad") {
operate_type_ = ACOSHGRAD;
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.size() == 0) {
output_shape_.insert(output_shape_.begin(), 1);
}
size_t l = input_shape0_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
input_shape0_.insert(input_shape0_.begin(), 1);
}
l = input_shape1_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
input_shape1_.insert(input_shape1_.begin(), 1);
}
CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_);
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) {
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type";
}
}
bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) {
LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
}
return true;
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
}
template <typename T>
void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
static const std::map<std::string,
std::function<void(EltWiseGradCPUKernel *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{"ReluGrad", &EltWiseGradCPUKernel<T>::ReluGrad}, {"ReLU6Grad", &EltWiseGradCPUKernel<T>::ReLU6Grad},
{"SigmoidGrad", &EltWiseGradCPUKernel<T>::SigmoidGrad}, {"AbsGrad", &EltWiseGradCPUKernel<T>::AbsGrad},
{"TanhGrad", &EltWiseGradCPUKernel<T>::TanhGrad}, {"SqrtGrad", &EltWiseGradCPUKernel<T>::SqrtGrad},
{"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad},
{"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad},
{"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}};
T *input1 = reinterpret_cast<T *>(inputs[0]->addr);
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t count = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == RELUGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ReluGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == RELU6GRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ReLU6Grad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ABSGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AbsGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SIGMOIDGRAD) {
threads.emplace_back(
std::thread(&EltWiseGradCPUKernel::SigmoidGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == TANHGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::TanhGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SQRTGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GELUGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASINGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ACOSGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ATANGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASINHGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinhGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ACOSHGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AcoshGrad<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
size_t once_compute_size = (count + thread_num - 1) / thread_num;
while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
auto block = [&, start, end]() {
elt_map.at(kernel_name_)(this, input1, input2, output, start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -18,11 +18,13 @@
#include <memory>
#include <vector>
#include <limits>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class EltWiseGradCPUKernel : public CPUKernel {
public:
EltWiseGradCPUKernel() = default;
@ -32,95 +34,75 @@ class EltWiseGradCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
template <typename T>
void ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
OperateType operate_type_{RELUGRAD};
TypeId dtype_{kTypeUnknown};
std::string kernel_name_ = "";
};
MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL_T(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
ReLU6Grad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
SqrtGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(GeLUGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(GeLUGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
EltWiseGradCPUKernel, float);
} // namespace kernel
} // namespace mindspore