fix code check warnings

This commit is contained in:
wenkai 2022-09-22 20:08:46 +08:00
parent c4995837d2
commit cfe8e0def0
5 changed files with 58 additions and 45 deletions

View File

@ -16,6 +16,9 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_ATOMIC_ADD_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_ATOMIC_ADD_H_
#include <stdint.h>
#include "utils/log_adapter.h"
namespace mindspore {

View File

@ -28,18 +28,25 @@ namespace {
constexpr int kDyOutputDimIndex = 1;
constexpr int kDyHeightIndex = 2;
constexpr int kDyWidthIndex = 3;
constexpr int INPUT_NUM = 2;
constexpr int OUTPUT_NUM = 1;
constexpr int OUT_PUT_SHAPE_SIZE = 4;
constexpr int DY_SHAPE_SIZE = 4;
constexpr int DX_SHAPE_SIZE = 4;
constexpr int ROI_SHAPE_SIZE = 3;
constexpr int ROIS_NUM_INDEX = 2;
} // namespace
template <typename T>
void PSROIPoolingGradCpuKernelMod::PSROIPoolBackward(size_t start, size_t end, const T *input_diff, T *output_diff,
T *roi_boxes) {
T *roi_boxes) const {
auto output_channels = output_channels_;
auto pooled_width = pooled_width_;
auto pooled_height = pooled_height_;
auto feature_channels = feature_channels_;
auto feature_width = width_;
auto feature_height = height_;
auto spatial_scale = (T)spatial_scale_;
auto spatial_scale = static_cast<T>(spatial_scale_);
auto rois_num = rois_num_;
auto elements_per_roi_box = 5;
constexpr float zero = 0;
@ -73,17 +80,21 @@ void PSROIPoolingGradCpuKernelMod::PSROIPoolBackward(size_t start, size_t end, c
T roi_end_height = static_cast<T>(roundf(static_cast<float>(roi_end_height_before_round)));
// let min roi len and width bigger than 0.1
T roi_width = std::max(roi_end_width - roi_start_width, (T)0.1);
T roi_height = std::max(roi_end_height - roi_start_height, (T)0.1);
T roi_width = std::max(roi_end_width - roi_start_width, static_cast<T>(0.1));
T roi_height = std::max(roi_end_height - roi_start_height, static_cast<T>(0.1));
// Compute bin_width and bin_height
T bin_height = roi_height / static_cast<T>(pooled_height);
T bin_width = roi_width / static_cast<T>(pooled_width);
// compute pooling area's position
int pooling_start_x = floor(static_cast<float>(static_cast<T>(height_offset_n) * bin_height + roi_start_height));
int pooling_start_y = floor(static_cast<float>(static_cast<T>(width_offset_n) * bin_width + roi_start_width));
int pooling_end_x = ceil(static_cast<float>(static_cast<T>(height_offset_n + 1) * bin_height + roi_start_height));
int pooling_end_y = ceil(static_cast<float>(static_cast<T>(width_offset_n + 1) * bin_width + roi_start_width));
int pooling_start_x =
static_cast<int>(floor(static_cast<float>(static_cast<T>(height_offset_n) * bin_height + roi_start_height)));
int pooling_start_y =
static_cast<int>(floor(static_cast<float>(static_cast<T>(width_offset_n) * bin_width + roi_start_width)));
int pooling_end_x =
static_cast<int>(ceil(static_cast<float>(static_cast<T>(height_offset_n + 1) * bin_height + roi_start_height)));
int pooling_end_y =
static_cast<int>(ceil(static_cast<float>(static_cast<T>(width_offset_n + 1) * bin_width + roi_start_width)));
// Add roi offsets and clip to input boundaries
pooling_start_x = std::min(std::max(pooling_start_x, 0), feature_height);
pooling_end_x = std::min(std::max(pooling_end_x, 0), feature_height);
@ -133,7 +144,7 @@ bool PSROIPoolingGradCpuKernelMod::IsSupportedDtype(TypeId type_id) {
return false;
}
int PSROIPoolingGradCpuKernelMod::ResizeCheckInputs(const std::vector<KernelTensorPtr> &inputs) {
int PSROIPoolingGradCpuKernelMod::ResizeCheckInputs(const std::vector<KernelTensorPtr> &inputs) const {
size_t input_num = inputs.size();
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input number is expected to be " << input_num
@ -233,11 +244,15 @@ int PSROIPoolingGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
feature_channels_ = output_channels_ * group_size_ * group_size_;
for (auto tensor_ptr : inputs) {
if (tensor_ptr->IsDynamicShape()) return KRET_UNKNOWN_SHAPE;
if (tensor_ptr->IsDynamicShape()) {
return KRET_UNKNOWN_SHAPE;
}
}
for (auto tensor_ptr : outputs) {
if (tensor_ptr->IsDynamicShape()) return KRET_UNKNOWN_OUT_SHAPE;
if (tensor_ptr->IsDynamicShape()) {
return KRET_UNKNOWN_OUT_SHAPE;
}
}
auto dy_shape = inputs[0]->GetShapeVector();
@ -287,11 +302,11 @@ bool PSROIPoolingGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto output_size = output_channels_ * pooled_height_ * pooled_width_ * output_n_;
if (data_type_id_ == kNumberTypeFloat32) {
auto top_diff = reinterpret_cast<float *>(inputs[0]->addr);
auto top_diff = static_cast<float *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(top_diff);
auto rois = reinterpret_cast<float *>(inputs[1]->addr);
auto rois = static_cast<float *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(rois);
auto output_diff = reinterpret_cast<float *>(outputs[0]->addr);
auto output_diff = static_cast<float *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_diff);
constexpr size_t unit_size = sizeof(float);
@ -308,11 +323,11 @@ bool PSROIPoolingGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
}
if (data_type_id_ == kNumberTypeFloat16) {
auto top_diff = reinterpret_cast<float16 *>(inputs[0]->addr);
auto top_diff = static_cast<float16 *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(top_diff);
auto rois = reinterpret_cast<float16 *>(inputs[1]->addr);
auto rois = static_cast<float16 *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(rois);
auto output_diff = reinterpret_cast<float16 *>(outputs[0]->addr);
auto output_diff = static_cast<float16 *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_diff);
constexpr size_t unit_size = sizeof(float16);

View File

@ -25,14 +25,6 @@
namespace mindspore {
namespace kernel {
#define INPUT_NUM 2
#define OUTPUT_NUM 1
#define OUT_PUT_SHAPE_SIZE 4
#define DY_SHAPE_SIZE 4
#define DX_SHAPE_SIZE 4
#define ROI_SHAPE_SIZE 3
#define ROIS_NUM_INDEX 2
class PSROIPoolingGradCpuKernelMod : public NativeCpuKernelMod {
public:
PSROIPoolingGradCpuKernelMod() = default;
@ -71,9 +63,9 @@ class PSROIPoolingGradCpuKernelMod : public NativeCpuKernelMod {
TypeId data_type_id_{kNumberTypeFloat32};
template <typename T>
void PSROIPoolBackward(size_t start, size_t end, const T *input_diff, T *output_diff, T *roi_boxes);
void PSROIPoolBackward(size_t start, size_t end, const T *input_diff, T *output_diff, T *roi_boxes) const;
int ResizeCheckInputs(const std::vector<KernelTensorPtr> &inputs);
int ResizeCheckInputs(const std::vector<KernelTensorPtr> &inputs) const;
};
} // namespace kernel
} // namespace mindspore

View File

@ -22,6 +22,13 @@ namespace {
constexpr int kDyOutputDimIndex = 1;
constexpr int kDyHeightIndex = 2;
constexpr int kDyWidthIndex = 3;
constexpr int INPUT_NUM = 2;
constexpr int OUTPUT_NUM = 1;
constexpr int OUT_PUT_SHAPE_SIZE = 4;
constexpr int DY_SHAPE_SIZE = 4;
constexpr int DX_SHAPE_SIZE = 4;
constexpr int ROI_SHAPE_SIZE = 3;
constexpr int ROIS_NUM_INDEX = 2;
} // namespace
bool PSROIPoolingBackV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator,
@ -149,11 +156,15 @@ int PSROIPoolingBackV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
feature_channels_ = output_channels_ * group_size_ * group_size_;
for (auto tensor_ptr : inputs) {
if (tensor_ptr->IsDynamicShape()) return KRET_UNKNOWN_SHAPE;
if (tensor_ptr->IsDynamicShape()) {
return KRET_UNKNOWN_SHAPE;
}
}
for (auto tensor_ptr : outputs) {
if (tensor_ptr->IsDynamicShape()) return KRET_UNKNOWN_OUT_SHAPE;
if (tensor_ptr->IsDynamicShape()) {
return KRET_UNKNOWN_OUT_SHAPE;
}
}
auto dy_shape = inputs[0]->GetShapeVector();
@ -203,28 +214,28 @@ bool PSROIPoolingBackV2GpuKernelMod::Launch(const std::vector<AddressPtr> &input
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (data_type_id_ == kNumberTypeFloat32) {
auto top_diff = reinterpret_cast<float *>(inputs[0]->addr);
auto top_diff = static_cast<float *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(top_diff);
auto rois = reinterpret_cast<float *>(inputs[1]->addr);
auto rois = static_cast<float *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(rois);
auto output_diff = reinterpret_cast<float *>(outputs[0]->addr);
auto output_diff = static_cast<float *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_diff);
PSROIPoolBackwardV2Launcher(top_diff, batch_size_, output_n_, static_cast<float>(spatial_scale_), feature_channels_,
height_, width_, pooled_width_, pooled_height_, output_channels_, output_diff, rois,
reinterpret_cast<cudaStream_t>(stream_ptr), rois_num_, group_size_);
static_cast<cudaStream_t>(stream_ptr), rois_num_, group_size_);
return true;
}
if (data_type_id_ == kNumberTypeFloat16) {
auto top_diff = reinterpret_cast<half *>(inputs[0]->addr);
auto top_diff = static_cast<half *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(top_diff);
auto rois = reinterpret_cast<half *>(inputs[1]->addr);
auto rois = static_cast<half *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(rois);
auto output_diff = reinterpret_cast<half *>(outputs[0]->addr);
auto output_diff = static_cast<half *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_diff);
PSROIPoolBackwardV2Launcher(top_diff, batch_size_, output_n_, static_cast<half>(spatial_scale_), feature_channels_,
height_, width_, pooled_width_, pooled_height_, output_channels_, output_diff, rois,
reinterpret_cast<cudaStream_t>(stream_ptr), rois_num_, group_size_);
static_cast<cudaStream_t>(stream_ptr), rois_num_, group_size_);
return true;
}

View File

@ -29,14 +29,6 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/psroi_pooling_v2_impl.cuh"
namespace mindspore {
namespace kernel {
#define INPUT_NUM 2
#define OUTPUT_NUM 1
#define OUT_PUT_SHAPE_SIZE 4
#define DY_SHAPE_SIZE 4
#define DX_SHAPE_SIZE 4
#define ROI_SHAPE_SIZE 3
#define ROIS_NUM_INDEX 2
class PSROIPoolingBackV2GpuKernelMod : public NativeGpuKernelMod {
public:
PSROIPoolingBackV2GpuKernelMod() = default;