!48745 add ops DeformableOffset/Grad

Merge pull request !48745 from zhupuxu/deformable_offset
This commit is contained in:
i-robot 2023-02-21 02:25:46 +00:00 committed by Gitee
commit 5ab2fe68b3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 1262 additions and 2 deletions

View File

@ -227,6 +227,7 @@ constexpr auto kConv3DOpName = "Conv3D";
constexpr auto kConv3DTransposeDOpName = "Conv3DTransposeD";
constexpr auto kConv3DTransposeOpName = "Conv3DTranspose";
constexpr auto kDeformableOffsetsOpName = "DeformableOffsets";
constexpr auto kDeformableOffsetsGradOpName = "DeformableOffsetsGrad";
constexpr auto kCropAndResizeOpName = "CropAndResize";
constexpr auto kCropAndResizeDOpName = "CropAndResizeD";
constexpr auto kCropAndResizeGradBoxesOpName = "CropAndResizeGradBoxes";

View File

@ -0,0 +1,317 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./deformable_offsets.h"
#include <memory>
#include <Eigen/Dense>
#include <map>
#include <functional>
#include <thread>
#include "Eigen/Dense"
#include "cpu_kernel_utils.h"
#include "utils/kernel_util.h"
namespace aicpu {
namespace {
const char *kDeformableOffsets = "DeformableOffsets";
constexpr auto kStrides = "strides";
constexpr auto kPads = "pads";
constexpr auto kSize = "ksize";
constexpr auto kDilations = "dilations";
constexpr auto kModulated = "modulated";
constexpr auto kDeformableGroups = "deformable_groups";
constexpr size_t kInputsSize = 2;
constexpr size_t kOutputsSize = 1;
constexpr size_t kStridesSize = 4;
constexpr size_t kPadsSize = 4;
constexpr size_t kKernelSizeSize = 2;
constexpr size_t kKernelSizeHIndex = 0;
constexpr size_t kKernelSizeWIndex = 1;
constexpr size_t kDilationsSize = 4;
constexpr size_t kXShapeSize = 4;
constexpr size_t kOutputShapeSize = 4;
constexpr size_t kPadTopIndex = 0;
constexpr size_t kPadLeftIndex = 2;
constexpr size_t kOffsetsSize = 3;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
using ShapeVector = std::vector<int64_t>;
template <typename T>
T DeformableBilinear(const T *input, T x, T y, int64_t width, int64_t height) {
if (y <= static_cast<T>(-1) || y >= static_cast<T>(height) || x <= static_cast<T>(-1) || x >= static_cast<T>(width)) {
return static_cast<T>(0);
}
int64_t left;
if constexpr (std::is_same<T, float>::value) {
left = static_cast<int64_t>(floorf(x));
} else {
left = static_cast<int64_t>(floor(x));
}
auto right = left + 1;
int64_t top;
if constexpr (std::is_same<T, float>::value) {
top = static_cast<int64_t>(floorf(y));
} else {
top = static_cast<int64_t>(floor(y));
}
auto bottom = top + 1;
T l = x - static_cast<T>(left);
T r = static_cast<T>(1) - l;
T t = y - static_cast<T>(top);
T b = static_cast<T>(1) - t;
T lt = static_cast<T>(0);
T lb = static_cast<T>(0);
if (left >= 0) {
if (top >= 0) {
lt = input[top * width + left];
}
if (bottom <= height - 1) {
lb = input[bottom * width + left];
}
}
T rt = static_cast<T>(0);
T rb = static_cast<T>(0);
if (right <= width - 1) {
if (top >= 0) {
rt = input[top * width + right];
}
if (bottom <= height - 1) {
rb = input[bottom * width + right];
}
}
T w_lt = r * b;
T w_rt = l * b;
T w_lb = r * t;
T w_rb = l * t;
T val = (w_lt * lt + w_rt * rt + w_lb * lb + w_rb * rb);
return val;
}
} // namespace
uint32_t DeformableOffsetsKernel::ParseAttrs(const CpuKernelContext &ctx) {
// Check args.
n_axis_ = kIndex0;
c_axis_ = kIndex1;
h_axis_ = kIndex2;
w_axis_ = kIndex3;
strides_ = ctx.GetAttr(kStrides)->GetListInt();
if (strides_.size() != kStridesSize || strides_[n_axis_] != 1 || strides_[c_axis_] != 1) {
KERNEL_LOG_ERROR(
"The strides should be a vector with size %zu and the values according to N and C dimensions must "
"be set to 1.",
kStridesSize);
return KERNEL_STATUS_PARAM_INVALID;
}
pads_ = ctx.GetAttr(kPads)->GetListInt();
if (pads_.size() != kPadsSize) {
KERNEL_LOG_ERROR("The 'pads' should be a vector with size %zu.", kPadsSize);
return KERNEL_STATUS_PARAM_INVALID;
}
kernel_size_ = ctx.GetAttr(kSize)->GetListInt();
if (kernel_size_.size() != kKernelSizeSize) {
KERNEL_LOG_ERROR("The 'kernel_size' should be a vector with size %zu.", kKernelSizeSize);
return KERNEL_STATUS_PARAM_INVALID;
}
dilations_ = ctx.GetAttr(kDilations)->GetListInt();
if (dilations_.size() != kDilationsSize || dilations_[n_axis_] != 1 || dilations_[c_axis_] != 1) {
KERNEL_LOG_ERROR(
"The dilations should be a vector with size %zu and the values according to N and C dimensions "
"must be set to 1.",
kStridesSize);
return KERNEL_STATUS_PARAM_INVALID;
}
deformable_groups_ = ctx.GetAttr(kDeformableGroups)->GetInt();
if (deformable_groups_ <= 0) {
KERNEL_LOG_ERROR("For kernel %s, the deformable_groups should be greater than 0.");
return KERNEL_STATUS_PARAM_INVALID;
}
modulated_ = ctx.GetAttr(kModulated)->GetBool();
if (!modulated_) {
AICPU_LOGE("The value of 'modulated' only support to be set to True.");
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsKernel::SetDims(const CpuKernelContext &ctx) {
auto inputs_shape = ctx.Input(kIndex0)->GetTensorShape();
if (inputs_shape->GetDims() != kXShapeSize) {
KERNEL_LOG_ERROR("The shape size of input 'x' should be %zu, but got %zu ", kXShapeSize, inputs_shape->GetDims());
return KERNEL_STATUS_PARAM_INVALID;
}
auto outputs_shape = ctx.Output(kIndex0)->GetTensorShape();
if (outputs_shape->GetDims() != kOutputShapeSize) {
KERNEL_LOG_ERROR("The shape size of output 'y' should be %zu, but got %zu ", kOutputShapeSize,
outputs_shape->GetDims());
return KERNEL_STATUS_PARAM_INVALID;
}
ShapeVector x_shape = inputs_shape->GetDimSizes();
ShapeVector y_shape = outputs_shape->GetDimSizes();
n_ = x_shape[n_axis_];
c_ = x_shape[c_axis_];
input_h_ = x_shape[h_axis_];
input_w_ = x_shape[w_axis_];
output_h_ = y_shape[h_axis_];
output_w_ = y_shape[w_axis_];
position_grid_size_ = output_h_ * output_w_;
index_type_ = ctx.Input(kIndex0)->GetDataType();
workspace_size_list_.emplace_back(sizeof(int64_t) * static_cast<size_t>(position_grid_size_) * kKernelSizeSize);
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsKernel::ParseKernelParam(const CpuKernelContext &ctx) {
auto input_size = ctx.GetInputsSize();
auto output_size = ctx.GetOutputsSize();
if (input_size != kInputsSize || output_size != kOutputsSize) {
KERNEL_LOG_ERROR("It should get %zu inputs and %zu outputs, but got %zu input and %zu outputs.", kInputsSize,
kOutputsSize, input_size, output_size);
return KERNEL_STATUS_PARAM_INVALID;
}
if (ParseAttrs(ctx) != KERNEL_STATUS_OK) {
return KERNEL_STATUS_PARAM_INVALID;
}
if (SetDims(ctx) != KERNEL_STATUS_OK) {
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DeformableOffsetsKernel::DoCompute(const CpuKernelContext &ctx, const int64_t *position_grid_addr) {
auto *input_addr = reinterpret_cast<T *>(ctx.Input(kIndex0)->GetData());
auto *offsets_addr = reinterpret_cast<T *>(ctx.Input(kIndex1)->GetData());
auto *output_addr = reinterpret_cast<T *>(ctx.Output(kIndex0)->GetData());
int64_t pixel_h = output_h_ / kernel_size_[kKernelSizeHIndex];
int64_t pixel_w = output_w_ / kernel_size_[kKernelSizeWIndex];
int64_t output_c_dim = output_h_ * output_w_;
int64_t output_n_dim = c_ * output_c_dim;
int64_t c_size_per_dfm_group = c_ / deformable_groups_;
int64_t offset_kw_dim = pixel_h * pixel_w;
int64_t offset_kh_dim = offset_kw_dim * kernel_size_[kKernelSizeWIndex];
int64_t offset_group_dim = offset_kh_dim * kernel_size_[kKernelSizeHIndex];
int64_t offset_mask_dim = offset_group_dim * deformable_groups_;
int64_t offset_n_dim = offset_mask_dim * static_cast<int64_t>(kOffsetsSize);
int64_t input_c_dim = input_h_ * input_w_;
int64_t input_n_dim = input_c_dim * c_;
auto task = [this, &input_addr, &offsets_addr, &output_addr, &position_grid_addr, &pixel_w, &output_c_dim,
&output_n_dim, &c_size_per_dfm_group, &offset_kw_dim, &offset_kh_dim, &offset_group_dim,
&offset_mask_dim, &offset_n_dim, &input_c_dim, &input_n_dim](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto long_i = static_cast<int64_t>(i);
// Get input position
int64_t hw_idx = long_i % output_c_dim;
int64_t position_grid_idx = hw_idx * 2;
int64_t input_x = position_grid_addr[position_grid_idx];
int64_t input_y = position_grid_addr[position_grid_idx + 1];
// Get offsets
int64_t n_index = long_i / output_n_dim;
int64_t c_index = long_i / output_c_dim % c_;
int64_t x = hw_idx % output_w_;
int64_t y = hw_idx / output_w_;
int64_t dfm_group_index = c_index / c_size_per_dfm_group;
int64_t pixel_x = x / kernel_size_[kKernelSizeWIndex];
int64_t pixel_y = y / kernel_size_[kKernelSizeHIndex];
int64_t kernel_x = x % kernel_size_[kKernelSizeWIndex];
int64_t kernel_y = y % kernel_size_[kKernelSizeHIndex];
int64_t x_offsets_offset = n_index * offset_n_dim + dfm_group_index * offset_group_dim +
kernel_y * offset_kh_dim + kernel_x * offset_kw_dim + pixel_y * pixel_w + pixel_x;
T x_offsets = offsets_addr[x_offsets_offset];
int64_t y_offsets_offset = x_offsets_offset + offset_mask_dim;
T y_offsets = offsets_addr[y_offsets_offset];
int64_t mask_offset = y_offsets_offset + offset_mask_dim;
T mask = offsets_addr[mask_offset];
T new_x = static_cast<T>(input_x) + x_offsets;
T new_y = static_cast<T>(input_y) + y_offsets;
const T *input_addr_offset = input_addr + n_index * input_n_dim + c_index * input_c_dim;
T bilinear_val = DeformableBilinear(input_addr_offset, new_x, new_y, input_w_, input_h_);
output_addr[i] = bilinear_val * mask;
}
};
int64_t num_kernels = n_ * output_n_dim;
int64_t per_unit_size = num_kernels / std::thread::hardware_concurrency();
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_kernels, per_unit_size, task),
"DeformableOffset Compute failed.");
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsKernel::GenPositionGrid(const CpuKernelContext &ctx, int64_t *position_grid) {
auto task = [this, &position_grid](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto long_i = static_cast<int64_t>(i);
int64_t y = long_i / output_w_;
int64_t x = long_i % output_w_;
int64_t pixel_y = y / kernel_size_[kKernelSizeHIndex];
int64_t pixel_x = x / kernel_size_[kKernelSizeWIndex];
int64_t kernel_y = y % kernel_size_[kKernelSizeHIndex];
int64_t kernel_x = x % kernel_size_[kKernelSizeWIndex];
size_t index = i * 2;
position_grid[index] = pixel_x * strides_[w_axis_] + kernel_x * dilations_[w_axis_] - pads_[kPadLeftIndex];
position_grid[index + 1] = pixel_y * strides_[h_axis_] + kernel_y * dilations_[h_axis_] - pads_[kPadTopIndex];
}
};
int64_t num_kernels = output_h_ * output_w_;
int64_t per_unit_size = num_kernels / std::thread::hardware_concurrency();
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_kernels, per_unit_size, task),
"DeformableOffset Compute failed.");
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(ParseKernelParam(ctx), "DeformableOffsets normal check failed.");
auto *position_grid_addr = reinterpret_cast<int64_t *>(malloc(workspace_size_list_[0]));
if (position_grid_addr == nullptr) {
KERNEL_LOG_ERROR("Malloc memory failed!");
return KERNEL_STATUS_PARAM_INVALID;
}
auto ret = GenPositionGrid(ctx, position_grid_addr);
if (ret != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("Generate position grid failed.");
free(position_grid_addr);
return KERNEL_STATUS_INNER_ERROR;
}
switch (index_type_) {
case DT_FLOAT:
ret = DoCompute<float>(ctx, position_grid_addr);
break;
case DT_FLOAT16:
ret = DoCompute<Eigen::half>(ctx, position_grid_addr);
break;
default:
KERNEL_LOG_ERROR("Error type %s.", DTypeStr(index_type_).c_str());
free(position_grid_addr);
return KERNEL_STATUS_INNER_ERROR;
}
free(position_grid_addr);
return ret;
}
REGISTER_CPU_KERNEL(kDeformableOffsets, DeformableOffsetsKernel);
} // namespace aicpu

View File

@ -0,0 +1,73 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_KERNELS_H_
#define AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_KERNELS_H_
#include <algorithm>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include "cpu_ops_kernel.h"
namespace aicpu {
class DeformableOffsetsKernel : public CpuKernel {
public:
DeformableOffsetsKernel() = default;
~DeformableOffsetsKernel() = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
void ResetResource() noexcept;
uint32_t ParseKernelParam(const CpuKernelContext &ctx);
uint32_t ParseAttrs(const CpuKernelContext &ctx);
uint32_t SetDims(const CpuKernelContext &ctx);
uint32_t GenPositionGrid(const CpuKernelContext &ctx, int64_t *position_grid);
template <typename T>
uint32_t DoCompute(const CpuKernelContext &ctx, const int64_t *position_grid);
std::vector<int64_t> strides_;
std::vector<int64_t> pads_;
std::vector<int64_t> kernel_size_;
std::vector<int64_t> dilations_;
std::vector<size_t> workspace_size_list_;
int64_t deformable_groups_{1};
bool modulated_{true};
size_t n_axis_{0};
size_t c_axis_{1};
size_t h_axis_{2};
size_t w_axis_{3};
int64_t n_{0};
int64_t c_{0};
int64_t input_h_{0};
int64_t input_w_{0};
int64_t output_h_{0};
int64_t output_w_{0};
int64_t position_grid_size_{0};
DataType index_type_{DT_FLOAT};
};
} // namespace aicpu
#endif // AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_KERNELS_H_

View File

@ -0,0 +1,538 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./deformable_offsets_grad.h"
#include <Eigen/Dense>
#include <cmath>
#include <utility>
#include <set>
#include <algorithm>
#include <mutex>
#include <map>
#include <functional>
#include <thread>
#include <tuple>
#include "cpu_kernel_utils.h"
#include "securec.h"
#include "utils/kernel_util.h"
namespace aicpu {
namespace {
const char *kDeformableOffsetsGrad = "DeformableOffsetsGrad";
constexpr auto kDeformableGroups = "deformable_groups";
constexpr auto kPads = "pads";
constexpr auto kStrides = "strides";
constexpr auto kDilations = "dilations";
constexpr auto kSize = "ksize";
constexpr auto kDataformat = "data_format";
constexpr auto kNCHW = "NCHW";
constexpr size_t kInputNum = 3;
constexpr size_t kOutputNum = 2;
constexpr size_t kGradIndex = 0;
constexpr size_t kXIndex = 1;
constexpr size_t kOffsetIndex = 2;
constexpr size_t kGradXIndex = 0;
constexpr size_t kGradOffsetIndex = 1;
constexpr size_t kPadNum = 4;
constexpr size_t kStrideNum = 4;
constexpr size_t kDilationNum = 4;
constexpr size_t kKernelSizeNum = 2;
constexpr size_t kPadTopIndex = 0;
constexpr size_t kPadLeftIndex = 2;
constexpr size_t kStrideHIndex = 2;
constexpr size_t kStrideWIndex = 3;
constexpr size_t kDilationHIndex = 2;
constexpr size_t kDilationWIndex = 3;
constexpr size_t kKernelHIndex = 0;
constexpr size_t kKernelWIndex = 1;
constexpr size_t kCIndexForNCHW = 1;
constexpr size_t kHIndexForNCHW = 2;
constexpr size_t kWIndexForNCHW = 3;
constexpr size_t kHIndexForNHWC = 1;
constexpr size_t kWIndexForNHWC = 2;
constexpr size_t kCIndexForNHWC = 3;
// x,y,mask total occupy 3 channel
constexpr size_t kOffsetChannel = 3;
struct OffsetStride {
size_t kernel_w_stride;
size_t kernel_h_stride;
size_t deformable_group_stride;
size_t position_stride;
size_t offset_w_stride;
size_t offset_h_stride;
size_t n_stride;
};
struct GradStride {
size_t deformable_group_channel_stride;
size_t deformable_group_stride;
size_t kernel_w_stride;
size_t offset_w_stride;
size_t kernel_h_stride;
size_t offset_h_stride;
size_t n_stride;
};
struct InputXStride {
size_t deformable_group_channel_stride;
size_t deformable_group_stride;
size_t w_stride;
size_t h_stride;
size_t n_stride;
};
struct OffsetIndex {
size_t kernel_j;
size_t kernel_i;
size_t deformable_group_i;
size_t offset_j;
size_t offset_i;
size_t n_i;
};
struct InputXIndex {
float i;
float j;
};
} // namespace
std::mutex mutex_;
template <typename T>
void MsAtomicAdd(T *output_grad_x, const size_t &output_grad_base_pos, const T &added_value) {
std::lock_guard<std::mutex> lock(mutex_);
output_grad_x[output_grad_base_pos] += added_value;
}
inline std::tuple<size_t, size_t, size_t> CalPosition(const OffsetIndex &offset_index,
const OffsetStride &offset_stride, const GradStride &grad_stride,
const InputXStride &input_x_stride) {
const size_t offset_index_base_pos =
offset_index.n_i * offset_stride.n_stride +
offset_index.deformable_group_i * offset_stride.deformable_group_stride +
offset_index.kernel_i * offset_stride.kernel_h_stride + offset_index.kernel_j * offset_stride.kernel_w_stride +
offset_index.offset_i * offset_stride.offset_h_stride + offset_index.offset_j * offset_stride.offset_w_stride;
const size_t input_grad_base_pos =
offset_index.n_i * grad_stride.n_stride + offset_index.offset_i * grad_stride.offset_h_stride +
offset_index.offset_j * grad_stride.offset_w_stride + offset_index.kernel_i * grad_stride.kernel_h_stride +
offset_index.kernel_j * grad_stride.kernel_w_stride +
offset_index.deformable_group_i * grad_stride.deformable_group_stride;
const size_t input_x_base_pos = offset_index.n_i * input_x_stride.n_stride +
offset_index.deformable_group_i * input_x_stride.deformable_group_stride;
return {offset_index_base_pos, input_grad_base_pos, input_x_base_pos};
}
inline InputXIndex CalInputXIndex(const OffsetIndex &offset_index, const DeformableOffsetGradDims &dims) {
InputXIndex input_x_index;
input_x_index.i = -1.0f * static_cast<float>(dims.pad_top);
input_x_index.j = -1.0f * static_cast<float>(dims.pad_left);
input_x_index.i +=
static_cast<float>(offset_index.offset_i * dims.stride_h + offset_index.kernel_i * dims.dilation_h);
input_x_index.j +=
static_cast<float>(offset_index.offset_j * dims.stride_w + offset_index.kernel_j * dims.dilation_w);
return input_x_index;
}
template <typename T>
void DeformableOffsetGradKernel(const OffsetIndex &offset_index, const OffsetStride &offset_stride,
const GradStride &grad_stride, const aicpu::DeformableOffsetGradDims &dims,
const InputXStride &input_x_stride, const T *input_x, const T *input_offset,
const T *input_grad, T *output_grad_x, T *output_grad_offset) {
const auto [offset_index_base_pos, input_grad_base_pos, input_x_base_pos] =
CalPosition(offset_index, offset_stride, grad_stride, input_x_stride);
const auto input_x_index = CalInputXIndex(offset_index, dims);
const size_t offset_index_i = offset_index_base_pos + offset_stride.position_stride;
const size_t offset_index_weight = offset_index_base_pos + 2 * offset_stride.position_stride;
float offset_i = static_cast<float>(input_offset[offset_index_i]);
float offset_j = static_cast<float>(input_offset[offset_index_base_pos]);
float scale_weight = static_cast<float>(input_offset[offset_index_weight]);
float floor_offset_i = floorf(offset_i);
float floor_offset_j = floorf(offset_j);
float ceil_offset_i = floor_offset_i + 1;
float ceil_offset_j = floor_offset_j + 1;
float floor_i = input_x_index.i + floor_offset_i;
float floor_j = input_x_index.j + floor_offset_j;
float ceil_i = input_x_index.i + ceil_offset_i;
float ceil_j = input_x_index.j + ceil_offset_j;
float ceil_weight_i = offset_i + 1 - ceil_offset_i;
float ceil_weight_j = offset_j + 1 - ceil_offset_j;
float floor_weight_i = 1 - ceil_weight_i;
float floor_weight_j = 1 - ceil_weight_j;
float floor_floor_weight = floor_weight_i * floor_weight_j;
float ceil_floor_weight = ceil_weight_i * floor_weight_j;
float floor_ceil_weight = floor_weight_i * ceil_weight_j;
float ceil_ceil_weight = ceil_weight_i * ceil_weight_j;
bool floor_floor_valid = false;
bool ceil_floor_valid = false;
bool floor_ceil_valid = false;
bool ceil_ceil_valid = false;
if (floor_i >= 0 && floor_i < dims.x_h) {
if (floor_j >= 0 && floor_j < dims.x_w) {
floor_floor_valid = true;
}
if (ceil_j >= 0 && ceil_j < dims.x_w) {
floor_ceil_valid = true;
}
}
if (ceil_i >= 0 && ceil_i < dims.x_h) {
if (floor_j >= 0 && floor_j < dims.x_w) {
ceil_floor_valid = true;
}
if (ceil_j >= 0 && ceil_j < dims.x_w) {
ceil_ceil_valid = true;
}
}
for (size_t channel = 0; channel < dims.deformable_group_channel; ++channel) {
float grad =
static_cast<float>(input_grad[input_grad_base_pos + channel * grad_stride.deformable_group_channel_stride]);
float grad_scale = grad * scale_weight;
size_t tmp_input_x_base_pos = input_x_base_pos + channel * input_x_stride.deformable_group_channel_stride;
float current_x_pos;
float floor_floor_value = 0;
float ceil_floor_value = 0;
float floor_ceil_value = 0;
float ceil_ceil_value = 0;
size_t input_x_pos = 0;
if (floor_floor_valid) {
current_x_pos = tmp_input_x_base_pos + floor_i * input_x_stride.h_stride + floor_j * input_x_stride.w_stride;
input_x_pos = static_cast<size_t>(current_x_pos);
floor_floor_value = static_cast<float>(input_x[input_x_pos]);
MsAtomicAdd(output_grad_x, input_x_pos, static_cast<T>(grad_scale * floor_floor_weight));
}
if (ceil_floor_valid) {
current_x_pos = tmp_input_x_base_pos + ceil_i * input_x_stride.h_stride + floor_j * input_x_stride.w_stride;
input_x_pos = static_cast<size_t>(current_x_pos);
ceil_floor_value = static_cast<float>(input_x[input_x_pos]);
MsAtomicAdd(output_grad_x, input_x_pos, static_cast<T>(grad_scale * ceil_floor_weight));
}
if (floor_ceil_valid) {
current_x_pos = tmp_input_x_base_pos + floor_i * input_x_stride.h_stride + ceil_j * input_x_stride.w_stride;
input_x_pos = static_cast<size_t>(current_x_pos);
floor_ceil_value = static_cast<float>(input_x[input_x_pos]);
MsAtomicAdd(output_grad_x, input_x_pos, static_cast<T>(grad_scale * floor_ceil_weight));
}
if (ceil_ceil_valid) {
current_x_pos = tmp_input_x_base_pos + ceil_i * input_x_stride.h_stride + ceil_j * input_x_stride.w_stride;
input_x_pos = static_cast<size_t>(current_x_pos);
ceil_ceil_value = static_cast<float>(input_x[input_x_pos]);
MsAtomicAdd(output_grad_x, input_x_pos, static_cast<T>(grad_scale * ceil_ceil_weight));
}
float delta = -floor_floor_value * floor_weight_j + ceil_floor_value * floor_weight_j -
floor_ceil_value * ceil_weight_j + ceil_ceil_value * ceil_weight_j;
delta *= grad_scale;
output_grad_offset[offset_index_i] += static_cast<T>(delta);
delta = -floor_floor_value * floor_weight_i - ceil_floor_value * ceil_weight_i + floor_ceil_value * floor_weight_i +
ceil_ceil_value * ceil_weight_i;
delta *= grad_scale;
output_grad_offset[offset_index_base_pos] += static_cast<T>(delta);
delta = floor_floor_value * floor_floor_weight + ceil_floor_value * ceil_floor_weight +
floor_ceil_value * floor_ceil_weight + ceil_ceil_value * ceil_ceil_weight;
delta *= grad;
output_grad_offset[offset_index_weight] += static_cast<T>(delta);
}
}
template <typename T>
uint32_t DeformableOffsetsGradKernel::DoComputeNHWC(const CpuKernelContext &ctx, size_t num_kernels,
const DeformableOffsetGradDims &dims, const T *input_x,
const T *input_offset, const T *input_grad, T *output_grad_x,
T *output_grad_offset) const {
OffsetStride offset_stride;
offset_stride.kernel_w_stride = 1;
offset_stride.kernel_h_stride = dims.kernel_w * offset_stride.kernel_w_stride;
offset_stride.deformable_group_stride = dims.kernel_h * offset_stride.kernel_h_stride;
offset_stride.position_stride = dims.deformable_group * offset_stride.deformable_group_stride;
offset_stride.offset_w_stride = kOffsetChannel * offset_stride.position_stride;
offset_stride.offset_h_stride = dims.offset_w * offset_stride.offset_w_stride;
offset_stride.n_stride = dims.offset_h * offset_stride.offset_h_stride;
GradStride grad_stride;
grad_stride.deformable_group_channel_stride = 1;
grad_stride.deformable_group_stride = dims.deformable_group_channel * grad_stride.deformable_group_channel_stride;
grad_stride.kernel_w_stride = dims.deformable_group * grad_stride.deformable_group_stride;
grad_stride.offset_w_stride = dims.kernel_w * grad_stride.kernel_w_stride;
grad_stride.kernel_h_stride = dims.offset_w * grad_stride.offset_w_stride;
grad_stride.offset_h_stride = dims.kernel_h * grad_stride.kernel_h_stride;
grad_stride.n_stride = dims.offset_h * grad_stride.offset_h_stride;
InputXStride input_x_stride;
input_x_stride.deformable_group_channel_stride = 1;
input_x_stride.deformable_group_stride =
dims.deformable_group_channel * input_x_stride.deformable_group_channel_stride;
input_x_stride.w_stride = dims.deformable_group * input_x_stride.deformable_group_stride;
input_x_stride.h_stride = dims.x_w * input_x_stride.w_stride;
input_x_stride.n_stride = dims.x_h * input_x_stride.h_stride;
OffsetIndex offset_index;
auto task = [&offset_index, &dims, &offset_stride, &grad_stride, &input_x, &input_offset, &input_grad, &output_grad_x,
&output_grad_offset, &input_x_stride](size_t start, size_t end) {
for (size_t index = start; index < end; ++index) {
offset_index.kernel_j = index % dims.kernel_w;
size_t tmp = index / dims.kernel_w;
offset_index.kernel_i = tmp % dims.kernel_h;
tmp = tmp / dims.kernel_h;
offset_index.deformable_group_i = tmp % dims.deformable_group;
tmp = tmp / dims.deformable_group;
offset_index.offset_j = tmp % dims.offset_w;
tmp = tmp / dims.offset_w;
offset_index.offset_i = tmp % dims.offset_h;
offset_index.n_i = tmp / dims.offset_h;
DeformableOffsetGradKernel(offset_index, offset_stride, grad_stride, dims, input_x_stride, input_x, input_offset,
input_grad, output_grad_x, output_grad_offset);
}
};
const int64_t per_unit_size = static_cast<int64_t>(num_kernels / std::thread::hardware_concurrency());
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_kernels, per_unit_size, task), "Compute failed.");
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DeformableOffsetsGradKernel::DoComputeNCHW(const CpuKernelContext &ctx, size_t num_kernels,
const DeformableOffsetGradDims &dims, const T *input_x,
const T *input_offset, const T *input_grad, T *output_grad_x,
T *output_grad_offset) const {
OffsetStride offset_stride;
offset_stride.offset_w_stride = 1;
offset_stride.offset_h_stride = dims.offset_w * offset_stride.offset_w_stride;
offset_stride.kernel_w_stride = dims.offset_h * offset_stride.offset_h_stride;
offset_stride.kernel_h_stride = dims.kernel_w * offset_stride.kernel_w_stride;
offset_stride.deformable_group_stride = dims.kernel_h * offset_stride.kernel_h_stride;
offset_stride.position_stride = dims.deformable_group * offset_stride.deformable_group_stride;
offset_stride.n_stride = kOffsetChannel * offset_stride.position_stride;
GradStride grad_stride;
grad_stride.kernel_w_stride = 1;
grad_stride.offset_w_stride = dims.kernel_w * grad_stride.kernel_w_stride;
grad_stride.kernel_h_stride = dims.offset_w * grad_stride.offset_w_stride;
grad_stride.offset_h_stride = dims.kernel_h * grad_stride.kernel_h_stride;
grad_stride.deformable_group_channel_stride = dims.offset_h * grad_stride.offset_h_stride;
grad_stride.deformable_group_stride = dims.deformable_group_channel * grad_stride.deformable_group_channel_stride;
grad_stride.n_stride = dims.deformable_group * grad_stride.deformable_group_stride;
InputXStride input_x_stride;
input_x_stride.w_stride = 1;
input_x_stride.h_stride = dims.x_w * input_x_stride.w_stride;
input_x_stride.deformable_group_channel_stride = dims.x_h * input_x_stride.h_stride;
input_x_stride.deformable_group_stride =
dims.deformable_group_channel * input_x_stride.deformable_group_channel_stride;
input_x_stride.n_stride = dims.deformable_group * input_x_stride.deformable_group_stride;
OffsetIndex offset_index;
auto task = [&offset_index, &dims, &offset_stride, &grad_stride, &input_x, &input_offset, &input_grad, &output_grad_x,
&output_grad_offset, &input_x_stride](size_t start, size_t end) {
for (size_t index = start; index < end; ++index) {
offset_index.offset_j = index % dims.offset_w;
size_t tmp = index / dims.offset_w;
offset_index.offset_i = tmp % dims.offset_h;
tmp = tmp / dims.offset_h;
offset_index.kernel_j = tmp % dims.kernel_w;
tmp = tmp / dims.kernel_w;
offset_index.kernel_i = tmp % dims.kernel_h;
tmp = tmp / dims.kernel_h;
offset_index.deformable_group_i = tmp % dims.deformable_group;
offset_index.n_i = tmp / dims.deformable_group;
DeformableOffsetGradKernel(offset_index, offset_stride, grad_stride, dims, input_x_stride, input_x, input_offset,
input_grad, output_grad_x, output_grad_offset);
}
};
const int64_t per_unit_size = static_cast<int64_t>(num_kernels / std::thread::hardware_concurrency());
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_kernels, per_unit_size, task), "Compute failed.");
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsGradKernel::ParseKernelParam(const CpuKernelContext &ctx) {
const size_t &num_input = ctx.GetInputsSize();
const size_t &num_output = ctx.GetOutputsSize();
auto ret = CheckInOutNum(num_input, num_output);
if (ret != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("It should get %zu inputs and %zu outputs, but got %zu input and %zu outputs.", kInputNum,
kInputNum, num_input, num_output);
return KERNEL_STATUS_PARAM_INVALID;
}
auto grad_x_output_tensor = ctx.Output(kGradXIndex);
// Get the dtype of the inputs
index_type_ = grad_x_output_tensor->GetDataType();
const auto &output_shape = grad_x_output_tensor->GetTensorShape();
index_output_size_ = output_shape->NumElements();
index_output_shape_ = output_shape->GetDimSizes();
auto grad_offset_output_tensor = ctx.Output(kGradOffsetIndex);
const auto &grad_output_shape = grad_offset_output_tensor->GetTensorShape();
grad_output_size_ = grad_output_shape->NumElements();
grad_output_shape_ = grad_output_shape->GetDimSizes();
ret = SetDims(ctx);
if (ret != KERNEL_STATUS_OK) {
KERNEL_LOG_ERROR("Set dims failed.");
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t DeformableOffsetsGradKernel::DeformableOffsetsGradTask(const CpuKernelContext &ctx) {
const size_t num_kernels =
dims_.x_n * dims_.offset_h * dims_.offset_w * dims_.kernel_h * dims_.kernel_w * dims_.deformable_group;
const T *input_grad = reinterpret_cast<T *>(ctx.Input(kGradXIndex)->GetData());
const T *input_x = reinterpret_cast<T *>(ctx.Input(kXIndex)->GetData());
const T *input_offset = reinterpret_cast<T *>(ctx.Input(kOffsetIndex)->GetData());
T *output_grad_x = reinterpret_cast<T *>(ctx.Output(kGradXIndex)->GetData());
T *output_grad_offset = reinterpret_cast<T *>(ctx.Output(kGradOffsetIndex)->GetData());
auto grad_x_size = static_cast<size_t>(index_output_size_ * sizeof(T));
auto grad_offset_size = static_cast<size_t>(grad_output_size_ * sizeof(T));
// Reset output initial value to 0.
auto ret = memset_s(output_grad_x, grad_x_size, 0, grad_x_size);
if (ret != 0) {
return KERNEL_STATUS_INNER_ERROR;
}
ret = memset_s(output_grad_offset, grad_offset_size, 0, grad_offset_size);
if (ret != 0) {
return KERNEL_STATUS_INNER_ERROR;
}
if (data_format_ == kNCHW) {
ret =
DoComputeNCHW<T>(ctx, num_kernels, dims_, input_x, input_offset, input_grad, output_grad_x, output_grad_offset);
} else {
ret =
DoComputeNHWC<T>(ctx, num_kernels, dims_, input_x, input_offset, input_grad, output_grad_x, output_grad_offset);
}
return ret;
}
uint32_t DeformableOffsetsGradKernel::CheckInOutNum(size_t inputs_num, size_t outputs_num) const {
if (inputs_num != kInputNum) {
KERNEL_LOG_ERROR("The number of inputs must be %d but got %d", kInputNum, inputs_num);
return KERNEL_STATUS_PARAM_INVALID;
}
if (outputs_num != kOutputNum) {
KERNEL_LOG_ERROR("The number of outputs must be %d but got %d", kOutputNum, outputs_num);
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsGradKernel::SetDims(const CpuKernelContext &ctx) {
dims_.deformable_group = static_cast<size_t>(ctx.GetAttr(kDeformableGroups)->GetInt());
if (dims_.deformable_group == 0) {
KERNEL_LOG_ERROR("Deformable group must be greater than 0, but got 0");
return KERNEL_STATUS_PARAM_INVALID;
}
auto pad = ctx.GetAttr(kPads)->GetListInt();
if (pad.size() != kPadNum) {
KERNEL_LOG_ERROR("the length of 'pad' must be %d but got %d", kPadNum, pad.size());
return KERNEL_STATUS_PARAM_INVALID;
}
dims_.pad_top = static_cast<size_t>(pad[kPadTopIndex]);
dims_.pad_left = static_cast<size_t>(pad[kPadLeftIndex]);
auto stride = ctx.GetAttr(kStrides)->GetListInt();
if (stride.size() != kStrideNum) {
KERNEL_LOG_ERROR("The length of 'stride' must be %d but got %d", kStrideNum, stride.size());
return KERNEL_STATUS_PARAM_INVALID;
}
dims_.stride_h = static_cast<size_t>(stride[kStrideHIndex]);
dims_.stride_w = static_cast<size_t>(stride[kStrideWIndex]);
auto dilation = ctx.GetAttr(kDilations)->GetListInt();
if (dilation.size() != kDilationNum) {
KERNEL_LOG_ERROR("The length of 'dilation' must be %d but got %d", kDilationNum, dilation.size());
return KERNEL_STATUS_PARAM_INVALID;
}
dims_.dilation_h = static_cast<size_t>(dilation[kDilationHIndex]);
dims_.dilation_w = static_cast<size_t>(dilation[kDilationWIndex]);
auto ksize = ctx.GetAttr(kSize)->GetListInt();
if (ksize.size() != kKernelSizeNum) {
KERNEL_LOG_ERROR("The length of 'ksize' must be %d but got %d", kKernelSizeNum, ksize.size());
return KERNEL_STATUS_PARAM_INVALID;
}
dims_.kernel_h = static_cast<size_t>(ksize[kKernelHIndex]);
dims_.kernel_w = static_cast<size_t>(ksize[kKernelWIndex]);
if (dims_.kernel_h == 0 || dims_.kernel_w == 0) {
KERNEL_LOG_ERROR("The value of 'ksize' must be larger than 0.");
return KERNEL_STATUS_PARAM_INVALID;
}
auto input_tensor = ctx.Input(kXIndex)->GetTensorShape();
auto x_shape = input_tensor->GetDimSizes();
dims_.x_n = static_cast<size_t>(x_shape[0]);
auto grad_index_input_tensor = ctx.Input(kGradIndex);
const auto &grad_shape = grad_index_input_tensor->GetTensorShape()->GetDimSizes();
data_format_ = ctx.GetAttr(kDataformat)->GetString();
if (data_format_ == kNCHW) {
dims_.grad_h = static_cast<size_t>(grad_shape[kHIndexForNCHW]);
dims_.grad_w = static_cast<size_t>(grad_shape[kWIndexForNCHW]);
dims_.x_h = static_cast<size_t>(x_shape[kHIndexForNCHW]);
dims_.x_w = static_cast<size_t>(x_shape[kWIndexForNCHW]);
dims_.deformable_group_channel = static_cast<size_t>(x_shape[kCIndexForNCHW]) / dims_.deformable_group;
} else {
dims_.grad_h = static_cast<size_t>(grad_shape[kHIndexForNHWC]);
dims_.grad_w = static_cast<size_t>(grad_shape[kWIndexForNHWC]);
dims_.x_h = static_cast<size_t>(x_shape[kHIndexForNHWC]);
dims_.x_w = static_cast<size_t>(x_shape[kWIndexForNHWC]);
dims_.deformable_group_channel = static_cast<size_t>(x_shape[kCIndexForNHWC]) / dims_.deformable_group;
}
dims_.offset_h = dims_.grad_h / dims_.kernel_h;
dims_.offset_w = dims_.grad_w / dims_.kernel_w;
return KERNEL_STATUS_OK;
}
uint32_t DeformableOffsetsGradKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(ParseKernelParam(ctx), "DeformableOffsetsGrad normal check failed.");
uint32_t ret = KERNEL_STATUS_OK;
switch (index_type_) {
case DT_FLOAT:
ret = DeformableOffsetsGradTask<float>(ctx);
break;
case DT_FLOAT16:
ret = DeformableOffsetsGradTask<Eigen::half>(ctx);
break;
default:
KERNEL_LOG_ERROR("Error type %s.", DTypeStr(index_type_).c_str());
return KERNEL_STATUS_INNER_ERROR;
}
return ret;
}
REGISTER_CPU_KERNEL(kDeformableOffsetsGrad, DeformableOffsetsGradKernel);
} // namespace aicpu

View File

@ -0,0 +1,84 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_GRAD_KERNELS_H_
#define AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_GRAD_KERNELS_H_
#include <algorithm>
#include <string>
#include <vector>
#include <utility>
#include "cpu_ops_kernel.h"
namespace aicpu {
struct DeformableOffsetGradDims {
size_t x_n = 0;
size_t x_h = 0;
size_t x_w = 0;
size_t offset_h = 0;
size_t offset_w = 0;
size_t grad_h = 0;
size_t grad_w = 0;
size_t kernel_h = 0;
size_t kernel_w = 0;
size_t pad_top = 0;
size_t pad_left = 0;
size_t stride_h = 0;
size_t stride_w = 0;
size_t dilation_h = 0;
size_t dilation_w = 0;
size_t deformable_group = 0;
size_t deformable_group_channel = 0;
};
class DeformableOffsetsGradKernel : public CpuKernel {
public:
DeformableOffsetsGradKernel() = default;
~DeformableOffsetsGradKernel() = default;
protected:
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ParseKernelParam(const CpuKernelContext &ctx);
uint32_t CheckInOutNum(size_t inputs_num, size_t outputs_num) const;
uint32_t SetDims(const CpuKernelContext &ctx);
template <typename T>
uint32_t DoComputeNHWC(const CpuKernelContext &ctx, size_t num_kernels, const DeformableOffsetGradDims &dims,
const T *input_x, const T *input_offset, const T *input_grad, T *output_grad_x,
T *output_grad_offset) const;
template <typename T>
uint32_t DoComputeNCHW(const CpuKernelContext &ctx, size_t num_kernels, const DeformableOffsetGradDims &dims,
const T *input_x, const T *input_offset, const T *input_grad, T *output_grad_x,
T *output_grad_offset) const;
template <typename T>
uint32_t DeformableOffsetsGradTask(const CpuKernelContext &ctx);
std::string data_format_ = "ND";
DeformableOffsetGradDims dims_;
DataType index_type_{DT_FLOAT};
int64_t index_output_size_ = 1;
int64_t grad_output_size_ = 1;
std::vector<int64_t> index_output_shape_;
std::vector<int64_t> grad_output_shape_;
std::vector<int64_t> output_shape_;
};
} // namespace aicpu
#endif // AICPU_OPS_AICPU_DEFORMABLE_OFFSETS_GRAD_KERNELS_H_

View File

@ -204,6 +204,8 @@ constexpr auto kQuantDTypeCast = "QuantDTypeCast";
constexpr auto kFSEDecode = "FSEDecode";
constexpr auto kSparseSegmentSum = "SparseSegmentSum";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kDeformableOffsets = "DeformableOffsets";
constexpr auto kDeformableOffsetsGrad = "DeformableOffsetsGrad";
const std::set<std::string> kCpuKernelOps{kIdentity,
kGather,

View File

@ -18,6 +18,7 @@
#include <vector>
#include <algorithm>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
@ -118,6 +119,10 @@ const AnfNodePtr DeformableOffsetsFusion::Process(const FuncGraphPtr &func_graph
new_cnode->set_scope(deformable_offsets_cnode->scope());
common::AnfAlgo::CopyNodeAttrs(deformable_offsets_cnode, new_cnode);
common::AnfAlgo::SetNodeAttr(kAttrDataFormat, MakeValue("NHWC"), new_cnode);
if (!CheckAICoreSupportedAny(new_cnode)) {
MS_LOG(INFO) << "DeformableOffsets failed, return to aicpu.";
return nullptr;
}
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Add assist tensor for DeformableOffsets op success.";

View File

@ -17,6 +17,7 @@
#include <memory>
#include <vector>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
@ -62,6 +63,10 @@ const AnfNodePtr DeformableOffsetsGradFusion::Process(const FuncGraphPtr &func_g
new_cnode->set_scope(deformable_offsets_grad_cnode->scope());
common::AnfAlgo::CopyNodeAttrs(deformable_offsets_grad_cnode, new_cnode);
common::AnfAlgo::SetNodeAttr(kAttrDataFormat, MakeValue("NHWC"), new_cnode);
if (!CheckAICoreSupportedAny(new_cnode)) {
MS_LOG(INFO) << "DeformableOffsetsGrad failed, return to aicpu.";
return nullptr;
}
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Add assist tensor for DeformableOffsets op success.";

View File

@ -322,6 +322,8 @@ bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const {
mindspore::kMatrixInverseOpName,
mindspore::kMultiMarginLossGradOpName,
mindspore::kSspaddmmOpName,
mindspore::kDeformableOffsetsOpName,
mindspore::kDeformableOffsetsGradOpName,
mindspore::kBatchMatMulOpName,
mindspore::kSparseToDenseV2OpName,
mindspore::kTrilOpName,

View File

@ -412,6 +412,8 @@ from .tracegrad import _tracegrad_aicpu
from .tridiagonal_solve import _tridiagonal_solve_aicpu
from .truncated_normal import _truncated_normal_aicpu
from .glu import _glu_aicpu
from .deformable_offsets import _deformable_offsets_aicpu
from .deformable_offsets_grad import _deformable_offsets_grad_aicpu
from .multi_margin_loss import _multi_margin_loss_aicpu
from .multi_margin_loss_grad import _multi_margin_loss_grad_aicpu
from .sparse_to_dense_v2 import _sparse_to_dense_v2_aicpu

View File

@ -21,8 +21,14 @@ deformable_offsets_op_info = AiCPURegOp("DeformableOffsets") \
.input(0, "x", "required") \
.input(1, "offsets", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.attr("strides", "listInt") \
.attr("pads", "listInt") \
.attr("ksize", "listInt") \
.attr("dilations", "listInt", "{1,1,1,1}") \
.attr("deformable_groups", "int", "1") \
.attr("modulated", "bool", "true") \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()

View File

@ -0,0 +1,43 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeformableOffsetsGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
deformable_offsets_grad_op_info = AiCPURegOp("DeformableOffsetsGrad") \
.fusion_type("OPAQUE") \
.input(0, "grad", "required") \
.input(1, "x", "required") \
.input(2, "offsets", "required") \
.output(0, "grad_x", "required") \
.output(1, "grad_offsets", "required") \
.attr("strides", "listInt") \
.attr("pads", "listInt") \
.attr("ksize", "listInt") \
.attr("dilations", "listInt", "{1, 1, 1, 1}") \
.attr("format", "str", "NCHW") \
.attr("deformable_groups", "int", "1") \
.attr("modulated", "bool", "true") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.get_op_info()
@op_info_register(deformable_offsets_grad_op_info)
def _deformable_offsets_grad_aicpu():
"""DeformableOffsetsGrad AiCPU register"""
return

View File

@ -0,0 +1,182 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import composite as C
from mindspore import nn
from mindspore import Tensor
from mindspore import dtype
from mindspore.ops.operations import nn_ops
grad_all = C.GradOperation(get_all=True)
class TestNetwork(nn.Cell):
def __init__(self):
super(TestNetwork, self).__init__()
stride = (1, 1, 1, 1)
pad = (0, 0, 0, 0)
ksize = (2, 2)
self.deformable_offsets_grad_op = G.DeformableOffsetsGrad(stride, pad, ksize)
def construct(self, dout, x, offsets):
output = self.deformable_offsets_grad_op(dout, x, offsets)
return output
def test_grad_infer():
"""
Feature: Aicpu operation.
Description: Test of Aicpu operation: DeformableOffsetsGrad
Expectation: No exception raised.
"""
dout = Tensor(np.ones([1, 1, 2, 2]), dtype.float32)
x = Tensor(np.ones([1, 1, 2, 2]), dtype.float32)
offsets = Tensor(np.array([0.1] * 12).astype(np.float32).reshape([1, 12, 1, 1]))
net = TestNetwork()
grad = net(dout, x, offsets)
print("grad_x:", grad[0])
print("grad_offset:", grad[1])
return grad
class ForwardNet(nn.Cell):
def __init__(self):
super(ForwardNet, self).__init__()
stride = (1, 1, 1, 1)
pad = (0, 0, 0, 0)
ksize = (2, 2)
self.deformable_offsets_grad_op = nn_ops.DeformableOffsets(stride, pad, ksize)
def construct(self, x, offsets):
output = self.deformable_offsets_grad_op(x, offsets)
return output
class BackwardNet(nn.Cell):
def __init__(self, net):
super(BackwardNet, self).__init__()
self.net = net
def construct(self, *inputs):
out = self.net(*inputs)
return out, grad_all(self.net)(*inputs)
def test_auto_diff():
"""
Feature: Aicpu operation.
Description: Test of Aicpu operation: DeformableOffsetsGrad by auto diff.
Expectation: No exception raised.
"""
x = Tensor(np.ones([1, 1, 2, 2]), dtype.float32)
offsets = Tensor(np.array([0.1] * 12).astype(np.float32).reshape([1, 12, 1, 1]))
forward_net = ForwardNet()
net = BackwardNet(forward_net)
grad = net(x, offsets)
print("grad_x:", grad[0])
print("grad_offset:", grad[1])
return grad
class NetDeformableOffsetsGrad(nn.Cell):
def __init__(self, data_format):
super(NetDeformableOffsetsGrad, self).__init__()
strides = (1, 1, 1, 1)
pads = (0, 0, 0, 0)
ksize = (3, 3)
self.grad_op = G.DeformableOffsetsGrad(strides, pads, ksize, data_format=data_format)
def construct(self, grad, input_x, offsets):
return self.grad_op(grad, input_x, offsets)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [np.float16, np.float32])
def test_deformable_offsets_grad_nchw(data_type):
"""
Feature: DeformableOffsetsGrad aicpu kernel
Description: test the rightness of DeformableOffsetsGrad gpu kernel
Expectation: the output is same as expected result
"""
net = NetDeformableOffsetsGrad(data_format="NCHW")
dout = Tensor(np.ones([1, 2, 3, 3]).astype(data_type))
x = Tensor(np.ones([1, 2, 4, 4]).astype(data_type))
offsets = Tensor(np.ones([1, 27, 1, 1]).astype(data_type) * 0.1)
output = net(dout, x, offsets)
expect_grad_x = np.array([[[0.081, 0.09, 0.09, 0.009],
[0.09, 0.1, 0.1, 0.01],
[0.09, 0.1, 0.1, 0.01],
[0.009, 0.01, 0.01, 0.001]],
[[0.081, 0.09, 0.09, 0.009],
[0.09, 0.1, 0.1, 0.01],
[0.09, 0.1, 0.1, 0.01],
[0.009, 0.01, 0.01, 0.001]]]
).astype(data_type)
expect_grad_offset = np.array([0] * 18 + [2.0] * 9).astype(data_type).reshape([1, 27, 1, 1])
rtol = 1e-5
if data_type == np.float16:
rtol = 1e-3
assert np.allclose(output[0].asnumpy(), expect_grad_x, rtol)
assert np.allclose(output[1].asnumpy(), expect_grad_offset, rtol)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('data_type', [np.float16, np.float32])
def test_deformable_offsets_grad_nhwc(data_type):
"""
Feature: DeformableOffsetsGrad aicpu kernel
Description: test the rightness of DeformableOffsetsGrad aicpu kernel
Expectation: the output is same as expected result
"""
net = NetDeformableOffsetsGrad(data_format="NHWC")
dout = Tensor(np.ones([1, 3, 3, 2]).astype(data_type))
x = Tensor(np.ones([1, 4, 4, 2]).astype(data_type))
offsets = Tensor(np.ones([1, 1, 1, 27]).astype(data_type) * 0.1)
output = net(dout, x, offsets)
expect_grad_x = np.array([[[0.081, 0.081],
[0.09, 0.09],
[0.09, 0.09],
[0.009, 0.009]],
[[0.09, 0.09],
[0.1, 0.1],
[0.1, 0.1],
[0.01, 0.01]],
[[0.09, 0.09],
[0.1, 0.1],
[0.1, 0.1],
[0.01, 0.01]],
[[0.009, 0.009],
[0.01, 0.01],
[0.01, 0.01],
[0.001, 0.001]]
]
).astype(data_type)
expect_grad_offset = np.array([0] * 18 + [2.0] * 9).astype(data_type).reshape([1, 1, 1, 27])
rtol = 1e-5
if data_type == np.float16:
rtol = 1e-3
assert np.allclose(output[0].asnumpy(), expect_grad_x, rtol)
assert np.allclose(output[1].asnumpy(), expect_grad_offset, rtol)