!39034 Clean code for ops

Merge pull request !39034 from YuJianfeng/clean
This commit is contained in:
i-robot 2022-07-28 11:14:18 +00:00 committed by Gitee
commit f3a5d0d1d5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 25 additions and 25 deletions

View File

@ -180,7 +180,7 @@ int DeformableOffsetsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
output_h_ = output_shape[h_axis_];
output_w_ = output_shape[w_axis_];
position_grid_size_ = output_h_ * output_w_;
(void)workspace_size_list_.emplace_back(sizeof(int64_t) * position_grid_size_ * kKernelSizeSize);
(void)workspace_size_list_.emplace_back(sizeof(int64_t) * LongToSize(position_grid_size_) * kKernelSizeSize);
return KRET_OK;
}
@ -214,7 +214,7 @@ void DeformableOffsetsCpuKernelMod::DeformableOffsets(const T *input_addr, const
int64_t offset_kh_dim = offset_kw_dim * kernel_size_[kKernelSizeWIndex];
int64_t offset_group_dim = offset_kh_dim * kernel_size_[kKernelSizeHIndex];
int64_t offset_mask_dim = offset_group_dim * deformable_groups_;
int64_t offset_n_dim = offset_mask_dim * kOffsetsSize;
int64_t offset_n_dim = offset_mask_dim * SizeToLong(kOffsetsSize);
int64_t input_c_dim = input_h_ * input_w_;
int64_t input_n_dim = input_c_dim * c_;

View File

@ -44,7 +44,6 @@ class DeformableOffsetsCpuKernelMod : public NativeCpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
@ -63,20 +62,20 @@ class DeformableOffsetsCpuKernelMod : public NativeCpuKernelMod,
std::vector<int64_t> pads_;
std::vector<int64_t> kernel_size_;
std::vector<int64_t> dilations_;
int64_t deformable_groups_;
bool modulated_;
int64_t deformable_groups_{1};
bool modulated_{true};
int64_t n_axis_;
int64_t c_axis_;
int64_t h_axis_;
int64_t w_axis_;
int64_t n_;
int64_t c_;
int64_t input_h_;
int64_t input_w_;
int64_t output_h_;
int64_t output_w_;
int64_t position_grid_size_;
size_t n_axis_{kIndex0};
size_t c_axis_{kIndex1};
size_t h_axis_{kIndex2};
size_t w_axis_{kIndex3};
int64_t n_{0};
int64_t c_{0};
int64_t input_h_{0};
int64_t input_w_{0};
int64_t output_h_{0};
int64_t output_w_{0};
int64_t position_grid_size_{0};
};
} // namespace kernel
} // namespace mindspore

View File

@ -40,7 +40,6 @@ bool ArgMaxWithValue::keep_dims() const {
namespace {
abstract::TupleShapePtr ArgMaxWithValueInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape_ptr = input_args[0]->BuildShape();
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr);
auto x_shape = x_shape_map[kShape];
@ -67,7 +66,7 @@ abstract::TupleShapePtr ArgMaxWithValueInferShape(const PrimitivePtr &primitive,
auto cal_shape = [axis, keep_dims](ShapeVector &shape, const ShapeVector &x_shape) -> void {
(void)shape.insert(shape.end(), x_shape.begin(), x_shape.end());
if (keep_dims) {
shape[axis] = 1;
shape[LongToSize(axis)] = 1;
} else {
(void)shape.erase(shape.begin() + axis);
}

View File

@ -64,7 +64,7 @@ std::vector<int64_t> CheckAttrTuple(const PrimitivePtr &prim, const std::string
}
std::vector<int64_t> CheckAttrTupleAndNCDimensions(const PrimitivePtr &primitive, const std::string &attr_name,
size_t num, int n_axis, int c_axis) {
size_t num, uint64_t n_axis, uint64_t c_axis) {
std::vector<int64_t> tuple = CheckAttrTuple(primitive, attr_name, num);
if (tuple[n_axis] != 1 || tuple[c_axis] != 1) {
MS_EXCEPTION(ValueError)
@ -86,16 +86,18 @@ void DeformableOffsetsPadFunction(std::vector<int64_t> *output_hw, const std::ve
constexpr size_t left_index = 2;
constexpr size_t right_index = 3;
if (x_h != abstract::Shape::SHP_ANY) {
out_h = static_cast<int64_t>(std::floor(1 + ((x_h * 1.0) + pads[top_index] + pads[bottom_index] - kernel_size[0] -
static_cast<float>((kernel_size[0] - 1) * (dilations[h_axis] - 1))) /
out_h = static_cast<int64_t>(
std::floor(1 + ((x_h * 1.0) + pads[top_index] + pads[bottom_index] - kernel_size[0] -
static_cast<double>((int64_t)LongToInt(kernel_size[0] - 1) * LongToInt(dilations[h_axis] - 1))) /
strides[h_axis]));
if (is_min_shape && out_h < 1) {
out_h = 1L;
}
}
if (x_w != abstract::Shape::SHP_ANY) {
out_w = static_cast<int64_t>(std::floor(1 + ((x_w * 1.0) + pads[left_index] + pads[right_index] - kernel_size[1] -
static_cast<float>((kernel_size[1] - 1) * (dilations[w_axis] - 1))) /
out_w = static_cast<int64_t>(
std::floor(1 + ((x_w * 1.0) + pads[left_index] + pads[right_index] - kernel_size[1] -
static_cast<double>((int64_t)LongToInt(kernel_size[1] - 1) * LongToInt(dilations[w_axis] - 1))) /
strides[w_axis]));
if (is_min_shape && out_w < 1) {
out_w = 1L;