forked from mindspore-Ecosystem/mindspore
!48349 merge canndev code to mindspore
Merge pull request !48349 from 沈竞兴/canndev_last1
This commit is contained in:
commit
67abbb89d1
|
@ -107,3 +107,4 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "identicalConditionAfterEarlyExit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "uninitMemberVar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "redundantInitialization"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "redundantCondition"
|
|
@ -346,3 +346,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpSpecialComputeComplex
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel<Op>::SparseDenseCwiseOpBcastComputeComplex
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc:aicpu::ResizeBicubicGrad
|
||||
|
|
|
@ -0,0 +1,439 @@
|
|||
#include "resize_bicubic_grad.h"
|
||||
|
||||
#include <securec.h>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "utils/sparse_tensor.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t kInputNum = 2;
|
||||
constexpr uint32_t kOutputNum = 1;
|
||||
static const int64_t kTableSize = (1 << 10);
|
||||
const int64_t kParallelDataNum = 1024 * 256;
|
||||
const char *kResizeBicubicGrad = "ResizeBicubicGrad";
|
||||
std::vector<int64_t> size_;
|
||||
std::vector<int64_t> shape_;
|
||||
float height_scale_ = 0;
|
||||
float width_scale_ = 0;
|
||||
bool align_corners_ = false;
|
||||
bool half_pixel_centers_ = false;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
DataType dtype0_ = DT_FLOAT;
|
||||
DataType dtype1_ = DT_FLOAT;
|
||||
DataType dtype2_ = DT_FLOAT;
|
||||
|
||||
float Scaling64_(int64_t in_size, int64_t out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
struct ResizerGradState {
|
||||
void CalculateSize(CpuKernelContext &ctx) {
|
||||
Tensor *input0_tensor = ctx.Input(0);
|
||||
Tensor *input1_tensor = ctx.Input(1);
|
||||
shape_ = input0_tensor->GetTensorShape()->GetDimSizes();
|
||||
size_ = input1_tensor->GetTensorShape()->GetDimSizes();
|
||||
|
||||
batch_size = shape_[0];
|
||||
channels = shape_[3];
|
||||
resized_height = shape_[1];
|
||||
resized_width = shape_[2];
|
||||
|
||||
original_height = size_[1];
|
||||
original_width = size_[2];
|
||||
|
||||
height_scale = Scaling64_(original_height, resized_height, align_corners_);
|
||||
width_scale = Scaling64_(original_width, resized_width, align_corners_);
|
||||
}
|
||||
int64_t Calindex(const int64_t x1, const int64_t x2, const int64_t x3, const int64_t x4, const bool flag_) {
|
||||
if (!flag_) {
|
||||
return static_cast<int64_t>(x1 * original_height * original_width * channels) +
|
||||
static_cast<int64_t>(x2 * original_width * channels) + static_cast<int64_t>(x3 * channels) +
|
||||
static_cast<int64_t>(x4);
|
||||
} else {
|
||||
return static_cast<int64_t>(x1 * resized_height * resized_width * channels) +
|
||||
static_cast<int64_t>(x2 * resized_width * channels) + static_cast<int64_t>(x3 * channels) +
|
||||
static_cast<int64_t>(x4);
|
||||
}
|
||||
}
|
||||
int64_t batch_size;
|
||||
int64_t channels;
|
||||
|
||||
int64_t original_height;
|
||||
int64_t original_width;
|
||||
|
||||
int64_t resized_height;
|
||||
int64_t resized_width;
|
||||
|
||||
float height_scale;
|
||||
float width_scale;
|
||||
};
|
||||
|
||||
struct WeightsAndIndices {
|
||||
float weight_0;
|
||||
float weight_1;
|
||||
float weight_2;
|
||||
float weight_3;
|
||||
int64_t index_0;
|
||||
int64_t index_1;
|
||||
int64_t index_2;
|
||||
int64_t index_3;
|
||||
|
||||
int advance;
|
||||
};
|
||||
|
||||
struct HalfPixelScalerGrad {
|
||||
HalfPixelScalerGrad(){};
|
||||
inline float operator()(const size_t x, const float scale) const {
|
||||
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
||||
}
|
||||
};
|
||||
struct LegacyScalerGrad {
|
||||
LegacyScalerGrad(){};
|
||||
inline float operator()(const size_t x, const float scale) const { return static_cast<float>(x) * scale; }
|
||||
};
|
||||
|
||||
class CachedInterpolationCalculator {
|
||||
public:
|
||||
CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
|
||||
inline int Advance(const int64_t x_0, const int64_t x_1, const int64_t x_2, const int64_t x_3) {
|
||||
const std::array<int64_t, 4> new_x_indices{{x_0, x_1, x_2, x_3}};
|
||||
int64_t cached_values_hand = 0;
|
||||
int64_t new_indices_hand = 0;
|
||||
while (cached_values_hand < 4) {
|
||||
if (indexes_[cached_values_hand] == new_x_indices[new_indices_hand]) {
|
||||
if (new_indices_hand < cached_values_hand) {
|
||||
indexes_[new_indices_hand] = indexes_[cached_values_hand];
|
||||
}
|
||||
cached_values_hand++;
|
||||
new_indices_hand++;
|
||||
} else {
|
||||
cached_values_hand++;
|
||||
}
|
||||
}
|
||||
switch (new_indices_hand) {
|
||||
case 0:
|
||||
indexes_[0] = x_0;
|
||||
case 1:
|
||||
indexes_[1] = x_1;
|
||||
case 2:
|
||||
indexes_[2] = x_2;
|
||||
case 3:
|
||||
indexes_[3] = x_3;
|
||||
break;
|
||||
}
|
||||
return new_indices_hand;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t indexes_[4];
|
||||
};
|
||||
|
||||
const float *InitCoeffsTable_(const double a) {
|
||||
float *coeffs_table = new float[(kTableSize + 1) * 2];
|
||||
for (int64_t i = 0; i <= kTableSize; ++i) {
|
||||
float x = i * 1.0 / kTableSize;
|
||||
coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
|
||||
x += 1.0;
|
||||
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||
}
|
||||
|
||||
return coeffs_table;
|
||||
}
|
||||
|
||||
const float *GetCoeffsTable_(const bool use_keys_cubic) {
|
||||
if (use_keys_cubic) {
|
||||
static const float *coeffs_table = InitCoeffsTable_(-0.5f);
|
||||
return coeffs_table;
|
||||
} else {
|
||||
static const float *coeffs_table = InitCoeffsTable_(-0.75f);
|
||||
return coeffs_table;
|
||||
}
|
||||
}
|
||||
|
||||
inline int64_t Bound(int64_t val, int64_t limit) { return std::min(limit - 1, std::max(int64_t{0}, val)); }
|
||||
|
||||
template <typename Scaler, bool use_keys_cubic>
|
||||
inline void GetWeightsAndIndicesGrad(const float scale, const size_t out_loc, const size_t limit,
|
||||
WeightsAndIndices *out) {
|
||||
const Scaler scaler;
|
||||
const float in_loc_f = scaler(out_loc, scale);
|
||||
const int64_t in_loc = std::floor(in_loc_f);
|
||||
const float delta = in_loc_f - in_loc;
|
||||
const int64_t offset = lrintf(delta * kTableSize);
|
||||
const float *coeffs_table = GetCoeffsTable_(use_keys_cubic);
|
||||
if (use_keys_cubic) {
|
||||
out->index_0 = Bound(in_loc - 1, limit);
|
||||
out->weight_0 = (out->index_0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
|
||||
out->index_1 = Bound(in_loc, limit);
|
||||
out->weight_1 = (out->index_1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
|
||||
out->index_2 = Bound(in_loc + 1, limit);
|
||||
out->weight_2 = (out->index_2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] : 0.0f);
|
||||
out->index_3 = Bound(in_loc + 2, limit);
|
||||
out->weight_3 = (out->index_3 == in_loc + 2 ? coeffs_table[(kTableSize - offset) * 2 + 1] : 0.0f);
|
||||
|
||||
const float weight_sum = out->weight_0 + out->weight_1 + out->weight_2 + out->weight_3;
|
||||
if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits<float>::min()) {
|
||||
const float one_over_weight_sum = 1.0f / weight_sum;
|
||||
out->weight_0 *= one_over_weight_sum;
|
||||
out->weight_1 *= one_over_weight_sum;
|
||||
out->weight_2 *= one_over_weight_sum;
|
||||
out->weight_3 *= one_over_weight_sum;
|
||||
}
|
||||
} else {
|
||||
out->weight_0 = coeffs_table[offset * 2 + 1];
|
||||
out->weight_1 = coeffs_table[offset * 2];
|
||||
out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
|
||||
out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
|
||||
out->index_0 = Bound(in_loc - 1, limit);
|
||||
out->index_1 = Bound(in_loc, limit);
|
||||
out->index_2 = Bound(in_loc + 1, limit);
|
||||
out->index_3 = Bound(in_loc + 2, limit);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t ResizeBicubicGradCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) {
|
||||
Tensor *input0_tensor = ctx.Input(0);
|
||||
Tensor *input1_tensor = ctx.Input(1);
|
||||
Tensor *output_tensor = ctx.Output(0);
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "ResizeBicubicGrad check params failed.");
|
||||
|
||||
shape_ = input0_tensor->GetTensorShape()->GetDimSizes();
|
||||
size_ = input1_tensor->GetTensorShape()->GetDimSizes();
|
||||
KERNEL_CHECK_FALSE((shape_.size() == 4), KERNEL_STATUS_PARAM_INVALID, "Dim of input[0] must be 4, but got[%zu].",
|
||||
shape_.size());
|
||||
KERNEL_CHECK_FALSE((size_.size() == 4), KERNEL_STATUS_PARAM_INVALID, "Dim of input[1] must be 4, but got[%zu].",
|
||||
size_.size());
|
||||
AttrValue *pattr_align_corners = ctx.GetAttr("align_corners");
|
||||
if (pattr_align_corners == nullptr) {
|
||||
align_corners_ = false;
|
||||
} else {
|
||||
align_corners_ = pattr_align_corners->GetBool();
|
||||
}
|
||||
AttrValue *pattr_half_pixel_centers = ctx.GetAttr("half_pixel_centers");
|
||||
if (pattr_half_pixel_centers == nullptr) {
|
||||
half_pixel_centers_ = false;
|
||||
} else {
|
||||
half_pixel_centers_ = pattr_half_pixel_centers->GetBool();
|
||||
}
|
||||
dtype0_ = input0_tensor->GetDataType();
|
||||
dtype1_ = input1_tensor->GetDataType();
|
||||
dtype2_ = output_tensor->GetDataType();
|
||||
|
||||
KERNEL_CHECK_FALSE((dtype0_ == DT_FLOAT), KERNEL_STATUS_PARAM_INVALID,
|
||||
"ResizeBicubicGrad op doesn't support input[0] tensor types: [%s]", DTypeStr(dtype0_).c_str());
|
||||
|
||||
KERNEL_CHECK_FALSE((dtype1_ == DT_FLOAT || dtype1_ == DT_DOUBLE), KERNEL_STATUS_PARAM_INVALID,
|
||||
"ResizeBicubicGrad op doesn't support input[1] tensor types: [%s]", DTypeStr(dtype1_).c_str());
|
||||
|
||||
KERNEL_CHECK_FALSE((dtype1_ == dtype2_), KERNEL_STATUS_PARAM_INVALID,
|
||||
"The type of input[1] and output must be the same");
|
||||
|
||||
int64_t in_height = shape_[1];
|
||||
int64_t in_width = shape_[2];
|
||||
int64_t out_height = size_[1];
|
||||
int64_t out_width = size_[2];
|
||||
height_scale_ = Scaling64_(out_height, in_height, align_corners_);
|
||||
width_scale_ = Scaling64_(out_width, in_width, align_corners_);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
static void ComputeGradientXWeightsAndIndices(const ResizerGradState &resizer_state, const bool half_pixel_centers,
|
||||
std::vector<WeightsAndIndices> *x_wais) {
|
||||
CachedInterpolationCalculator calc;
|
||||
if (half_pixel_centers) {
|
||||
for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
|
||||
GetWeightsAndIndicesGrad<HalfPixelScalerGrad, true>(resizer_state.width_scale, x, resizer_state.original_width,
|
||||
&(*x_wais)[x]);
|
||||
auto &x_wai = (*x_wais)[x];
|
||||
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2, x_wai.index_3);
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int64_t x = 0; x < resizer_state.resized_width; ++x) {
|
||||
GetWeightsAndIndicesGrad<LegacyScalerGrad, false>(resizer_state.width_scale, x, resizer_state.original_width,
|
||||
&(*x_wais)[x]);
|
||||
auto &x_wai = (*x_wais)[x];
|
||||
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2, x_wai.index_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void ResizeBicubicGrad(const float *input_grad, ResizerGradState &resizer_state, const bool half_pixel_centers,
|
||||
T *output_grad, CpuKernelContext &ctx) {
|
||||
const float height_scale = resizer_state.height_scale;
|
||||
const int64_t original_height = resizer_state.original_height;
|
||||
const int64_t channels = resizer_state.channels;
|
||||
const int64_t resized_width = resizer_state.resized_width;
|
||||
const int64_t resized_height = resizer_state.resized_height;
|
||||
|
||||
std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
|
||||
ComputeGradientXWeightsAndIndices(resizer_state, half_pixel_centers, &x_wais);
|
||||
const bool flag = true;
|
||||
bool utils_flag = false;
|
||||
if (resizer_state.original_width * original_height * channels * resizer_state.batch_size >= kParallelDataNum) {
|
||||
utils_flag = true;
|
||||
}
|
||||
if (utils_flag) {
|
||||
for (int64_t b = 0; b < resizer_state.batch_size; ++b) {
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
if (max_core_num > resized_height) {
|
||||
max_core_num = resized_height;
|
||||
}
|
||||
auto shard_resize_bicubic_grad = [&](int64_t start, int64_t end) {
|
||||
for (int64_t y = start; y < end; ++y) {
|
||||
WeightsAndIndices y_wai;
|
||||
if (half_pixel_centers) {
|
||||
GetWeightsAndIndicesGrad<HalfPixelScalerGrad, true>(height_scale, y, original_height, &y_wai);
|
||||
} else {
|
||||
GetWeightsAndIndicesGrad<LegacyScalerGrad, false>(height_scale, y, original_height, &y_wai);
|
||||
}
|
||||
for (int64_t x = 0; x < resized_width; ++x) {
|
||||
const WeightsAndIndices &x_wai = x_wais[x];
|
||||
for (int64_t c = 0; c < channels; ++c) {
|
||||
T curr_input_grad = input_grad[resizer_state.Calindex(b, y, x, c, flag)];
|
||||
// row 0 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
|
||||
|
||||
// row 1 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
|
||||
// row 2 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
|
||||
// row 3 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
CpuKernelUtils::ParallelFor(ctx, resized_height, resized_height / max_core_num, shard_resize_bicubic_grad);
|
||||
}
|
||||
} else {
|
||||
for (int64_t b = 0; b < resizer_state.batch_size; ++b) {
|
||||
for (int64_t y = 0; y < resized_height; ++y) {
|
||||
WeightsAndIndices y_wai;
|
||||
if (half_pixel_centers) {
|
||||
GetWeightsAndIndicesGrad<HalfPixelScalerGrad, true>(height_scale, y, original_height, &y_wai);
|
||||
} else {
|
||||
GetWeightsAndIndicesGrad<LegacyScalerGrad, false>(height_scale, y, original_height, &y_wai);
|
||||
}
|
||||
for (int64_t x = 0; x < resized_width; ++x) {
|
||||
const WeightsAndIndices &x_wai = x_wais[x];
|
||||
for (int64_t c = 0; c < channels; ++c) {
|
||||
T curr_input_grad = input_grad[resizer_state.Calindex(b, y, x, c, flag)];
|
||||
// row 0 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_0, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
|
||||
|
||||
// row 1 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_1, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
|
||||
// row 2 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_2, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
|
||||
// row 3 of 0, 1, 2, 3
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
|
||||
output_grad[resizer_state.Calindex(b, y_wai.index_3, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx) {
|
||||
auto input0_addr = reinterpret_cast<float *>(ctx.Input(0)->GetData());
|
||||
auto output_addr = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
||||
AttrValue *pattr_half_pixel_centers = ctx.GetAttr("half_pixel_centers");
|
||||
if (pattr_half_pixel_centers == nullptr) {
|
||||
half_pixel_centers_ = false;
|
||||
} else {
|
||||
half_pixel_centers_ = pattr_half_pixel_centers->GetBool();
|
||||
}
|
||||
ResizerGradState sta;
|
||||
sta.CalculateSize(ctx);
|
||||
|
||||
auto ret = memset_s(output_addr, ctx.Output(0)->GetDataSize(), 0, ctx.Output(0)->GetDataSize());
|
||||
KERNEL_CHECK_FALSE((ret == EOK), ret, "Output buffer memset failed, ret: [%d].", ret);
|
||||
|
||||
ResizeBicubicGrad(input0_addr, sta, half_pixel_centers_, output_addr, ctx);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ResizeBicubicGradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
uint32_t res = GetInputAndCheck(ctx);
|
||||
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "GetInputAndCheck failed.");
|
||||
|
||||
if (dtype1_ == DT_DOUBLE) {
|
||||
res = DoCompute<double>(ctx);
|
||||
} else if (dtype1_ == DT_FLOAT) {
|
||||
res = DoCompute<float>(ctx);
|
||||
} else {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_CHECK_FALSE((res == KERNEL_STATUS_OK), res, "ResizeBicubicGrad Compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kResizeBicubicGrad, ResizeBicubicGradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_RESIZE_BICUBIC_GRAD_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_RESIZE_BICUBIC_GRAD_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
|
||||
class ResizeBicubicGradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~ResizeBicubicGradCpuKernel() = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t GetInputAndCheck(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,247 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "resize_nearest_neighbor_v2.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
#include "status.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t kInputNum = 2;
|
||||
constexpr uint32_t kOutputNum = 1;
|
||||
constexpr uint32_t kDim1 = 1;
|
||||
constexpr uint32_t kDim4 = 4;
|
||||
constexpr uint32_t kValue0 = 0;
|
||||
constexpr uint32_t kIndex0 = 0;
|
||||
constexpr uint32_t kIndex1 = 1;
|
||||
constexpr uint32_t kIndex2 = 2;
|
||||
constexpr uint32_t kNumElements2 = 2;
|
||||
constexpr uint32_t kMaxValue = 24;
|
||||
const char *kResizeNearestNeighborV2 = "ResizeNearestNeighborV2";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
inline float Scaler(const int x, const float scale, bool half_pixel_centers) {
|
||||
if (half_pixel_centers) {
|
||||
return (static_cast<float>(x) + 0.5f) * scale;
|
||||
} else {
|
||||
return static_cast<float>(x) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
inline float CalculateResizeScale(int64_t in_size, int64_t out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
|
||||
uint32_t ResizeNearestNeighborV2CpuKernel::ResizeNearestNeighborV2ParamCheck(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check params failed.", kResizeNearestNeighborV2);
|
||||
Tensor *x_ptr = ctx.Input(0);
|
||||
|
||||
Tensor *size_ptr = ctx.Input(1);
|
||||
|
||||
AttrValue *align_corners_ptr = ctx.GetAttr("align_corners");
|
||||
|
||||
AttrValue *half_pixel_centers_ptr = ctx.GetAttr("half_pixel_centers");
|
||||
|
||||
auto align_corners = false;
|
||||
auto half_pixel_centers = false;
|
||||
if (align_corners_ptr != nullptr) {
|
||||
align_corners = align_corners_ptr->GetBool();
|
||||
}
|
||||
if (half_pixel_centers_ptr != nullptr) {
|
||||
half_pixel_centers = half_pixel_centers_ptr->GetBool();
|
||||
}
|
||||
auto x_shape = x_ptr->GetTensorShape()->GetDimSizes();
|
||||
auto x_dims = x_ptr->GetTensorShape()->GetDims();
|
||||
auto size_shape = size_ptr->GetTensorShape()->GetDimSizes();
|
||||
auto size_dims = size_ptr->GetTensorShape()->GetDims();
|
||||
auto size_data = static_cast<int32_t *>(size_ptr->GetData());
|
||||
|
||||
KERNEL_CHECK_FALSE((!half_pixel_centers || (half_pixel_centers && !align_corners)), KERNEL_STATUS_PARAM_INVALID,
|
||||
"If half_pixel_centers is True, "
|
||||
"align_corners must be False, but got half_pixel_centers %s, "
|
||||
"align_corners %s.",
|
||||
half_pixel_centers == true ? "True" : "False", align_corners == true ? "True" : "False");
|
||||
KERNEL_CHECK_FALSE(x_dims == kDim4, KERNEL_STATUS_PARAM_INVALID, "x must be 4-dimensional but got %d-dimensional.",
|
||||
x_dims);
|
||||
auto channels = x_shape[3];
|
||||
KERNEL_CHECK_FALSE(channels > kValue0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"image must have at least one channel but got %d channel.", channels);
|
||||
KERNEL_CHECK_FALSE(x_shape[kIndex1] > kValue0 && x_shape[kIndex2] > kValue0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"x image must be of non-zero size but got height %d, width %d.", x_shape[kIndex1],
|
||||
x_shape[kIndex2]);
|
||||
KERNEL_CHECK_FALSE(x_shape[kIndex1] < INT32_MAX && x_shape[kIndex2] < INT32_MAX, KERNEL_STATUS_PARAM_INVALID,
|
||||
"x sizes must be between 0 and max int32 but got but "
|
||||
"got height %d, width %d.",
|
||||
x_shape[kIndex1], x_shape[kIndex2]);
|
||||
auto in_height = static_cast<int32_t>(x_shape[1]);
|
||||
auto in_width = static_cast<int32_t>(x_shape[2]);
|
||||
KERNEL_CHECK_FALSE(size_dims == kDim1, KERNEL_STATUS_PARAM_INVALID, "size_shape must be 1-dimensional but got %d.",
|
||||
size_dims);
|
||||
KERNEL_CHECK_FALSE(size_ptr->NumElements() == kNumElements2, KERNEL_STATUS_PARAM_INVALID,
|
||||
"shape_t must have two elements but got %d element(s).", size_ptr->NumElements());
|
||||
KERNEL_CHECK_FALSE(size_data[kIndex0] > 0 && size_data[kIndex1] > 0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"output dimensions must be positive but got height %d, width %d.", size_data[kIndex0],
|
||||
size_data[kIndex1]);
|
||||
auto out_height = size_data[0];
|
||||
auto out_width = size_data[1];
|
||||
|
||||
auto height_scale = CalculateResizeScale(in_height, out_height, align_corners);
|
||||
auto width_scale = CalculateResizeScale(in_width, out_width, align_corners);
|
||||
KERNEL_CHECK_FALSE(ceilf((out_height - 1) * height_scale) <= float(INT64_MAX), KERNEL_STATUS_PARAM_INVALID,
|
||||
"input image height scale would cause an overflow.");
|
||||
KERNEL_CHECK_FALSE(ceilf((out_width - 1) * width_scale) <= float(INT64_MAX), KERNEL_STATUS_PARAM_INVALID,
|
||||
"input image width scale would cause an overflow.");
|
||||
KERNEL_CHECK_FALSE(in_height < (1 << kMaxValue) && in_width < (1 << kMaxValue), KERNEL_STATUS_PARAM_INVALID,
|
||||
"nearest neighbor requires max height "
|
||||
"& width of 2^24.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ResizeNearestNeighborV2CpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
if (ResizeNearestNeighborV2ParamCheck(ctx) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
Tensor *x = ctx.Input(0);
|
||||
DataType data_type = DataType(x->GetDataType());
|
||||
uint32_t res = KERNEL_STATUS_OK;
|
||||
switch (data_type) {
|
||||
case DT_INT8:
|
||||
res = ResizeNearestNeighborV2Compute<int8_t>(ctx);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
res = ResizeNearestNeighborV2Compute<uint8_t>(ctx);
|
||||
break;
|
||||
case DT_INT16:
|
||||
res = ResizeNearestNeighborV2Compute<int16_t>(ctx);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
res = ResizeNearestNeighborV2Compute<uint16_t>(ctx);
|
||||
break;
|
||||
case DT_INT32:
|
||||
res = ResizeNearestNeighborV2Compute<int32_t>(ctx);
|
||||
break;
|
||||
case DT_INT64:
|
||||
res = ResizeNearestNeighborV2Compute<int64_t>(ctx);
|
||||
break;
|
||||
case DT_FLOAT16:
|
||||
res = ResizeNearestNeighborV2Compute<Eigen::half>(ctx);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
res = ResizeNearestNeighborV2Compute<float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
res = ResizeNearestNeighborV2Compute<double>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("For ResizeNearestNeighborV2, invalid input type [%s].", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
template <typename T>
|
||||
void ResizeNearestNeighborV2CpuKernel::InnerCompute(
|
||||
Eigen::Index b, Eigen::Index y,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> x_4d,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> y_4d) {
|
||||
Eigen::Index in_y =
|
||||
std::min((align_corners) ? static_cast<Eigen::Index>(roundf(Scaler(y, height_scale, half_pixel_centers)))
|
||||
: static_cast<Eigen::Index>(floorf(Scaler(y, height_scale, half_pixel_centers))),
|
||||
in_height - 1);
|
||||
if (half_pixel_centers) {
|
||||
in_y = std::max(static_cast<Eigen::Index>(0), in_y);
|
||||
}
|
||||
for (Eigen::Index x = 0; x < out_width; ++x) {
|
||||
Eigen::Index in_x =
|
||||
std::min((align_corners) ? static_cast<Eigen::Index>(roundf(Scaler(x, width_scale, half_pixel_centers)))
|
||||
: static_cast<Eigen::Index>(floorf(Scaler(x, width_scale, half_pixel_centers))),
|
||||
in_width - 1);
|
||||
if (half_pixel_centers) {
|
||||
in_x = std::max(static_cast<Eigen::Index>(0), in_x);
|
||||
}
|
||||
if (data_format == "NHWC") {
|
||||
std::copy_n(&x_4d(b, in_y, in_x, 0), channels, &y_4d(b, y, x, 0));
|
||||
} else {
|
||||
// data_format = NCHW
|
||||
for (Eigen::Index c = 0; c < channels; ++c) {
|
||||
y_4d(b, c, y, x) = x_4d(b, c, in_y, in_x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ResizeNearestNeighborV2CpuKernel::ResizeNearestNeighborV2Compute(CpuKernelContext &ctx) {
|
||||
Tensor *input_x = ctx.Input(0);
|
||||
Tensor *output_y = ctx.Output(0);
|
||||
std::vector<int64_t> x_shape = input_x->GetTensorShape()->GetDimSizes();
|
||||
std::vector<int64_t> y_shape = output_y->GetTensorShape()->GetDimSizes();
|
||||
AttrValue *align_corners_ptr = ctx.GetAttr("align_corners");
|
||||
AttrValue *half_pixel_centers_ptr = ctx.GetAttr("half_pixel_centers");
|
||||
AttrValue *data_format_ptr = ctx.GetAttr("data_format");
|
||||
if (data_format_ptr != nullptr) {
|
||||
data_format = data_format_ptr->GetString();
|
||||
}
|
||||
if (data_format == "NHWC") {
|
||||
dim_idx_map_ = {
|
||||
{'N', kFormatNHWCIndexN}, {'H', kFormatNHWCIndexH}, {'W', kFormatNHWCIndexW}, {'C', kFormatNHWCIndexC}};
|
||||
} else if (data_format == "NCHW") {
|
||||
dim_idx_map_ = {
|
||||
{'N', kFormatNCHWIndexN}, {'C', kFormatNCHWIndexC}, {'H', kFormatNCHWIndexH}, {'W', kFormatNCHWIndexW}};
|
||||
} else {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For ResizeNearestNeighborV2, data_format only support [NCHW, NHWC], "
|
||||
"but get [%s].",
|
||||
data_format);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
align_corners = false;
|
||||
half_pixel_centers = false;
|
||||
if (align_corners_ptr != nullptr) {
|
||||
align_corners = align_corners_ptr->GetBool();
|
||||
}
|
||||
if (half_pixel_centers_ptr != nullptr) {
|
||||
half_pixel_centers = half_pixel_centers_ptr->GetBool();
|
||||
}
|
||||
batch_size = x_shape[dim_idx_map_['N']];
|
||||
in_height = x_shape[dim_idx_map_['H']];
|
||||
in_width = x_shape[dim_idx_map_['W']];
|
||||
channels = x_shape[dim_idx_map_['C']];
|
||||
out_height = y_shape[dim_idx_map_['H']];
|
||||
out_width = y_shape[dim_idx_map_['W']];
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners);
|
||||
EigenTensor x_et(input_x, input_x->GetData());
|
||||
EigenTensor y_et(output_y, output_y->GetData());
|
||||
auto x_4d = x_et.tensor<T, 4>();
|
||||
auto y_4d = y_et.tensor<T, 4>();
|
||||
for (Eigen::Index b = 0; b < batch_size; ++b) {
|
||||
for (Eigen::Index y = 0; y < out_height; ++y) {
|
||||
InnerCompute<T>(b, y, x_4d, y_4d);
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kResizeNearestNeighborV2, ResizeNearestNeighborV2CpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_RESIZE_NEAREST_NEIGHBOR_V2_H
|
||||
#define AICPU_KERNELS_NORMALIZED_RESIZE_NEAREST_NEIGHBOR_V2_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
constexpr uint32_t kValue4 = 4;
|
||||
class ResizeNearestNeighborV2CpuKernel : public CpuKernel {
|
||||
public:
|
||||
~ResizeNearestNeighborV2CpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ResizeNearestNeighborV2ParamCheck(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t ResizeNearestNeighborV2Compute(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
void InnerCompute(
|
||||
Eigen::Index b, Eigen::Index y,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> x_4d,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> y_4d);
|
||||
|
||||
std::string data_format = "NHWC";
|
||||
std::unordered_map<char, size_t> dim_idx_map_;
|
||||
bool align_corners;
|
||||
bool half_pixel_centers;
|
||||
Eigen::Index batch_size;
|
||||
Eigen::Index in_height;
|
||||
Eigen::Index in_width;
|
||||
Eigen::Index channels;
|
||||
Eigen::Index out_height;
|
||||
Eigen::Index out_width;
|
||||
float height_scale;
|
||||
float width_scale;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,205 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "resize_nearest_neighbor_v2_grad.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
#include "status.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
|
||||
namespace {
|
||||
constexpr uint32_t kInputNum = 2;
|
||||
constexpr uint32_t kOutputNum = 1;
|
||||
constexpr uint32_t kDim1 = 1;
|
||||
constexpr uint32_t kDim4 = 4;
|
||||
constexpr uint32_t kIndex0 = 0;
|
||||
constexpr uint32_t kIndex1 = 1;
|
||||
constexpr uint32_t kNumElements2 = 2;
|
||||
const char *kResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
inline float Scaler(const int x, const float scale, bool half_pixel_centers) {
|
||||
if (half_pixel_centers) {
|
||||
return (static_cast<float>(x) + 0.5f) * scale;
|
||||
} else {
|
||||
return static_cast<float>(x) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
inline float CalculateResizeScale(int64_t in_size, int64_t out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
|
||||
uint32_t ResizeNearestNeighborV2GradCpuKernel::ResizeNearestNeighborV2GradParamCheck(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check params failed.",
|
||||
kResizeNearestNeighborV2Grad);
|
||||
Tensor *grads_ptr = ctx.Input(0);
|
||||
Tensor *size_ptr = ctx.Input(1);
|
||||
|
||||
auto grads_shape = grads_ptr->GetTensorShape()->GetDimSizes();
|
||||
auto grads_dims = grads_ptr->GetTensorShape()->GetDims();
|
||||
auto size_shape = size_ptr->GetTensorShape()->GetDimSizes();
|
||||
auto size_dims = size_ptr->GetTensorShape()->GetDims();
|
||||
auto size_data = static_cast<int32_t *>(size_ptr->GetData());
|
||||
|
||||
KERNEL_CHECK_FALSE(grads_dims == kDim4, KERNEL_STATUS_PARAM_INVALID,
|
||||
"grads must be 4-dimensional but got %d-dimensional.", grads_dims);
|
||||
|
||||
KERNEL_CHECK_FALSE(size_dims == kDim1, KERNEL_STATUS_PARAM_INVALID, "size_shape must be 1-dimensional but got %d.",
|
||||
size_dims);
|
||||
|
||||
KERNEL_CHECK_FALSE(size_ptr->NumElements() == kNumElements2, KERNEL_STATUS_PARAM_INVALID,
|
||||
"size must have two elements but got %d element(s).", size_ptr->NumElements());
|
||||
KERNEL_CHECK_FALSE(size_data[kIndex0] > 0 && size_data[kIndex1] > 0, KERNEL_STATUS_PARAM_INVALID,
|
||||
"size elements must be positive but got height %d, width %d.", size_data[kIndex0],
|
||||
size_data[kIndex1]);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ResizeNearestNeighborV2GradCpuKernel::InnerCompute(
|
||||
Eigen::Index y, Eigen::Index out_y, Eigen::Index x,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> grads_4d,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> y_4d) {
|
||||
const Eigen::Index out_x =
|
||||
std::min((align_corners) ? static_cast<Eigen::Index>(roundf(Scaler(x, width_scale, half_pixel_centers)))
|
||||
: static_cast<Eigen::Index>(floorf(Scaler(x, width_scale, half_pixel_centers))),
|
||||
out_width - 1);
|
||||
for (Eigen::Index b = 0; b < batch_size; ++b) {
|
||||
for (Eigen::Index c = 0; c < channels; ++c) {
|
||||
if (data_format == "NHWC") {
|
||||
y_4d(b, out_y, out_x, c) += grads_4d(b, y, x, c);
|
||||
} else {
|
||||
// data_format = NCHW
|
||||
y_4d(b, c, out_y, out_x) += grads_4d(b, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t ResizeNearestNeighborV2GradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
if (ResizeNearestNeighborV2GradParamCheck(ctx) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
Tensor *grads = ctx.Input(0);
|
||||
DataType data_type = DataType(grads->GetDataType());
|
||||
uint32_t res = KERNEL_STATUS_OK;
|
||||
switch (data_type) {
|
||||
case DT_INT8:
|
||||
res = ResizeNearestNeighborV2GradCompute<int8_t>(ctx);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
res = ResizeNearestNeighborV2GradCompute<uint8_t>(ctx);
|
||||
break;
|
||||
case DT_INT16:
|
||||
res = ResizeNearestNeighborV2GradCompute<int16_t>(ctx);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
res = ResizeNearestNeighborV2GradCompute<uint16_t>(ctx);
|
||||
break;
|
||||
case DT_INT32:
|
||||
res = ResizeNearestNeighborV2GradCompute<int32_t>(ctx);
|
||||
break;
|
||||
case DT_INT64:
|
||||
res = ResizeNearestNeighborV2GradCompute<int64_t>(ctx);
|
||||
break;
|
||||
case DT_FLOAT16:
|
||||
res = ResizeNearestNeighborV2GradCompute<Eigen::half>(ctx);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
res = ResizeNearestNeighborV2GradCompute<float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
res = ResizeNearestNeighborV2GradCompute<double>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("For ResizeNearestNeighborV2Grad, invalid input type [%s].", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ResizeNearestNeighborV2GradCpuKernel::ResizeNearestNeighborV2GradCompute(CpuKernelContext &ctx) {
|
||||
Tensor *input_grads = ctx.Input(0);
|
||||
Tensor *output_y = ctx.Output(0);
|
||||
std::vector<int64_t> grads_shape = input_grads->GetTensorShape()->GetDimSizes();
|
||||
std::vector<int64_t> y_shape = output_y->GetTensorShape()->GetDimSizes();
|
||||
AttrValue *align_corners_ptr = ctx.GetAttr("align_corners");
|
||||
AttrValue *half_pixel_centers_ptr = ctx.GetAttr("half_pixel_centers");
|
||||
AttrValue *data_format_ptr = ctx.GetAttr("data_format");
|
||||
if (data_format_ptr != nullptr) {
|
||||
data_format = data_format_ptr->GetString();
|
||||
}
|
||||
if (data_format == "NHWC") {
|
||||
dim_idx_map_ = {
|
||||
{'N', kFormatNHWCIndexN}, {'H', kFormatNHWCIndexH}, {'W', kFormatNHWCIndexW}, {'C', kFormatNHWCIndexC}};
|
||||
} else if (data_format == "NCHW") {
|
||||
dim_idx_map_ = {
|
||||
{'N', kFormatNCHWIndexN}, {'C', kFormatNCHWIndexC}, {'H', kFormatNCHWIndexH}, {'W', kFormatNCHWIndexW}};
|
||||
} else {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For ResizeNearestNeighborV2Grad, data_format only support [NCHW, "
|
||||
"NHWC], but get [%s].",
|
||||
data_format);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
align_corners = false;
|
||||
half_pixel_centers = false;
|
||||
if (align_corners_ptr != nullptr) {
|
||||
align_corners = align_corners_ptr->GetBool();
|
||||
}
|
||||
if (half_pixel_centers_ptr != nullptr) {
|
||||
half_pixel_centers = half_pixel_centers_ptr->GetBool();
|
||||
}
|
||||
batch_size = grads_shape[dim_idx_map_['N']];
|
||||
in_height = grads_shape[dim_idx_map_['H']];
|
||||
in_width = grads_shape[dim_idx_map_['W']];
|
||||
channels = grads_shape[dim_idx_map_['C']];
|
||||
|
||||
out_height = y_shape[dim_idx_map_['H']];
|
||||
out_width = y_shape[dim_idx_map_['W']];
|
||||
|
||||
height_scale = CalculateResizeScale(out_height, in_height, align_corners);
|
||||
width_scale = CalculateResizeScale(out_width, in_width, align_corners);
|
||||
|
||||
EigenTensor grads_et(input_grads, input_grads->GetData());
|
||||
EigenTensor y_et(output_y, output_y->GetData());
|
||||
auto grads_4d = grads_et.tensor<T, 4>();
|
||||
auto y_4d = y_et.tensor<T, 4>();
|
||||
y_4d.setZero();
|
||||
for (Eigen::Index y = 0; y < in_height; ++y) {
|
||||
const Eigen::Index out_y =
|
||||
std::min((align_corners) ? static_cast<Eigen::Index>(roundf(Scaler(y, height_scale, half_pixel_centers)))
|
||||
: static_cast<Eigen::Index>(floorf(Scaler(y, height_scale, half_pixel_centers))),
|
||||
out_height - 1);
|
||||
for (Eigen::Index x = 0; x < in_width; ++x) {
|
||||
InnerCompute(y, out_y, x, grads_4d, y_4d);
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_H
|
||||
#define AICPU_KERNELS_NORMALIZED_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
constexpr uint32_t kValue4 = 4;
|
||||
class ResizeNearestNeighborV2GradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~ResizeNearestNeighborV2GradCpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ResizeNearestNeighborV2GradParamCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ResizeNearestNeighborV2GradCompute(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
void InnerCompute(
|
||||
Eigen::Index y, Eigen::Index out_y, Eigen::Index x,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> grads_4d,
|
||||
Eigen::TensorMap<Eigen::Tensor<T, kValue4, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> y_4d);
|
||||
std::unordered_map<char, size_t> dim_idx_map_;
|
||||
std::string data_format = "NHWC";
|
||||
bool align_corners;
|
||||
bool half_pixel_centers;
|
||||
Eigen::Index batch_size;
|
||||
Eigen::Index in_height;
|
||||
Eigen::Index in_width;
|
||||
Eigen::Index channels;
|
||||
|
||||
Eigen::Index out_height;
|
||||
Eigen::Index out_width;
|
||||
|
||||
float height_scale;
|
||||
float width_scale;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -159,6 +159,7 @@ constexpr auto kAcosh = "Acosh";
|
|||
constexpr auto kAsin = "Asin";
|
||||
constexpr auto kAsinh = "Asinh";
|
||||
constexpr auto kAtanh = "Atanh";
|
||||
constexpr auto kAdaptiveMaxPool3DGrad = "AdaptiveMaxPool3DGrad";
|
||||
constexpr auto kCosh = "Cosh";
|
||||
constexpr auto kTan = "Tan";
|
||||
constexpr auto kTanhGrad = "TanhGrad";
|
||||
|
@ -348,7 +349,8 @@ const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
|
|||
{kStridedSliceV2, "StridedSlice"},
|
||||
{kAdaptiveMaxPool3D, "AdaptiveMaxPool3d"},
|
||||
{kRandpermV2, "StatelessRandperm"},
|
||||
{kStridedSliceV2Grad, "StridedSliceGrad"}};
|
||||
{kStridedSliceV2Grad, "StridedSliceGrad"},
|
||||
{kAdaptiveMaxPool3DGrad, "AdaptiveMaxPool3dGrad"}};
|
||||
|
||||
class AicpuOpUtil {
|
||||
public:
|
||||
|
|
|
@ -105,7 +105,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
mindspore::kBiasAddGradOpName,
|
||||
mindspore::kBincountOpName,
|
||||
mindspore::kBlackmanWindowOpName,
|
||||
mindspore::kBroadcastOpName,
|
||||
mindspore::kBroadcastToOpName,
|
||||
mindspore::kMedianGradOpName,
|
||||
mindspore::kNMSWithMaskOpName,
|
||||
mindspore::kReduceSumOpName,
|
||||
|
@ -161,6 +161,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
mindspore::kNanToNumOpName,
|
||||
mindspore::kQrOpName,
|
||||
mindspore::kResizeBicubicOpName,
|
||||
mindspore::kResizeBicubicGradOpName,
|
||||
mindspore::kNuclearNormOpName,
|
||||
mindspore::kQuantileOpName,
|
||||
mindspore::kSparseSegmentSqrtNOpName,
|
||||
|
@ -173,6 +174,8 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
mindspore::kMulNoNanOpName,
|
||||
mindspore::kMultilabelMarginLossGradOpName,
|
||||
mindspore::kNthElementOpName,
|
||||
mindspore::kResizeNearestNeighborV2OpName,
|
||||
mindspore::kResizeNearestNeighborV2GradOpName,
|
||||
mindspore::kNonMaxSuppressionWithOverlapsOpName,
|
||||
mindspore::kOneHotOpName,
|
||||
mindspore::kOrgqrOpName,
|
||||
|
|
|
@ -175,6 +175,9 @@ from .tan import _tan_aicpu
|
|||
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
|
||||
from .resize_bilinear import _resize_bilinear_aicpu
|
||||
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
|
||||
from .resize_bicubic_grad import _resize_bicubic_grad_aicpu
|
||||
from .resize_nearest_neighbor_v2 import _resize_nearest_neighbor_v2_aicpu
|
||||
from .resize_nearest_neighbor_v2_grad import _resize_nearest_neighbor_v2_grad_aicpu
|
||||
from .scatter_elements import _scatter_elements_aicpu
|
||||
from .non_max_suppression import _non_max_suppression_aicpu
|
||||
from .square import _square_aicpu
|
||||
|
|
Loading…
Reference in New Issue