clean code in cpu operater
This commit is contained in:
parent
4d403f5a39
commit
17d1bb97af
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"
|
||||
|
@ -24,26 +23,26 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *m = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *v = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
|
||||
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
|
||||
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
|
||||
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[6]->addr)[0]);
|
||||
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[7]->addr)[0]);
|
||||
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[8]->addr)[0]);
|
||||
T *gradient = reinterpret_cast<T *>(inputs[9]->addr);
|
||||
if (beta1_power - 1.0 == 0) {
|
||||
void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &) {
|
||||
T *var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
||||
T *m = reinterpret_cast<T *>(inputs[M]->addr);
|
||||
T *v = reinterpret_cast<T *>(inputs[V]->addr);
|
||||
float beta1_power = reinterpret_cast<float *>(inputs[BETA1_POWER]->addr)[SCALAR_INDEX];
|
||||
float beta2_power = reinterpret_cast<float *>(inputs[BETA2_POWER]->addr)[SCALAR_INDEX];
|
||||
float lr = reinterpret_cast<float *>(inputs[LR]->addr)[SCALAR_INDEX];
|
||||
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA1]->addr)[SCALAR_INDEX]);
|
||||
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[BETA1]->addr)[SCALAR_INDEX]);
|
||||
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[EPSILON]->addr)[SCALAR_INDEX]);
|
||||
T *gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
|
||||
constexpr float ONE = 1.0;
|
||||
if (beta1_power - ONE == 0) {
|
||||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||
}
|
||||
T new_lr = static_cast<T>(lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power));
|
||||
T one = static_cast<T>(1.0);
|
||||
T new_lr = static_cast<T>(lr * std::sqrt(ONE - beta2_power) / (ONE - beta1_power));
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(T)) : 1;
|
||||
auto task = [this, &var, &m, &v, &gradient, new_lr, beta1, beta2, epsilon](size_t start, size_t end) {
|
||||
T one = static_cast<T>(1.0);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * (one - beta1);
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2);
|
||||
|
@ -59,26 +58,30 @@ void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs,
|
|||
}
|
||||
|
||||
void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
float *var = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
float *m = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
float *v = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
|
||||
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
|
||||
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
|
||||
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0];
|
||||
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
|
||||
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
|
||||
float *gradient = reinterpret_cast<float *>(inputs[9]->addr);
|
||||
if (beta1_power - 1.0 == 0) {
|
||||
const std::vector<kernel::AddressPtr> &) {
|
||||
float *var = reinterpret_cast<float *>(inputs[VAR]->addr);
|
||||
float *m = reinterpret_cast<float *>(inputs[M]->addr);
|
||||
float *v = reinterpret_cast<float *>(inputs[V]->addr);
|
||||
float beta1_power = reinterpret_cast<float *>(inputs[BETA1_POWER]->addr)[SCALAR_INDEX];
|
||||
float beta2_power = reinterpret_cast<float *>(inputs[BETA2_POWER]->addr)[SCALAR_INDEX];
|
||||
float lr = reinterpret_cast<float *>(inputs[LR]->addr)[SCALAR_INDEX];
|
||||
float beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[SCALAR_INDEX];
|
||||
float beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[SCALAR_INDEX];
|
||||
float epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[SCALAR_INDEX];
|
||||
float *gradient = reinterpret_cast<float *>(inputs[GRAD]->addr);
|
||||
constexpr float ONE = 1.0;
|
||||
if (beta1_power - ONE == 0) {
|
||||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||
}
|
||||
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
||||
float new_lr = lr * std::sqrt(ONE - beta2_power) / (ONE - beta1_power);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
AdamFp32(var, m, v, new_lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
|
||||
auto task = [this, &var, &m, &v, &gradient, new_lr, beta1, beta2, epsilon](size_t start, size_t end) {
|
||||
int ret = AdamFp32(var, m, v, new_lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(EXCEPTION) << "AdamFp32 failed.";
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelForAutoSearch(task, lens, ¶llel_search_info_);
|
||||
}
|
||||
|
@ -87,11 +90,11 @@ void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (input_num != 10) {
|
||||
if (input_num != INPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Adam needs 10 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
if (output_num != OUTPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Adam needs 3 outputs.";
|
||||
}
|
||||
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
|
||||
|
@ -99,18 +102,19 @@ void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
|
||||
bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 10) {
|
||||
if (inputs.size() != INPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but Adam needs 10 inputs.";
|
||||
}
|
||||
if (outputs.size() != 3) {
|
||||
if (outputs.size() != OUTPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but Adam needs 3 outputs.";
|
||||
}
|
||||
if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[9]->size) {
|
||||
if (inputs[VAR]->size != inputs[M]->size || inputs[VAR]->size != inputs[V]->size ||
|
||||
inputs[VAR]->size != inputs[GRAD]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
size_t f_size = sizeof(float);
|
||||
if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size ||
|
||||
inputs[6]->size != f_size || inputs[7]->size != f_size || inputs[8]->size != f_size) {
|
||||
if (inputs[BETA1_POWER]->size != f_size || inputs[BETA2_POWER]->size != f_size || inputs[LR]->size != f_size ||
|
||||
inputs[BETA1]->size != f_size || inputs[BETA2]->size != f_size || inputs[EPSILON]->size != f_size) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta_power, beta, lr and epsilon must be float!";
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SCALAR_INDEX = 0;
|
||||
constexpr size_t INPUT_NUMS = 10;
|
||||
constexpr size_t OUTPUT_NUMS = 3;
|
||||
|
||||
class AdamCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AdamCPUKernel() = default;
|
||||
|
@ -38,6 +42,7 @@ class AdamCPUKernel : public CPUKernel {
|
|||
private:
|
||||
bool use_nesterov_{false};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
enum input_list_ { VAR, M, V, BETA1_POWER, BETA2_POWER, LR, BETA1, BETA2, EPSILON, GRAD };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel);
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
|
||||
|
@ -37,7 +36,7 @@ void ArithmeticCPUKernel<T>::AssignAdd(T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::Add(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -86,7 +85,7 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::Mul(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -148,7 +147,7 @@ void ArithmeticCPUKernel<T>::RealDiv(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -177,7 +176,7 @@ void ArithmeticCPUKernel<T>::RealDiv(const T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::Div(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -206,7 +205,7 @@ void ArithmeticCPUKernel<T>::Div(const T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::FloorDiv(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -235,7 +234,7 @@ void ArithmeticCPUKernel<T>::FloorDiv(const T *input1, const T *input2, T *out)
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::Mod(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -257,7 +256,7 @@ void ArithmeticCPUKernel<T>::Mod(const T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::FloorMod(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -275,7 +274,7 @@ template <typename T>
|
|||
void ArithmeticCPUKernel<T>::Pow(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
if (output_size_ > MAX_POW_SERIAL_SIZE) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -300,7 +299,7 @@ void ArithmeticCPUKernel<T>::Pow(const T *input1, const T *input2, T *out) {
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::SquaredDifference(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -315,7 +314,7 @@ void ArithmeticCPUKernel<T>::SquaredDifference(const T *input1, const T *input2,
|
|||
template <typename T>
|
||||
void ArithmeticCPUKernel<T>::Atan2(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
|
|
@ -13,72 +13,79 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/iou_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/iou_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void IOUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto anchor_boxes_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (anchor_boxes_shape.size() != 2 || anchor_boxes_shape[1] != 4) {
|
||||
auto anchor_boxes_shape = AnfAlgo::GetInputDeviceShape(kernel_node, ANCHOR_BOXES);
|
||||
constexpr size_t BOX_SHAPE_SIZE = 2;
|
||||
constexpr size_t BOX_SIZE_INDEX = 0;
|
||||
constexpr size_t BOX_COORDINATE_INDEX = 1;
|
||||
|
||||
if (anchor_boxes_shape.size() != BOX_SHAPE_SIZE || anchor_boxes_shape[BOX_COORDINATE_INDEX] != BOX_COORDINATE_LEN) {
|
||||
MS_LOG(EXCEPTION) << "The anchor_boxes shape should be [N, 4].";
|
||||
}
|
||||
anchor_boxes_size_ = anchor_boxes_shape[0];
|
||||
auto gt_boxes_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
if (gt_boxes_shape.size() != 2 || gt_boxes_shape[1] != 4) {
|
||||
anchor_boxes_size_ = anchor_boxes_shape[BOX_SIZE_INDEX];
|
||||
auto gt_boxes_shape = AnfAlgo::GetInputDeviceShape(kernel_node, GT_BOXES);
|
||||
if (gt_boxes_shape.size() != BOX_SHAPE_SIZE || gt_boxes_shape[BOX_COORDINATE_INDEX] != BOX_COORDINATE_LEN) {
|
||||
MS_LOG(EXCEPTION) << "The gt_boxes shape should be [N, 4].";
|
||||
}
|
||||
gt_boxes_size_ = gt_boxes_shape[0];
|
||||
gt_boxes_size_ = gt_boxes_shape[BOX_SIZE_INDEX];
|
||||
iou_size_ = anchor_boxes_size_ * gt_boxes_size_;
|
||||
std::string iou_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "mode");
|
||||
if (iou_mode != "iou" && iou_mode != "iof") {
|
||||
MS_LOG(EXCEPTION) << "IOU mode should be 'iou', 'iof'.";
|
||||
}
|
||||
if (iou_mode == "iof") {
|
||||
mode_ = 1;
|
||||
mode_ = IOF_MODE;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IOUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
bool IOUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but IOU needs 2 inputs.";
|
||||
if (inputs.size() != INPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but IOU needs " << INPUT_NUMS << " inputs.";
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but IOU needs 1 outputs.";
|
||||
if (outputs.size() != OUTPUT_NUMS) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but IOU needs " << OUTPUT_NUMS << " outputs.";
|
||||
}
|
||||
auto anchor_boxes = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto gt_boxes = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto iou_score = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto anchor_boxes = reinterpret_cast<T *>(inputs[ANCHOR_BOXES]->addr);
|
||||
auto gt_boxes = reinterpret_cast<T *>(inputs[GT_BOXES]->addr);
|
||||
auto iou_score = reinterpret_cast<T *>(outputs[IOU_VALUE]->addr);
|
||||
|
||||
// multithreading
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [&anchor_boxes, >_boxes, &iou_score, this](size_t start, size_t end) {
|
||||
const T ZERO = T(1);
|
||||
const T ONE = T(1);
|
||||
const T EPS = T(1e-10);
|
||||
constexpr size_t Y0_SHIFT = 1;
|
||||
constexpr size_t X1_SHIFT = 2;
|
||||
constexpr size_t Y1_SHIFT = 3;
|
||||
for (size_t i = start; i < end; i++) {
|
||||
int idx1 = i % anchor_boxes_size_ * 4;
|
||||
int idx2 = i / anchor_boxes_size_ * 4;
|
||||
int idx1 = i % anchor_boxes_size_ * BOX_COORDINATE_LEN;
|
||||
int idx2 = i / anchor_boxes_size_ * BOX_COORDINATE_LEN;
|
||||
T I_x0 = std::max(anchor_boxes[idx1], gt_boxes[idx2]);
|
||||
T I_y0 = std::max(anchor_boxes[idx1 + 1], gt_boxes[idx2 + 1]);
|
||||
T I_x1 = std::min(anchor_boxes[idx1 + 2], gt_boxes[idx2 + 2]);
|
||||
T I_y1 = std::min(anchor_boxes[idx1 + 3], gt_boxes[idx2 + 3]);
|
||||
T overlaps = std::max(T(0), (I_x1 - I_x0 + T(1)) * (I_y1 - I_y0 + T(1)));
|
||||
T area1 =
|
||||
(anchor_boxes[idx1 + 2] - anchor_boxes[idx1] + T(1)) * (anchor_boxes[idx1 + 3] - anchor_boxes[idx1 + 1] + T(1));
|
||||
T area2 = (gt_boxes[idx2 + 2] - gt_boxes[idx2] + T(1)) * (gt_boxes[idx2 + 3] - gt_boxes[idx2 + 1] + T(1));
|
||||
if (mode_ == 0) {
|
||||
iou_score[i] = overlaps / (area1 + area2 - overlaps + T(1e-10));
|
||||
T I_y0 = std::max(anchor_boxes[idx1 + Y0_SHIFT], gt_boxes[idx2 + Y0_SHIFT]);
|
||||
T I_x1 = std::min(anchor_boxes[idx1 + X1_SHIFT], gt_boxes[idx2 + X1_SHIFT]);
|
||||
T I_y1 = std::min(anchor_boxes[idx1 + Y1_SHIFT], gt_boxes[idx2 + Y1_SHIFT]);
|
||||
T overlaps = std::max(ZERO, (I_x1 - I_x0 + ONE) * (I_y1 - I_y0 + ONE));
|
||||
T area1 = (anchor_boxes[idx1 + X1_SHIFT] - anchor_boxes[idx1] + ONE) *
|
||||
(anchor_boxes[idx1 + Y1_SHIFT] - anchor_boxes[idx1 + Y0_SHIFT] + ONE);
|
||||
T area2 = (gt_boxes[idx2 + X1_SHIFT] - gt_boxes[idx2] + ONE) *
|
||||
(gt_boxes[idx2 + Y1_SHIFT] - gt_boxes[idx2 + Y0_SHIFT] + ONE);
|
||||
if (mode_ == IOU_MODE) {
|
||||
iou_score[i] = overlaps / (area1 + area2 - overlaps + EPS);
|
||||
} else {
|
||||
iou_score[i] = overlaps / (area2 + T(1e-10));
|
||||
iou_score[i] = overlaps / (area2 + EPS);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUMS = 2;
|
||||
constexpr size_t OUTPUT_NUMS = 1;
|
||||
constexpr size_t BOX_COORDINATE_LEN = 4;
|
||||
|
||||
template <typename T>
|
||||
class IOUCPUKernel : public CPUKernel {
|
||||
public:
|
||||
|
@ -34,10 +38,13 @@ class IOUCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
int mode_{0};
|
||||
size_t anchor_boxes_size_{0};
|
||||
size_t gt_boxes_size_{0};
|
||||
size_t iou_size_{0};
|
||||
enum input_list_ { ANCHOR_BOXES, GT_BOXES };
|
||||
enum output_list_ { IOU_VALUE };
|
||||
enum iou_mod_ { IOU_MODE, IOF_MODE };
|
||||
int mode_{IOU_MODE};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
@ -29,10 +27,10 @@ void LayerNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis");
|
||||
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis");
|
||||
if (begin_norm_axis < 0) {
|
||||
begin_norm_axis += x_shape.size();
|
||||
begin_norm_axis += SizeToLong(x_shape.size());
|
||||
}
|
||||
if (begin_params_axis < 0) {
|
||||
begin_params_axis += x_shape.size();
|
||||
begin_params_axis += SizeToLong(x_shape.size());
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(begin_norm_axis); i++) {
|
||||
block_num_ *= x_shape[i];
|
||||
|
@ -81,7 +79,7 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
}
|
||||
std::vector<common::Task> tasks;
|
||||
tasks.reserve(thread_num);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto task = [this, &x, &gamma, &beta, &y, &mean, &var, thread_num](size_t start, size_t end) {
|
||||
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num); ++c) {
|
||||
if (c * thread_num + start >= block_num_) {
|
||||
continue;
|
||||
|
|
|
@ -28,10 +28,10 @@ void LayerNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis");
|
||||
auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis");
|
||||
if (begin_norm_axis < 0) {
|
||||
begin_norm_axis += x_shape.size();
|
||||
begin_norm_axis += SizeToLong(x_shape.size());
|
||||
}
|
||||
if (begin_params_axis < 0) {
|
||||
begin_params_axis += x_shape.size();
|
||||
begin_params_axis += SizeToLong(x_shape.size());
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(begin_norm_axis); i++) {
|
||||
block_num_ *= x_shape[i];
|
||||
|
@ -81,7 +81,7 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
auto thread_num2 = block_num_ < thread_num ? block_num_ : thread_num;
|
||||
std::vector<common::Task> tasks2;
|
||||
tasks2.reserve(thread_num2);
|
||||
auto task1 = [&](size_t start, size_t end) {
|
||||
auto task1 = [this, &x, &dy, &var, &mean, &dg, &db, thread_num1](size_t start, size_t end) {
|
||||
for (size_t c = 0; c < ceil(static_cast<double>(param_num_) / thread_num1); ++c) {
|
||||
if (c * thread_num1 + start >= param_num_) {
|
||||
continue;
|
||||
|
@ -98,7 +98,7 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
db[param_index] = dbeta;
|
||||
}
|
||||
};
|
||||
auto task2 = [&](size_t start, size_t end) {
|
||||
auto task2 = [this, &x, &dy, &var, &mean, &dx, &gamma, thread_num2](size_t start, size_t end) {
|
||||
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num2); ++c) {
|
||||
if (c * thread_num2 + start >= block_num_) {
|
||||
continue;
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -25,29 +24,29 @@ void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t type_size = sizeof(float);
|
||||
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t tensor_size = shape[1] * 2 * type_size;
|
||||
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, Y_BACKPROP);
|
||||
size_t tensor_size = shape[C] * SCALE_SHIFT_NUM * type_size;
|
||||
input_size_list_.pop_back();
|
||||
// [2, c] to store scale and bias
|
||||
workspace_size_list_.emplace_back(tensor_size);
|
||||
(void)workspace_size_list_.emplace_back(tensor_size);
|
||||
// [2, c] to store diff_scale and diff_bias
|
||||
workspace_size_list_.emplace_back(tensor_size);
|
||||
(void)workspace_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
|
||||
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() == 2) {
|
||||
if (x_shape.size() == NC) {
|
||||
x_shape.insert(x_shape.end(), 2, 1);
|
||||
} else if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
||||
} else if (x_shape.size() != NCHW) {
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm support nc or nchw input!";
|
||||
}
|
||||
batch_size = x_shape[0];
|
||||
channel = x_shape[1];
|
||||
hw_size = x_shape[2] * x_shape[3];
|
||||
nhw_size = x_shape[0] * hw_size;
|
||||
batch_size = x_shape[N];
|
||||
channel = x_shape[C];
|
||||
hw_size = x_shape[H] * x_shape[W];
|
||||
nhw_size = batch_size * hw_size;
|
||||
dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape);
|
||||
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel});
|
||||
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({SCALE_SHIFT_NUM, channel});
|
||||
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
|
||||
auto prop_kind = dnnl::prop_kind::forward_training;
|
||||
auto normalization_flags = dnnl::normalization_flags::use_scale_shift;
|
||||
|
@ -77,36 +76,37 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() < 5 || outputs.empty()) {
|
||||
constexpr size_t INPUT_NUM = 5;
|
||||
if (inputs.size() < INPUT_NUM || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Error input output size!";
|
||||
}
|
||||
auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size);
|
||||
auto wksp_in = reinterpret_cast<float *>(workspace[SCALE_BIAS]->addr);
|
||||
auto scale_ret = memcpy_s(wksp_in, workspace[SCALE_BIAS]->size, inputs[SCALE]->addr, inputs[SCALE]->size);
|
||||
if (scale_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Scale memcpy error!";
|
||||
}
|
||||
auto max_size = workspace[0]->size - inputs[2]->size;
|
||||
auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0, max_size);
|
||||
auto max_size = workspace[SCALE_BIAS]->size - inputs[SCALE]->size;
|
||||
auto bias_ret = memset_s(wksp_in + (inputs[SCALE]->size / sizeof(float)), max_size, 0, max_size);
|
||||
if (bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Bias memset 0 error.";
|
||||
}
|
||||
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[Y_BACKPROP]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[X]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_MEAN, inputs[SAVE_MEAN]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[SAVE_VARIANCE]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[SCALE_BIAS]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[DX]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[DIFF_SCALE_BIAS]->addr);
|
||||
ExecutePrimitive();
|
||||
|
||||
auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr);
|
||||
auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size);
|
||||
auto wksp_out = reinterpret_cast<float *>(workspace[DIFF_SCALE_BIAS]->addr);
|
||||
auto diff_scale_ret = memcpy_s(outputs[DSCALE]->addr, outputs[DSCALE]->size, wksp_out, inputs[SCALE]->size);
|
||||
if (diff_scale_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Diff_scale memcpy to output[1] error.";
|
||||
}
|
||||
auto diff_bias_ret =
|
||||
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size);
|
||||
auto diff_bias_ret = memcpy_s(outputs[DBIAS]->addr, outputs[DBIAS]->size,
|
||||
wksp_out + (outputs[DSCALE]->size / sizeof(float)), outputs[DBIAS]->size);
|
||||
if (diff_bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Diff_bias memcpy to to output[2] error.";
|
||||
}
|
||||
|
|
|
@ -21,6 +21,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SCALE_SHIFT_NUM = 2;
|
||||
constexpr size_t NC = 2;
|
||||
constexpr size_t NCHW = 4;
|
||||
|
||||
class BatchNormGradCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
BatchNormGradCPUKernel() = default;
|
||||
|
@ -40,6 +44,10 @@ class BatchNormGradCPUKernel : public MKLCPUKernel {
|
|||
size_t channel{0};
|
||||
size_t hw_size{0};
|
||||
size_t nhw_size{0};
|
||||
enum format_ { N, C, H, W };
|
||||
enum input_list_ { Y_BACKPROP, X, SCALE, SAVE_MEAN, SAVE_VARIANCE, RESERVE };
|
||||
enum workspace_list_ { SCALE_BIAS, DIFF_SCALE_BIAS };
|
||||
enum output_list_ { DX, DSCALE, DBIAS };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(BatchNormGrad,
|
||||
|
|
|
@ -46,7 +46,7 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << "Conv channels should be divided by group!";
|
||||
}
|
||||
weight_shape.insert(weight_shape.begin(), group);
|
||||
(void)weight_shape.insert(weight_shape.begin(), group);
|
||||
weight_shape[1] = weight_shape[1] / group;
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h"
|
||||
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
|
||||
|
|
|
@ -14,10 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/topk_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue