merge canndev code to mindspore

This commit is contained in:
shen_jingxing 2023-02-11 13:21:56 +08:00
parent 6bb7c08540
commit 9437d7da9d
20 changed files with 2168 additions and 0 deletions

View File

@ -89,6 +89,7 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constArgument"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "unknownMacro"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "unsignedLessThanZero"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unreadVariable"

View File

@ -339,6 +339,7 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/maximum.cc:aicpu::MaximumCpuKernel::BcastComputeOneKernel
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/lu_unpack.cc:aicpu::LuUnpackCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc:mindspore::opt::AICpuLibSelectPass::Process
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/crop_and_resize_grad_boxes.cc:aicpu::CropAndResizeGradBoxesCpuKernel::GradOfBoxesCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_max_pool_grad.cc:aicpu::FractionalMaxPoolGradCpuKernel::DoCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_avg_pool_grad.cc:aicpu::FractionalAvgPoolGradCpuKernel::DoCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fractional_max_pool.cc:aicpu::FractionalMaxPoolCpuKernel::DoCompute

View File

@ -0,0 +1,136 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coalesce.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include "cpu_kernel_utils.h"
#include <Eigen/Dense>
#include <algorithm>
#include <iostream>
#include <map>
namespace {
const uint32_t kInputNum = 3;
const uint32_t kOutputNum = 3;
const char *kCoalesce = "Coalesce";
} // namespace
namespace aicpu {
uint32_t CoalesceCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Coalesce normal check failed.");
auto x_values_type = ctx.Input(1)->GetDataType();
if (x_values_type == DT_FLOAT) {
return ComputeKernel<float>(ctx);
} else {
return ComputeKernel<Eigen::half>(ctx);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CoalesceCpuKernel::ComputeKernel(CpuKernelContext &ctx) {
Tensor *x_indices = ctx.Input(0);
Tensor *x_values = ctx.Input(1);
Tensor *x_shape = ctx.Input(2);
Tensor *y_indices = ctx.Output(0);
Tensor *y_values = ctx.Output(1);
Tensor *y_shape = ctx.Output(2);
auto x_indices_ptr = reinterpret_cast<int64_t *>(x_indices->GetData());
auto x_values_ptr = reinterpret_cast<T *>(x_values->GetData());
auto x_shape_ptr = reinterpret_cast<int64_t *>(x_shape->GetData());
auto y_indices_ptr = reinterpret_cast<int64_t *>(y_indices->GetData());
auto y_values_ptr = reinterpret_cast<T *>(y_values->GetData());
auto y_shape_ptr = reinterpret_cast<int64_t *>(y_shape->GetData());
int64_t x_nnz = x_indices->GetTensorShape()->GetDimSize(1);
int64_t num_dims = x_indices->GetTensorShape()->GetDimSize(0);
for (int64_t i = 0; i < x_nnz; i++) {
for (int64_t j = 0; j < num_dims; j++) {
KERNEL_CHECK_FALSE(
(x_indices_ptr[j * x_nnz + i] >= 0), KERNEL_STATUS_PARAM_INVALID,
"For Coalesce, values of elements of x_indices should be non-negative, but got x_indices[%d][%d] = %d.", j, i,
x_indices_ptr[j * x_nnz + i])
KERNEL_CHECK_FALSE((x_indices_ptr[j * x_nnz + i] < x_shape_ptr[j]), KERNEL_STATUS_PARAM_INVALID,
"For Coalesce, values of elements of x_indices should not exceed the limit set by x_shape, "
"but got x_indices[%d][%d] = %d, got x_shape[%d] = %d.",
j, i, x_indices_ptr[j * x_nnz + i], j, x_shape_ptr[j])
}
}
std::vector<int64_t> reorder(x_nnz);
std::iota(reorder.begin(), reorder.end(), 0);
auto sorter = [x_indices_ptr, num_dims, x_nnz](int64_t i, int64_t j) -> bool {
for (int64_t n = 0; n < num_dims; n++) {
if (x_indices_ptr[n * x_nnz + i] < x_indices_ptr[n * x_nnz + j]) {
return true;
}
if (x_indices_ptr[n * x_nnz + i] > x_indices_ptr[n * x_nnz + j]) {
return false;
}
}
return true;
};
std::sort(reorder.begin(), reorder.end(), sorter);
std::vector<bool> del(x_nnz);
del[0] = false;
int64_t jump = 0;
y_values_ptr[0] = x_values_ptr[reorder[0]];
for (int64_t i = 1; i < x_nnz; i++) {
del[i] = true;
for (int64_t j = 0; j < num_dims; j++) {
if (x_indices_ptr[j * x_nnz + reorder[i]] != x_indices_ptr[j * x_nnz + reorder[i - 1]]) {
del[i] = false;
break;
}
}
if (del[i]) {
y_values_ptr[jump] += x_values_ptr[reorder[i]];
} else {
jump++;
y_values_ptr[jump] = x_values_ptr[reorder[i]];
}
}
int64_t up = 0;
for (int64_t i = 0; i < x_nnz; i++) {
if (!del[i]) {
for (int64_t j = 0; j < num_dims; j++) {
y_indices_ptr[j * (jump + 1) + up] = x_indices_ptr[j * x_nnz + reorder[i]];
}
up++;
}
}
for (int64_t i = 0; i < num_dims; i++) {
y_shape_ptr[i] = x_shape_ptr[i];
}
std::vector<int64_t> dims = {num_dims, jump + 1};
auto y_indices_shape = y_indices->GetTensorShape();
y_indices_shape->SetDimSizes(dims);
dims = {jump + 1};
auto y_values_shape = y_values->GetTensorShape();
y_values_shape->SetDimSizes(dims);
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCoalesce, CoalesceCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_COALESCE_H_
#define AICPU_KERNELS_NORMALIZED_COALESCE_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class CoalesceCpuKernel : public CpuKernel {
public:
CoalesceCpuKernel() = default;
~CoalesceCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
static uint32_t ComputeKernel(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,226 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crop_and_resize_grad_boxes.h"
#include <cmath>
#include <iostream>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <chrono>
#include <cstdlib>
#include <vector>
#include "Eigen/Dense"
namespace {
constexpr uint32_t kInputNum = 4;
constexpr uint32_t kOutputNum = 1;
const char *kCropAndResizeGradBoxes = "CropAndResizeGradBoxes";
} // namespace
namespace aicpu {
uint32_t CropAndResizeGradBoxesCpuKernel::cheakInputTypeAndGetDatas(CpuKernelContext &ctx) {
Tensor *input_data0 = ctx.Input(0);
Tensor *input_data1 = ctx.Input(1);
Tensor *input_data2 = ctx.Input(2);
Tensor *input_data3 = ctx.Input(3);
Tensor *output = ctx.Output(0);
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CropAndResizeGradBoxes check params failed.");
image_shape_ = input_data1->GetTensorShape()->GetDimSizes();
boxes_shape_ = input_data2->GetTensorShape()->GetDimSizes();
box_in_shape_ = input_data3->GetTensorShape()->GetDimSizes();
grads_shape_ = input_data0->GetTensorShape()->GetDimSizes();
output_shape_ = output->GetTensorShape()->GetDimSizes();
KERNEL_CHECK_FALSE((grads_shape_.size() == 4), KERNEL_STATUS_PARAM_INVALID,
"Dim of input[0] must be 4, but the input[0] is %zu.", grads_shape_.size());
KERNEL_CHECK_FALSE((image_shape_.size() == 4), KERNEL_STATUS_PARAM_INVALID,
"Dim of input[1] must be 4, but the input[1] is %zu.", image_shape_.size());
KERNEL_CHECK_FALSE((image_shape_[1] > 0 && image_shape_[2] > 0), KERNEL_STATUS_PARAM_INVALID,
"the height and width of input image of "
"CropAndResizeGradBoxes must be over 0.");
KERNEL_CHECK_FALSE((grads_shape_[1] > 0 && grads_shape_[2] > 0), KERNEL_STATUS_PARAM_INVALID,
"the height and width of input grads of "
"CropAndResizeGradBoxes must be over 0.");
KERNEL_CHECK_FALSE((boxes_shape_.size() == 2), KERNEL_STATUS_PARAM_INVALID, "Dim of input[2] must be 2.");
KERNEL_CHECK_FALSE((box_in_shape_.size() == 1), KERNEL_STATUS_PARAM_INVALID, "Dim of input[3] must be 1.");
KERNEL_CHECK_FALSE((output_shape_.size() == 2), KERNEL_STATUS_PARAM_INVALID, "Dim of output must be 2.");
KERNEL_CHECK_FALSE((grads_shape_[0] == boxes_shape_[0]), KERNEL_STATUS_PARAM_INVALID,
"boxes and grads must have compatible Batch.");
data_type_ = input_data1->GetDataType();
return KERNEL_STATUS_OK;
}
uint32_t CropAndResizeGradBoxesCpuKernel::Compute(CpuKernelContext &ctx) {
uint32_t res = cheakInputTypeAndGetDatas(ctx);
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "GetInputAndCheck failed.");
switch (data_type_) {
case DT_UINT8:
res = GradOfBoxesCompute<uint8_t>(ctx);
break;
case DT_UINT16:
res = GradOfBoxesCompute<uint16_t>(ctx);
break;
case DT_INT8:
res = GradOfBoxesCompute<int8_t>(ctx);
break;
case DT_INT16:
res = GradOfBoxesCompute<int16_t>(ctx);
break;
case DT_INT32:
res = GradOfBoxesCompute<int32_t>(ctx);
break;
case DT_INT64:
res = GradOfBoxesCompute<int64_t>(ctx);
break;
case DT_FLOAT16:
res = GradOfBoxesCompute<Eigen::half>(ctx);
break;
case DT_FLOAT:
res = GradOfBoxesCompute<float>(ctx);
break;
case DT_DOUBLE:
res = GradOfBoxesCompute<double>(ctx);
break;
default:
KERNEL_LOG_ERROR("CropAndResizeGradBoxes op doesn't support input tensor types: [%s]",
DTypeStr(data_type_).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "CropAndResizeGradBoxes Compute failed.");
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CropAndResizeGradBoxesCpuKernel::GradOfBoxesCompute(CpuKernelContext &ctx) {
Tensor *grads_tensor = ctx.Input(0);
Tensor *image_tensor = ctx.Input(1);
Tensor *boxes_tensor = ctx.Input(2);
Tensor *box_ind_tensor = ctx.Input(3);
Tensor *output_tensor = ctx.Output(0);
const float *grads = reinterpret_cast<float *>(grads_tensor->GetData());
const T *image = reinterpret_cast<T *>(image_tensor->GetData());
const float *boxes = reinterpret_cast<float *>(boxes_tensor->GetData());
float *outputDatas = reinterpret_cast<float *>(output_tensor->GetData());
const int32_t *box_ind = reinterpret_cast<int32_t *>(box_ind_tensor->GetData());
const int image_batch = image_shape_[0];
const int image_height = image_shape_[1];
const int image_width = image_shape_[2];
const int depth = image_shape_[3];
const int nums_boxes = grads_shape_[0];
const int crop_height = grads_shape_[1];
const int crop_width = grads_shape_[2];
const int crop_depth = grads_shape_[3];
KERNEL_CHECK_FALSE((depth == crop_depth), KERNEL_STATUS_PARAM_INVALID, "boxes and grads must have compatible Depth.");
const int boxesCoordinateNum = boxes_shape_[1];
const int num_image2 = image_height * image_width * depth;
const int num_image3 = image_width * depth;
const int num_crop2 = crop_height * crop_width * crop_depth;
const int num_crop3 = crop_width * crop_depth;
// Output zeroing.
int num = nums_boxes * 4;
for (int i = 0; i < num; i++) {
*(outputDatas + i) = 0;
}
for (int b = 0; b < nums_boxes; b++) {
const float y1 = *(boxes + b * boxesCoordinateNum + 0);
const float x1 = *(boxes + b * boxesCoordinateNum + 1);
const float y2 = *(boxes + b * boxesCoordinateNum + 2);
const float x2 = *(boxes + b * boxesCoordinateNum + 3);
const int b_in = *(box_ind + b);
if (b_in >= image_batch || b_in < 0) {
continue;
}
const float height_ratio = (crop_height > 1) ? static_cast<float>(image_height - 1) / (crop_height - 1) : 0;
const float width_ratio = (crop_width > 1) ? static_cast<float>(image_width - 1) / (crop_width - 1) : 0;
const float height_scale = (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
for (int y = 0; y < crop_height; y++) {
const float y_in =
(crop_height > 1) ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1);
if (y_in < 0 || y_in > image_height - 1) {
continue;
}
const int top_y_index = floorf(y_in);
const int bottom_y_index = ceilf(y_in);
const float y_lerp = y_in - top_y_index;
for (int x = 0; x < crop_width; x++) {
const float x_in =
(crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1);
if (x_in < 0 || x_in > image_width - 1) {
continue;
}
const int left_x_index = floorf(x_in);
const int right_x_index = ceilf(x_in);
const float x_lerp = x_in - left_x_index;
for (int d = 0; d < depth; d++) {
const float top_left_value(
static_cast<float>(*(image + b_in * num_image2 + top_y_index * num_image3 + left_x_index * depth + d)));
const float top_right_value(
static_cast<float>(*(image + b_in * num_image2 + top_y_index * num_image3 + right_x_index * depth + d)));
const float bottom_left_value(
static_cast<float>(*(image + b_in * num_image2 + bottom_y_index * num_image3 + left_x_index * depth + d)));
const float bottom_right_value(
static_cast<float>(*(image + b_in * num_image2 + bottom_y_index * num_image3 + right_x_index * depth + d)));
// Compute the image gradient
float image_ygrad_value =
(1 - x_lerp) * (bottom_left_value - top_left_value) + x_lerp * (bottom_right_value - top_right_value);
float image_xgrad_value =
(1 - y_lerp) * (top_right_value - top_left_value) + y_lerp * (bottom_right_value - bottom_left_value);
// Modulate the image gradient with the incoming gradient
const float top_grad = *(grads + b * num_crop2 + y * num_crop3 + x * crop_depth + d);
image_ygrad_value *= top_grad;
image_xgrad_value *= top_grad;
// dy1,dy2
if (crop_height > 1) {
*(outputDatas + b * 4 + 0) += image_ygrad_value * (image_height - 1 - y * height_ratio);
*(outputDatas + b * 4 + 2) += image_ygrad_value * (y * height_ratio);
} else {
*(outputDatas + b * 4 + 0) += image_ygrad_value * 0.5 * (image_height - 1);
*(outputDatas + b * 4 + 2) += image_ygrad_value * 0.5 * (image_height - 1);
}
// dx1,dx2
if (crop_width > 1) {
*(outputDatas + b * 4 + 1) += image_xgrad_value * (image_width - 1 - x * width_ratio);
*(outputDatas + b * 4 + 3) += image_xgrad_value * (x * width_ratio);
} else {
*(outputDatas + b * 4 + 1) += image_xgrad_value * 0.5 * (image_width - 1);
*(outputDatas + b * 4 + 3) += image_xgrad_value * 0.5 * (image_width - 1);
}
}
}
}
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCropAndResizeGradBoxes, CropAndResizeGradBoxesCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CROPANDRESIZEGRADBOXES_H_
#define AICPU_KERNELS_NORMALIZED_CROPANDRESIZEGRADBOXES_H_
#include "Eigen/Core"
#include "cpu_ops_kernel.h"
#include "utils/bcast.h"
namespace aicpu {
class CropAndResizeGradBoxesCpuKernel : public CpuKernel {
public:
CropAndResizeGradBoxesCpuKernel() = default;
~CropAndResizeGradBoxesCpuKernel() = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t cheakInputTypeAndGetDatas(CpuKernelContext &ctx);
template <typename T>
uint32_t GradOfBoxesCompute(CpuKernelContext &ctx);
std::vector<int64_t> grads_shape_;
std::vector<int64_t> image_shape_;
std::vector<int64_t> boxes_shape_;
std::vector<int64_t> box_in_shape_;
std::vector<int64_t> output_shape_;
DataType data_type_ = DT_INT32;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,215 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crop_and_resize_grad_image.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <cmath>
#include <iostream>
namespace {
constexpr uint32_t kInputNum = 4;
constexpr uint32_t kOutputNum = 1;
const char *kCropAndResizeGradImage = "CropAndResizeGradImage";
} // namespace
namespace aicpu {
uint32_t CropAndResizeGradImageCpuKernel::cheakInputTypeAndGetDatas(CpuKernelContext &ctx) {
Tensor *grads = ctx.Input(0);
Tensor *boxes = ctx.Input(1);
Tensor *box_index = ctx.Input(2);
Tensor *image_size = ctx.Input(3);
Tensor *output = ctx.Output(0);
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "CropAndResizeGradImage check params failed.");
grads_shape_ = grads->GetTensorShape()->GetDimSizes();
boxes_shape_ = boxes->GetTensorShape()->GetDimSizes();
box_ind_shape_ = box_index->GetTensorShape()->GetDimSizes();
image_size_shape_ = image_size->GetTensorShape()->GetDimSizes();
output_shape_ = output->GetTensorShape()->GetDimSizes();
KERNEL_CHECK_FALSE((grads_shape_.size() == 4), KERNEL_STATUS_PARAM_INVALID,
"Dim of grads must be 4, but the grads is %zu.", grads_shape_.size());
KERNEL_CHECK_FALSE((boxes_shape_.size() == 2), KERNEL_STATUS_PARAM_INVALID,
"Dim of boxes must be 2, but the boxes is %zu.", boxes_shape_.size());
KERNEL_CHECK_FALSE((box_ind_shape_.size() == 1), KERNEL_STATUS_PARAM_INVALID, "Dim of box_index must be 1.");
KERNEL_CHECK_FALSE((image_size_shape_.size() == 1 && image_size_shape_[0] == 4), KERNEL_STATUS_PARAM_INVALID,
"the input of image_size must be 1D and have 4 elements.");
KERNEL_CHECK_FALSE((output_shape_.size() == 4), KERNEL_STATUS_PARAM_INVALID, "Dim of output must be 4.");
KERNEL_CHECK_FALSE((grads_shape_[1] > 0 && grads_shape_[2] > 0), KERNEL_STATUS_PARAM_INVALID,
"grads dimensions must be positive.");
KERNEL_CHECK_FALSE((grads_shape_[0] == boxes_shape_[0]), KERNEL_STATUS_PARAM_INVALID,
"boxes and grads have incompatible shape.");
data_type_ = output->GetDataType();
return KERNEL_STATUS_OK;
}
uint32_t CropAndResizeGradImageCpuKernel::Compute(CpuKernelContext &ctx) {
uint32_t res = cheakInputTypeAndGetDatas(ctx);
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "GetInputAndCheck failed.");
switch (data_type_) {
case DT_FLOAT16:
res = GradOfImageComputeShared<Eigen::half>(ctx);
break;
case DT_FLOAT:
res = GradOfImageComputeShared<float>(ctx);
break;
case DT_DOUBLE:
res = GradOfImageComputeShared<double>(ctx);
break;
default:
KERNEL_LOG_ERROR("CropAndResizeGradImage op doesn't support input tensor types: [%s]",
DTypeStr(data_type_).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "CropAndResizeGradImage Compute failed.");
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CropAndResizeGradImageCpuKernel::GradOfImageCompute(CpuKernelContext &ctx, int64_t start, int64_t end) {
Tensor *grads_tensor = ctx.Input(0);
Tensor *boxes_tensor = ctx.Input(1);
Tensor *box_index_tensor = ctx.Input(2);
Tensor *image_size_tensor = ctx.Input(3);
Tensor *output_tensor = ctx.Output(0);
float *grads = reinterpret_cast<float *>(grads_tensor->GetData());
int32_t *image_size = reinterpret_cast<int32_t *>(image_size_tensor->GetData());
float *boxes = reinterpret_cast<float *>(boxes_tensor->GetData());
T *outputDatas = reinterpret_cast<T *>(output_tensor->GetData());
int32_t *box_index = reinterpret_cast<int32_t *>(box_index_tensor->GetData());
const int64_t image_batch = *(image_size + 0);
const int64_t image_height = *(image_size + 1);
const int64_t image_width = *(image_size + 2);
const int64_t depth = *(image_size + 3);
const int64_t crop_height = grads_shape_[1];
const int64_t crop_width = grads_shape_[2];
const int64_t crop_depth = grads_shape_[3];
const int64_t boxesCoordinateNum = boxes_shape_[1];
const int64_t num_image2 = image_height * image_width * depth;
const int64_t num_image3 = image_width * depth;
const int64_t num_crop2 = crop_height * crop_width * crop_depth;
const int64_t num_crop3 = crop_width * crop_depth;
for (int64_t b = start; b < end; b++) {
const float y1 = *(boxes + b * boxesCoordinateNum + 0);
const float x1 = *(boxes + b * boxesCoordinateNum + 1);
const float y2 = *(boxes + b * boxesCoordinateNum + 2);
const float x2 = *(boxes + b * boxesCoordinateNum + 3);
const int64_t b_in = *(box_index + b);
if (b_in < 0 || b_in > image_batch - 1) {
continue;
}
float height_scale = 0;
float width_scale = 0;
if (crop_height > 1) {
height_scale = (y2 - y1) * (image_height - 1) / (crop_height - 1);
}
if (crop_width > 1) {
width_scale = (x2 - x1) * (image_width - 1) / (crop_width - 1);
}
for (int64_t y = 0; y < crop_height; y++) {
float in_y = 0.5 * (y1 + y2) * (image_height - 1);
if (crop_height > 1) {
in_y = y1 * (image_height - 1) + y * height_scale;
}
if (in_y < 0 || in_y > image_height - 1) {
continue;
}
const int64_t top_y_index = floorf(in_y);
const int64_t bottom_y_index = ceilf(in_y);
const float y_lerp = in_y - top_y_index;
for (int64_t x = 0; x < crop_width; x++) {
float in_x = 0.5 * (x1 + x2) * (image_width - 1);
if (crop_width > 1) {
in_x = x1 * (image_width - 1) + x * width_scale;
}
if (in_x < 0 || in_x > image_width - 1) {
continue;
}
AttrValue *attr = ctx.GetAttr("method");
std::string str = attr->GetString();
if (str == "bilinear") {
const int64_t left_x_index = floorf(in_x);
const int64_t right_x_index = ceilf(in_x);
const float x_lerp = in_x - left_x_index;
for (int64_t d = 0; d < depth; d++) {
const float dtop = (*(grads + b * num_crop2 + y * num_crop3 + x * crop_depth + d)) * (1 - y_lerp);
*(outputDatas + b_in * num_image2 + top_y_index * num_image3 + left_x_index * depth + d) +=
static_cast<T>((1 - x_lerp) * dtop);
*(outputDatas + b_in * num_image2 + top_y_index * num_image3 + right_x_index * depth + d) +=
static_cast<T>(x_lerp * dtop);
const float dbottom = (*(grads + b * num_crop2 + y * num_crop3 + x * crop_depth + d)) * y_lerp;
*(outputDatas + b_in * num_image2 + bottom_y_index * num_image3 + left_x_index * depth + d) +=
static_cast<T>((1 - x_lerp) * dbottom);
*(outputDatas + b_in * num_image2 + bottom_y_index * num_image3 + right_x_index * depth + d) +=
static_cast<T>(x_lerp * dbottom);
}
} else {
for (int64_t d = 0; d < depth; d++) {
const int close_x_index = roundf(in_x);
const int close_y_index = roundf(in_y);
*(outputDatas + b_in * num_image2 + close_y_index * num_image3 + close_x_index * depth + d) +=
static_cast<T>(*(grads + b * num_crop2 + y * num_crop3 + x * crop_depth + d));
}
}
}
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t CropAndResizeGradImageCpuKernel::GradOfImageComputeShared(CpuKernelContext &ctx) {
Tensor *image_size_tensor = ctx.Input(3);
Tensor *output_tensor = ctx.Output(0);
int32_t *image_size = reinterpret_cast<int32_t *>(image_size_tensor->GetData());
T *outputDatas = reinterpret_cast<T *>(output_tensor->GetData());
const int64_t image_height = *(image_size + 1);
const int64_t image_width = *(image_size + 2);
const int64_t depth = *(image_size + 3);
KERNEL_CHECK_FALSE((image_height > 0 && image_width > 0), KERNEL_STATUS_PARAM_INVALID,
"image dimensions must be positive.");
const int64_t nums_boxes = grads_shape_[0];
const int64_t crop_depth = grads_shape_[3];
KERNEL_CHECK_FALSE((depth == crop_depth), KERNEL_STATUS_PARAM_INVALID, "image_size and grads are incompatible.");
const int64_t num_image1 = nums_boxes * image_height * image_width * depth;
// Set the output data to 0.
T temp = static_cast<T>(0.0);
for (int i = 0; i < num_image1; i++) {
*(outputDatas + i) = temp;
}
auto shared_CropAndResizeGradImage = [&](size_t start, size_t end) {
uint32_t res = GradOfImageCompute<T>(ctx, start, end);
return res;
};
CpuKernelUtils::ParallelFor(ctx, nums_boxes, 1, shared_CropAndResizeGradImage);
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kCropAndResizeGradImage, CropAndResizeGradImageCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CROPANDRESIZEGRADIMAGE_H_
#define AICPU_KERNELS_NORMALIZED_CROPANDRESIZEGRADIMAGE_H_
#include "Eigen/Core"
#include "cpu_ops_kernel.h"
#include "utils/bcast.h"
namespace aicpu {
class CropAndResizeGradImageCpuKernel : public CpuKernel {
public:
CropAndResizeGradImageCpuKernel() = default;
~CropAndResizeGradImageCpuKernel() = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t cheakInputTypeAndGetDatas(CpuKernelContext &ctx);
template <typename T>
uint32_t GradOfImageCompute(CpuKernelContext &ctx, int64_t start, int64_t end);
template <typename T>
uint32_t GradOfImageComputeShared(CpuKernelContext &ctx);
std::vector<int64_t> grads_shape_;
std::vector<int64_t> image_size_shape_;
std::vector<int64_t> boxes_shape_;
std::vector<int64_t> box_ind_shape_;
std::vector<int64_t> output_shape_;
DataType data_type_;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,160 @@
#ifndef AICPU_KERNELS_DEPTHTOSPACE_CC_
#define AICPU_KERNELS_DEPTHTOSPACE_CC_
#include "depth_to_space.h"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include <iostream>
#include <thread>
#include <unordered_map>
#include <mutex>
namespace {
const uint32_t kInputNum = 1;
const uint32_t kOutputNum = 1;
const char *kDepthToSpace = "DepthToSpace";
#define DEPTHTOSPACE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = DoCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("DepthToSpace kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
template <typename T>
uint32_t DepthToSpaceCpuKernel::DoCompute(CpuKernelContext &ctx) {
auto input_shape = ctx.Input(0)->GetTensorShape();
auto output_shape = ctx.Output(0)->GetTensorShape();
auto input_dims = input_shape->GetDimSizes();
std::vector<std::string> attr_name1 = {"data_format"};
AttrValue *attr_data_format = ctx.GetAttr("data_format");
std::vector<std::string> attr_name2 = {"block_size"};
data_format_ = (attr_data_format == nullptr) ? "NHWC" : (attr_data_format->GetString());
int64_t block_size = ctx.GetAttr("block_size")->GetInt();
int64_t zero = 0;
int64_t two = 2;
int64_t n_nhwc = 0;
int64_t h_nhwc = 1;
int64_t w_nhwc = 2;
int64_t c_nhwc = 3;
int64_t n_nchw = 0;
int64_t h_nchw = 1;
int64_t w_nchw = 2;
int64_t c_nchw = 3;
if (block_size == zero && block_size * block_size == zero) {
return KERNEL_STATUS_PARAM_INVALID;
}
KERNEL_CHECK_FALSE((block_size >= two), KERNEL_STATUS_PARAM_INVALID,
"The value of block_size must be greater than 2");
std::vector<int64_t> output_dims;
if (data_format_ == "NHWC") {
KERNEL_CHECK_FALSE((input_dims[c_nhwc] % block_size * block_size == zero), KERNEL_STATUS_PARAM_INVALID,
"Channels must can be divided by block_size * block_size.");
output_dims = {input_dims[n_nhwc], input_dims[h_nhwc] * block_size, input_dims[w_nhwc] * block_size,
input_dims[c_nhwc] / (block_size * block_size)};
output_shape->SetDimSizes(output_dims);
input_dims = {input_dims[n_nhwc], input_dims[c_nhwc], input_dims[h_nhwc], input_dims[w_nhwc]};
output_dims = {output_dims[n_nhwc], output_dims[c_nhwc], output_dims[h_nhwc], output_dims[w_nhwc]};
} else if (data_format_ == "NCHW") {
KERNEL_CHECK_FALSE((input_dims[h_nchw] % block_size * block_size == zero), KERNEL_STATUS_PARAM_INVALID,
"Channels must can be divided by block_size * block_size.");
output_dims = {input_dims[n_nchw], input_dims[h_nchw] / (block_size * block_size), input_dims[w_nchw] * block_size,
input_dims[c_nchw] * block_size};
output_shape->SetDimSizes(output_dims);
}
auto input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t x = 0;
const size_t data_num = (size_t)ctx.Input(0)->NumElements();
for (size_t i = 0; i < data_num; i = i + block_size) {
for (size_t j = i; j < block_size + i; ++j) {
if (j % (input_dims[h_nhwc] * input_dims[c_nhwc]) == 0) {
x = -1;
}
if (j % output_dims[h_nhwc] == 0) {
++x;
}
size_t number = 0, output_pos = 0;
size_t loc = j / output_dims[h_nhwc];
number += (loc / (output_dims[w_nhwc] * output_dims[c_nhwc])) * output_dims[w_nhwc] * output_dims[c_nhwc];
// Mark the position of this segment of the vector in the entire segment.
number += (input_dims[h_nhwc] * input_dims[c_nhwc] / output_dims[h_nhwc]) *
(loc / (input_dims[h_nhwc] * input_dims[c_nhwc] / output_dims[h_nhwc]));
// Label the position of the block within a segment of the vector.
number += ((loc % input_dims[h_nhwc]) / block_size) * block_size * input_dims[c_nhwc];
// Mark the relative position within the small block.
number += loc % block_size + (x / input_dims[h_nhwc]) * block_size;
output_pos = j % output_dims[h_nhwc] + number * output_dims[h_nhwc];
output[output_pos] = input[j];
}
}
return KERNEL_STATUS_OK;
} // DoCompute
uint32_t DepthToSpaceCpuKernel::STDParamCheck(CpuKernelContext &ctx) {
// check params
auto input = ctx.Input(0);
auto output = ctx.Output(0);
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "DepthToSpace check input and output number failed.");
KERNEL_LOG_DEBUG(
"DepthToSpaceCpuKernel[%s], input0: size[%llu];"
"output: size[%llu].",
ctx.GetOpType().c_str(), input->GetDataSize(), output->GetDataSize());
// check data_format
std::vector<std::string> attr_name1 = {"data_format"};
AttrValue *attr_data_format = ctx.GetAttr("data_format");
data_format_ = (attr_data_format == nullptr) ? "NHWC" : (attr_data_format->GetString());
KERNEL_CHECK_FALSE((data_format_ == "NHWC" || data_format_ == "NCHW"), KERNEL_STATUS_PARAM_INVALID,
"The data_format must be NCHW, NHWC or NCHW_VECT_C, but got: [%s]", data_format_);
return KERNEL_STATUS_OK;
}
uint32_t DepthToSpaceCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(STDParamCheck(ctx), "DepthToSpace check params failed.");
Tensor *input0_tensor = ctx.Input(0);
auto input_data_type = input0_tensor->GetDataType();
switch (input_data_type) {
DEPTHTOSPACE_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_FLOAT, float, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_DOUBLE, double, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_INT8, int8_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_INT16, int16_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_QINT8, int8_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_QINT16, int16_t, ctx)
DEPTHTOSPACE_COMPUTE_CASE(DT_QINT32, int32_t, ctx)
default:
KERNEL_LOG_ERROR("DepthToSpace kernel data type[%s] not support.", DTypeStr(input_data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kDepthToSpace, DepthToSpaceCpuKernel);
} // namespace aicpu
#endif // AICPU_KERNELS_SPACETODEPTH_CC_

View File

@ -0,0 +1,23 @@
#ifndef AICPU_KERNELS_DEPTHTOSPACE_H_
#define AICPU_KERNELS_DEPTHTOSPACE_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class DepthToSpaceCpuKernel : public CpuKernel {
public:
DepthToSpaceCpuKernel() = default;
~DepthToSpaceCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override final;
private:
uint32_t STDParamCheck(CpuKernelContext &ctx);
template <typename T>
uint32_t DoCompute(CpuKernelContext &ctx);
std::string data_format_;
}; // DepthToSpaceCpuKernel
} // namespace aicpu
#endif // AICPU_KERNELS_DEPTHTOSPACE_H_

View File

@ -0,0 +1,401 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "div.h"
#include <complex>
#include "cmath"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 2;
const char *kDiv = "Div";
// when input data size is more than kParallelDataNum, use Parallel func
const int64_t kParallelDataNum = 2 * 1024;
const int64_t kParallelDataNumMid = 16 * 1024;
const int64_t kParallelDataNumSameShape = 7 * 1024;
const int64_t kParallelDataNumSameShapeMid = 35 * 1024;
#define DIV_COMPUTE_CASEINT(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = DivComputeInt<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Div kernel compute failed."); \
return result; \
} \
break; \
}
#define DIV_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = DivCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Div kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t DivCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kDiv);
BCalcInfo calc_info;
KERNEL_HANDLE_ERROR(DivParamCheck(ctx), "Div check params failed.");
auto data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
DIV_COMPUTE_CASEINT(DT_INT8, int8_t, ctx)
DIV_COMPUTE_CASEINT(DT_INT16, int16_t, ctx)
DIV_COMPUTE_CASEINT(DT_INT32, int32_t, ctx)
DIV_COMPUTE_CASEINT(DT_INT64, int64_t, ctx)
DIV_COMPUTE_CASEINT(DT_UINT8, uint8_t, ctx)
DIV_COMPUTE_CASEINT(DT_UINT16, uint16_t, ctx)
DIV_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
DIV_COMPUTE_CASE(DT_FLOAT, float, ctx)
DIV_COMPUTE_CASE(DT_DOUBLE, double, ctx)
DIV_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
DIV_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("Div kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t DivCpuKernel::DivParamCheck(CpuKernelContext &ctx) {
// the non null of input_0, input_1, output has been verified in NormalCheck
Tensor *input_0 = ctx.Input(0);
Tensor *input_1 = ctx.Input(1);
Tensor *output = ctx.Output(0);
KERNEL_CHECK_NULLPTR(input_0->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.")
KERNEL_CHECK_NULLPTR(input_1->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.")
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
DataType input0_type = input_0->GetDataType();
DataType input1_type = input_1->GetDataType();
KERNEL_CHECK_FALSE((input0_type == input1_type), KERNEL_STATUS_PARAM_INVALID,
"The data type of input0 [%s] need be same with "
"input1 [%s].",
DTypeStr(input0_type).c_str(), DTypeStr(input1_type).c_str())
KERNEL_LOG_DEBUG(
"DivCpuKernel[%s], input0: size[%llu];"
"input1: size[%llu], output: size[%llu].",
ctx.GetOpType().c_str(), input_0->GetDataSize(), input_1->GetDataSize(), output->GetDataSize());
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::DivParamCheck_Zero(CpuKernelContext &ctx) {
auto input1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
int64_t input1_elements_nums = ctx.Input(1)->NumElements();
for (int64_t i = 0; i < input1_elements_nums; i++) {
if (static_cast<double>(*(input1 + i)) == 0) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
}
}
return KERNEL_STATUS_OK;
}
/**
special compute is used in the following situations.
1. the shapes of input1 and input2 are the same
2. input1 is a 1D tensor with only one element or input1 is scalar
3. input2 is a 1D tensor with only one element or input2 is scalar
4. the shapes of input1 and input2 are different
*/
template <typename T>
uint32_t DivCpuKernel::SpecialComputeInt(BcastShapeType type, int64_t start, int64_t end, const T *input1,
const T *input2, T *output) {
switch (type) {
case BcastShapeType::SAME_SHAPE:
for (int64_t i = start; i < end; ++i) {
if (*(input2 + i) == static_cast<T>(0)) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
} else {
T mod;
mod = (*(input1 + i)) % (*(input2 + i));
if (((*(input1 + i)) * (*(input2 + i)) < static_cast<T>(0)) && (mod != 0))
*(output + i) = (*(input1 + i)) / (*(input2 + i)) - static_cast<T>(1);
else
*(output + i) = (*(input1 + i)) / (*(input2 + i));
}
}
break;
case BcastShapeType::X_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
if (*(input2 + i) == static_cast<T>(0)) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
} else {
T mod;
mod = (*input1) % (*(input2 + i));
if (((*input1) * (*(input2 + i)) < static_cast<T>(0)) && (mod != 0))
*(output + i) = (*input1) / (*(input2 + i)) - static_cast<T>(1);
else
*(output + i) = (*input1) / (*(input2 + i));
}
}
break;
case BcastShapeType::Y_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
if (*input2 == static_cast<T>(0)) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
} else {
T mod;
mod = (*(input1 + i)) % (*input2);
if (((*(input1 + i)) * (*input2) < static_cast<T>(0)) && (mod != 0))
*(output + i) = (*(input1 + i)) / (*input2) - static_cast<T>(1);
else
*(output + i) = (*(input1 + i)) / (*input2);
}
}
break;
default:
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
break;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::SpecialCompute(BcastShapeType type, int64_t start, int64_t end, const T *input1, const T *input2,
T *output) {
switch (type) {
case BcastShapeType::SAME_SHAPE:
for (int64_t i = start; i < end; ++i) {
*(output + i) = *(input1 + i) / *(input2 + i);
}
break;
case BcastShapeType::X_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
*(output + i) = (*input1) / (*(input2 + i));
}
break;
case BcastShapeType::Y_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
*(output + i) = *(input1 + i) / (*input2);
}
break;
default:
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
break;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::NoBcastComputeInt(CpuKernelContext &ctx) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t in0_elements_nums = ctx.Input(0)->NumElements();
int64_t in1_elements_nums = ctx.Input(1)->NumElements();
int64_t data_num = ctx.Output(0)->NumElements();
BcastShapeType type = in0_elements_nums == in1_elements_nums
? BcastShapeType::SAME_SHAPE
: (in0_elements_nums == 1 ? BcastShapeType::X_ONE_ELEMENT : BcastShapeType::Y_ONE_ELEMENT);
if (data_num >= kParallelDataNumSameShape) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumSameShapeMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_div = [&](int64_t start, int64_t end) { SpecialComputeInt<T>(type, start, end, in0, in1, out); };
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_div),
"Div Compute failed.");
} else {
SpecialComputeInt<T>(type, 0, data_num, in0, in1, out);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::NoBcastCompute(CpuKernelContext &ctx) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t in0_elements_nums = ctx.Input(0)->NumElements();
int64_t in1_elements_nums = ctx.Input(1)->NumElements();
int64_t data_num = ctx.Output(0)->NumElements();
BcastShapeType type = in0_elements_nums == in1_elements_nums
? BcastShapeType::SAME_SHAPE
: (in0_elements_nums == 1 ? BcastShapeType::X_ONE_ELEMENT : BcastShapeType::Y_ONE_ELEMENT);
if (data_num >= kParallelDataNumSameShape) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumSameShapeMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_div = [&](int64_t start, int64_t end) { SpecialCompute<T>(type, start, end, in0, in1, out); };
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_div),
"Div Compute failed.");
} else {
SpecialCompute<T>(type, 0, data_num, in0, in1, out);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::BcastComputeInt(CpuKernelContext &ctx, Bcast &bcast) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Output(0)->NumElements();
if (data_num >= kParallelDataNum) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_divnonan = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
if (*(in1 + bcast.GetBroadcastYIndex(i)) == static_cast<T>(0)) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
} else {
T mod;
mod = *(in0 + bcast.GetBroadcastXIndex(i)) % *(in1 + bcast.GetBroadcastYIndex(i));
if (((*(in0 + bcast.GetBroadcastXIndex(i))) * (*(in1 + bcast.GetBroadcastYIndex(i))) < static_cast<T>(0)) &&
(mod != 0))
*(out + i) =
*(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i)) - static_cast<T>(1);
else
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
}
return KERNEL_STATUS_OK;
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_divnonan),
"DivNoNan Compute failed.");
} else {
for (int64_t i = 0; i < data_num; ++i) {
if (*(in1 + bcast.GetBroadcastYIndex(i)) == static_cast<T>(0)) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return KERNEL_STATUS_INNER_ERROR;
} else {
T mod;
mod = *(in0 + bcast.GetBroadcastXIndex(i)) % *(in1 + bcast.GetBroadcastYIndex(i));
if (((*(in0 + bcast.GetBroadcastXIndex(i))) * (*(in1 + bcast.GetBroadcastYIndex(i))) < static_cast<T>(0)) &&
(mod != 0))
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i)) - static_cast<T>(1);
else
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::BcastCompute(CpuKernelContext &ctx, Bcast &bcast) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Output(0)->NumElements();
if (data_num >= kParallelDataNum) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_div = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_div),
"Div Compute failed.");
} else {
for (int64_t i = 0; i < data_num; ++i) {
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::DivComputeInt(CpuKernelContext &ctx) {
Tensor *input0_tensor = ctx.Input(0);
auto input0_shape = input0_tensor->GetTensorShape()->GetDimSizes();
int64_t input0_elements_nums = input0_tensor->NumElements();
Tensor *input1_tensor = ctx.Input(1);
auto input1_shape = input1_tensor->GetTensorShape()->GetDimSizes();
int64_t input1_elements_nums = input1_tensor->NumElements();
bool isNeedBcast = (input0_shape == input1_shape) || (input0_elements_nums == 1) || (input1_elements_nums == 1);
if (isNeedBcast) {
uint32_t result1 = DivParamCheck_Zero<T>(ctx);
if (result1 != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return result1;
}
return NoBcastComputeInt<T>(ctx);
} else {
Bcast bcast(input0_shape, input1_shape);
if (!bcast.IsValid()) {
KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t result1 = DivParamCheck_Zero<T>(ctx);
if (result1 != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("Invalid argumengt: Division by zero.");
return result1;
}
return BcastComputeInt<T>(ctx, bcast);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivCpuKernel::DivCompute(CpuKernelContext &ctx) {
Tensor *input0_tensor = ctx.Input(0);
auto input0_shape = input0_tensor->GetTensorShape()->GetDimSizes();
int64_t input0_elements_nums = input0_tensor->NumElements();
Tensor *input1_tensor = ctx.Input(1);
auto input1_shape = input1_tensor->GetTensorShape()->GetDimSizes();
int64_t input1_elements_nums = input1_tensor->NumElements();
bool isNeedBcast = (input0_shape == input1_shape) || (input0_elements_nums == 1) || (input1_elements_nums == 1);
if (isNeedBcast) {
return NoBcastCompute<T>(ctx);
} else {
Bcast bcast(input0_shape, input1_shape);
if (!bcast.IsValid()) {
KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return BcastCompute<T>(ctx, bcast);
}
return KERNEL_STATUS_OK;
} // namespace aicpu
REGISTER_CPU_KERNEL(kDiv, DivCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_DIV_H_
#define AICPU_KERNELS_NORMALIZED_DIV_H_
#define EIGEN_USE_THREADS
#define EIGEN_USE_SIMPLE_THREAD_POOL
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
namespace aicpu {
class DivCpuKernel : public CpuKernel {
public:
DivCpuKernel() = default;
~DivCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t DivParamCheck(CpuKernelContext &ctx);
private:
template <typename T>
uint32_t DivParamCheck_Zero(CpuKernelContext &ctx);
template <typename T>
uint32_t SpecialComputeInt(BcastShapeType type, int64_t start, int64_t end, const T *input1, const T *input2,
T *output);
template <typename T>
uint32_t SpecialCompute(BcastShapeType type, int64_t start, int64_t end, const T *input1, const T *input2, T *output);
template <typename T>
uint32_t NoBcastComputeInt(CpuKernelContext &ctx);
template <typename T>
uint32_t NoBcastCompute(CpuKernelContext &ctx);
template <typename T>
uint32_t BcastComputeInt(CpuKernelContext &ctx, Bcast &bcast);
template <typename T>
uint32_t BcastCompute(CpuKernelContext &ctx, Bcast &bcast);
template <typename T>
uint32_t DivComputeInt(CpuKernelContext &ctx);
template <typename T>
uint32_t DivCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,216 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "divnonan.h"
#include <complex>
#include "cmath"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 2;
const char *kDivNoNan = "DivNoNan";
// when input data size is more than kParallelDataNum, use Parallel func
const int64_t kParallelDataNum = 2 * 1024;
const int64_t kParallelDataNumMid = 16 * 1024;
const int64_t kParallelDataNumSameShape = 7 * 1024;
const int64_t kParallelDataNumSameShapeMid = 35 * 1024;
#define DIVNONAN_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = DivNoNanCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("DivNoNan kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t DivNoNanCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kDivNoNan);
BCalcInfo calc_info;
KERNEL_HANDLE_ERROR(DivNoNanParamCheck(ctx), "DivNoNan check params failed.");
auto data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
DIVNONAN_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
DIVNONAN_COMPUTE_CASE(DT_FLOAT, float, ctx)
DIVNONAN_COMPUTE_CASE(DT_DOUBLE, double, ctx)
DIVNONAN_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
DIVNONAN_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("Div kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t DivNoNanCpuKernel::DivNoNanParamCheck(CpuKernelContext &ctx) {
// the non null of input_0, input_1, output has been verified in NormalCheck
Tensor *input_0 = ctx.Input(0);
Tensor *input_1 = ctx.Input(1);
Tensor *output = ctx.Output(0);
KERNEL_CHECK_NULLPTR(input_0->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.")
KERNEL_CHECK_NULLPTR(input_1->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.")
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
DataType input0_type = input_0->GetDataType();
DataType input1_type = input_1->GetDataType();
KERNEL_CHECK_FALSE((input0_type == input1_type), KERNEL_STATUS_PARAM_INVALID,
"The data type of input0 [%s] need be same with "
"input1 [%s].",
DTypeStr(input0_type).c_str(), DTypeStr(input1_type).c_str())
KERNEL_LOG_DEBUG(
"DivNoNanCpuKernel[%s], input0: size[%llu];"
"input1: size[%llu], output: size[%llu].",
ctx.GetOpType().c_str(), input_0->GetDataSize(), input_1->GetDataSize(), output->GetDataSize());
return KERNEL_STATUS_OK;
}
/**
special compute is used in the following situations.
1. the shapes of input1 and input2 are the same
2. input1 is a 1D tensor with only one element or input1 is scalar
3. input2 is a 1D tensor with only one element or input2 is scalar
4. the shapes of input1 and input2 are different
*/
template <typename T>
uint32_t DivNoNanCpuKernel::SpecialCompute(BcastShapeType type, int64_t start, int64_t end, const T *input1,
const T *input2, T *output) {
switch (type) {
case BcastShapeType::SAME_SHAPE:
for (int64_t i = start; i < end; ++i) {
if (*(input2 + i) == static_cast<T>(0)) {
*(output + i) = static_cast<T>(0);
} else
*(output + i) = *(input1 + i) / *(input2 + i);
}
break;
case BcastShapeType::X_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
if (*(input2 + i) == static_cast<T>(0)) {
*(output + i) = static_cast<T>(0);
} else
*(output + i) = (*input1) / (*(input2 + i));
}
break;
case BcastShapeType::Y_ONE_ELEMENT:
for (int64_t i = start; i < end; ++i) {
if (*input2 == static_cast<T>(0)) {
*(output + i) = static_cast<T>(0);
} else
*(output + i) = *(input1 + i) / (*input2);
}
break;
default:
KERNEL_LOG_WARN("Invalid type [%d]", static_cast<int32_t>(type));
break;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivNoNanCpuKernel::NoBcastCompute(CpuKernelContext &ctx) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t in0_elements_nums = ctx.Input(0)->NumElements();
int64_t in1_elements_nums = ctx.Input(1)->NumElements();
int64_t data_num = ctx.Output(0)->NumElements();
BcastShapeType type = in0_elements_nums == in1_elements_nums
? BcastShapeType::SAME_SHAPE
: (in0_elements_nums == 1 ? BcastShapeType::X_ONE_ELEMENT : BcastShapeType::Y_ONE_ELEMENT);
if (data_num >= kParallelDataNumSameShape) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumSameShapeMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_divnonan = [&](int64_t start, int64_t end) { SpecialCompute<T>(type, start, end, in0, in1, out); };
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_divnonan),
"Div Compute failed.");
} else {
SpecialCompute<T>(type, 0, data_num, in0, in1, out);
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivNoNanCpuKernel::BcastCompute(CpuKernelContext &ctx, Bcast &bcast) {
auto in0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto in1 = reinterpret_cast<T *>(ctx.Input(1)->GetData());
auto out = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Output(0)->NumElements();
if (data_num >= kParallelDataNum) {
uint32_t min_core_num = 1;
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (data_num <= kParallelDataNumMid) {
max_core_num = std::min(max_core_num, 4U); // up to 4 cpu cores
}
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto sharder_divnonan = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
if (*(in1 + bcast.GetBroadcastYIndex(i)) == static_cast<T>(0)) {
*(out + i) = static_cast<T>(0);
} else
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, sharder_divnonan),
"DivNoNan Compute failed.");
} else {
for (int64_t i = 0; i < data_num; ++i) {
if (*(in1 + bcast.GetBroadcastYIndex(i)) == static_cast<T>(0)) {
*(out + i) = static_cast<T>(0);
} else
*(out + i) = *(in0 + bcast.GetBroadcastXIndex(i)) / *(in1 + bcast.GetBroadcastYIndex(i));
}
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DivNoNanCpuKernel::DivNoNanCompute(CpuKernelContext &ctx) {
Tensor *input0_tensor = ctx.Input(0);
auto input0_shape = input0_tensor->GetTensorShape()->GetDimSizes();
int64_t input0_elements_nums = input0_tensor->NumElements();
Tensor *input1_tensor = ctx.Input(1);
auto input1_shape = input1_tensor->GetTensorShape()->GetDimSizes();
int64_t input1_elements_nums = input1_tensor->NumElements();
bool isNeedBcast = (input0_shape == input1_shape) || (input0_elements_nums == 1) || (input1_elements_nums == 1);
if (isNeedBcast) {
return NoBcastCompute<T>(ctx);
} else {
Bcast bcast(input0_shape, input1_shape);
if (!bcast.IsValid()) {
KERNEL_LOG_ERROR("[%s] broadcast failed.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return BcastCompute<T>(ctx, bcast);
}
return KERNEL_STATUS_OK;
} // namespace aicpu
REGISTER_CPU_KERNEL(kDivNoNan, DivNoNanCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_DIVNONAN_H_
#define AICPU_KERNELS_NORMALIZED_DIVNONAN_H_
#define EIGEN_USE_THREADS
#define EIGEN_USE_SIMPLE_THREAD_POOL
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
namespace aicpu {
class DivNoNanCpuKernel : public CpuKernel {
public:
DivNoNanCpuKernel() = default;
~DivNoNanCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t DivNoNanParamCheck(CpuKernelContext &ctx);
template <typename T>
uint32_t SpecialCompute(BcastShapeType type, int64_t start, int64_t end, const T *input1, const T *input2, T *output);
template <typename T>
uint32_t NoBcastCompute(CpuKernelContext &ctx);
template <typename T>
uint32_t BcastCompute(CpuKernelContext &ctx, Bcast &bcast);
template <typename T>
uint32_t DivNoNanCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,174 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "expm1.h"
#include "cpu_kernel_utils.h"
#include "math.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include <math.h>
#include <iostream>
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const char *kExpm1 = "Expm1";
#define EXPM1_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = Expm1Compute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Expm1 kernel compute failed."); \
return result; \
} \
break; \
}
#define EXPM1_COMPUTE_CASE2(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = Expm1Compute2<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Expm1 kernel compute failed."); \
return result; \
} \
break; \
}
#define EXPM1_COMPUTE_CASE3(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = Expm1Compute3<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Expm1 kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t Expm1CpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kExpm1);
KERNEL_HANDLE_ERROR(Expm1Check(ctx), "[%s] check params failed.", kExpm1);
DataType data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
EXPM1_COMPUTE_CASE2(DT_FLOAT16, Eigen::half, ctx)
EXPM1_COMPUTE_CASE3(DT_FLOAT, float, ctx)
EXPM1_COMPUTE_CASE3(DT_DOUBLE, double, ctx)
EXPM1_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
EXPM1_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("Expm1 kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t Expm1CpuKernel::Expm1Check(CpuKernelContext &ctx) {
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed.")
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t Expm1CpuKernel::Expm1Compute(CpuKernelContext &ctx) {
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
auto data_type = ctx.Input(0)->GetDataType();
size_t data_num = ctx.Input(0)->NumElements();
int64_t data_size = data_num * sizeof(T);
T num0 = static_cast<T>(-1.0);
if (((data_type = DT_COMPLEX64) && (data_size <= 64 * 1024)) ||
((data_type = DT_COMPLEX128) && (data_size <= 64 * 1024))) {
for (size_t i = 0; i < data_num; i++) {
(*(output_y + i)) = Eigen::numext::exp(*(input_x + i)) + num0;
}
} else {
uint32_t min_core_num = 1;
size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto shard_expm1 = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
(*(output_y + i)) = Eigen::numext::exp(*(input_x + i)) + num0;
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_expm1),
"Expm1 Compute failed.");
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t Expm1CpuKernel::Expm1Compute2(CpuKernelContext &ctx) {
auto input_x = reinterpret_cast<Eigen::half *>(ctx.Input(0)->GetData());
auto output_y = reinterpret_cast<Eigen::half *>(ctx.Output(0)->GetData());
size_t data_num = ctx.Input(0)->NumElements();
int64_t data_size = data_num * sizeof(Eigen::half);
Eigen::half num0 = static_cast<Eigen::half>(-1.0);
if (data_size <= 32 * 1024) {
for (size_t i = 0; i < data_num; i++) {
*(output_y + i) = exp(*(input_x + i)) + num0;
}
} else {
uint32_t min_core_num = 1;
size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto shard_expm1 = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
*(output_y + i) = exp(*(input_x + i)) + num0;
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_expm1),
"Expm1 Compute failed.");
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t Expm1CpuKernel::Expm1Compute3(CpuKernelContext &ctx) {
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_y = reinterpret_cast<T *>(ctx.Output(0)->GetData());
auto data_type = ctx.Input(0)->GetDataType();
size_t data_num = ctx.Input(0)->NumElements();
int64_t data_size = data_num * sizeof(T);
if ((data_type == DT_DOUBLE && data_size <= 64 * 1024) || (data_type == DT_FLOAT && data_size <= 16 * 1024)) {
for (size_t i = 0; i < data_num; i++) {
*(output_y + i) = expm1(*(input_x + i));
}
} else {
uint32_t min_core_num = 1;
size_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto shard_expm1 = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
*(output_y + i) = expm1(*(input_x + i));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_expm1),
"Expm1 Compute failed.");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kExpm1, Expm1CpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_EXPM1_H_
#define AICPU_KERNELS_NORMALIZED_EXPM1_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class Expm1CpuKernel : public CpuKernel {
public:
Expm1CpuKernel() = default;
~Expm1CpuKernel() override = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t Expm1Check(CpuKernelContext &ctx);
template <typename T>
uint32_t Expm1Compute(CpuKernelContext &ctx);
template <typename T>
uint32_t Expm1Compute2(CpuKernelContext &ctx);
template <typename T>
uint32_t Expm1Compute3(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,151 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "hamming_window.h"
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "utils/kernel_util.h"
#include "utils/eigen_tensor.h"
namespace {
const char *kHammingWindow = "HammingWindow";
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
constexpr int64_t kParallelDataNums = 16 * 1024;
constexpr int64_t kParallelDataNumsMid = 7 * 1024;
#define WINDOW_LENGTH_CASE(DTYPE, TYPE, LENGTH, CTX) \
case (DTYPE): { \
TYPE *length_addr = reinterpret_cast<TYPE *>(ctx.Input(0)->GetData()); \
LENGTH = static_cast<int64_t>(*length_addr); \
break; \
}
#define SWITCH_PARALLEL(SHARD, end_num) \
if (end_num >= kParallelDataNumsMid) { \
uint32_t min_core_num = 1; \
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); \
if (end_num < kParallelDataNums) { \
max_core_num = std::min(max_core_num, 4L); \
} \
if (max_core_num > end_num) { \
max_core_num = end_num; \
} \
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, end_num, end_num / max_core_num, SHARD), \
"HammingWindow #SHARD Compute failed."); \
} else { \
SHARD(0, end_num); \
}
} // namespace
namespace aicpu {
uint32_t HammingWindowCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "HammingWindow check input and output number failed.");
int64_t dtype = 0;
AttrValue *dtype_attr = ctx.GetAttr("dtype");
if (dtype_attr != nullptr) {
dtype = dtype_attr->GetInt();
}
DataType data_type = static_cast<DataType>(dtype);
ctx.Output(0)->SetDataType(data_type);
switch (data_type) {
case DT_FLOAT:
return HammingWindowCompute<float>(ctx);
case DT_FLOAT16:
return HammingWindowCompute<Eigen::half>(ctx);
case DT_DOUBLE:
return HammingWindowCompute<double>(ctx);
default:
KERNEL_LOG_ERROR(
"Attribute dtype only supports floating point types, "
"but got:[%s].",
DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
template <typename T>
uint32_t HammingWindowCpuKernel::HammingWindowCompute(CpuKernelContext &ctx) {
DataType input_type = ctx.Input(0)->GetDataType();
int64_t length;
switch (input_type) {
WINDOW_LENGTH_CASE(DT_INT8, int8_t, length, ctx)
WINDOW_LENGTH_CASE(DT_INT16, int16_t, length, ctx)
WINDOW_LENGTH_CASE(DT_INT32, int32_t, length, ctx)
WINDOW_LENGTH_CASE(DT_INT64, int64_t, length, ctx)
WINDOW_LENGTH_CASE(DT_UINT8, uint8_t, length, ctx)
WINDOW_LENGTH_CASE(DT_UINT16, uint16_t, length, ctx)
WINDOW_LENGTH_CASE(DT_UINT32, uint32_t, length, ctx)
WINDOW_LENGTH_CASE(DT_UINT64, uint64_t, length, ctx)
default:
KERNEL_LOG_ERROR("HammingWindow input data type [%s] not support.", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
KERNEL_CHECK_FALSE((length >= 0), KERNEL_STATUS_PARAM_INVALID,
"Input window length cannot be negative, bug got [%d].", length);
Tensor *y_tensor = ctx.Output(0);
auto y_shape = y_tensor->GetTensorShape();
std::vector<int64_t> y_dim = y_shape->GetDimSizes();
y_dim.clear();
if (length != 0) {
y_dim.push_back(length);
}
y_shape->SetDimSizes(y_dim);
y_tensor->SetTensorShape(y_shape.get());
y_tensor->SetDataSize(length * sizeof(T));
T *y_addr = reinterpret_cast<T *>(y_tensor->GetData());
if (length == 0) {
return KERNEL_STATUS_OK;
} else if (length == 1) {
*y_addr = T{1};
return KERNEL_STATUS_OK;
} else {
bool periodic = true;
AttrValue *periodic_attr = ctx.GetAttr("periodic");
if (periodic_attr != nullptr) {
periodic = periodic_attr->GetBool();
}
int64_t window_length = length;
if (periodic) {
length += 1;
}
float alpha = 0.54;
AttrValue *alpha_attr = ctx.GetAttr("alpha");
if (alpha_attr != nullptr) {
alpha = alpha_attr->GetFloat();
}
float beta = 0.46;
AttrValue *beta_attr = ctx.GetAttr("beta");
if (beta_attr != nullptr) {
beta = beta_attr->GetFloat();
}
constexpr double t_pi = 6.283185307179586476925286766559;
auto shard_hamming_window = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
double result = alpha - beta * std::cos(i * t_pi / (length - 1));
*(y_addr + i) = static_cast<T>(result);
}
};
SWITCH_PARALLEL(shard_hamming_window, window_length);
return KERNEL_STATUS_OK;
}
}
REGISTER_CPU_KERNEL(kHammingWindow, HammingWindowCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_HAMMING_WINDOW_H_
#define AICPU_KERNELS_HAMMING_WINDOW_H_
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
namespace aicpu {
class HammingWindowCpuKernel : public CpuKernel {
public:
HammingWindowCpuKernel() = default;
~HammingWindowCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename T>
uint32_t HammingWindowCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,99 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "imag.h"
#include "Eigen/Eigen"
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const char *kImag = "Imag";
constexpr int64_t kFolatDataNums = 8 * 128 * 1024;
constexpr int64_t kDoubleDataNums = 16 * 128 * 1024;
#define Imag_COMPUTE_CASE(IN_DTYPE, IN_TYPE, OUT_DTYPE, CTX) \
case (IN_DTYPE): { \
switch (OUT_DTYPE) { \
case (DT_FLOAT): { \
uint32_t result = ImagCompute<IN_TYPE, float>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Imag kernel compute failed."); \
return result; \
} \
break; \
} \
case (DT_DOUBLE): { \
uint32_t result = ImagCompute<IN_TYPE, double>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Imag kernel compute failed."); \
return result; \
} \
break; \
} \
default: \
KERNEL_LOG_ERROR("Imag kernel output data type [%s] not support.", DTypeStr(OUT_DTYPE).c_str()); \
return KERNEL_STATUS_PARAM_INVALID; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t ImagCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kImag);
DataType input_type = ctx.Input(0)->GetDataType();
switch (input_type) {
Imag_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, DT_FLOAT, ctx)
Imag_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, DT_DOUBLE, ctx) default
: KERNEL_LOG_ERROR("Imag kernel input data type [%s] not support.", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T, typename t>
uint32_t ImagCpuKernel::ImagCompute(CpuKernelContext &ctx) {
auto input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output = reinterpret_cast<t *>(ctx.Output(0)->GetData());
auto data_type = ctx.Input(0)->GetDataType();
int64_t data_num = ctx.Output(0)->NumElements();
int64_t data_size = data_num * sizeof(T);
if ((data_type == DT_COMPLEX64 && data_size <= kFolatDataNums) ||
(data_type == DT_COMPLEX128 && data_size <= kDoubleDataNums)) {
for (int64_t index = 0; index < data_num; ++index) {
*(output + index) = (*(input + index)).imag();
}
} else {
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
if (max_core_num > data_num) {
max_core_num = data_num;
}
auto shard_imag = [&](size_t start, size_t end) {
for (size_t index = start; index < end; ++index) {
*(output + index) = (*(input + index)).imag();
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_imag),
"imag Compute failed");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kImag, ImagCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,40 @@
/**
* Copyright (C) 2020-2021. Huawei Technologies Co., Ltd. All rights reserved.
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*
* @brief
*
* @version 1.0
*
*/
#ifndef AICPU_KERNELS_NORMALIZED_REAL_H_
#define AICPU_KERNELS_NORMALIZED_REAL_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class ImagCpuKernel : public CpuKernel {
public:
ImagCpuKernel() = default;
~ImagCpuKernel() override = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ImagCheck(CpuKernelContext &ctx);
template <typename T, typename t>
uint32_t ImagCompute(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif