forked from mindspore-Ecosystem/mindspore
!48015 ResizeBicubic/ResiezeBicubicGrad 对标torch,format修改为NCHW
Merge pull request !48015 from haozhang/resize_bicubic
This commit is contained in:
commit
9db1e218b3
|
@ -13,11 +13,11 @@ mindspore.ops.ResizeBicubic
|
|||
- **half_pixel_centers** (bool,可选) - 是否使用半像素中心对齐。如果设置为True,那么 `align_corners` 应该设置为False。默认值:False。
|
||||
|
||||
输入:
|
||||
- **images** (Tensor) -输入图像为四维的Tensor,其shape为 :math:`(batch, height, width, channels)` ,支持的数据类型有:int8、int16、int32、int64、float16、float32、float64、uint8和uint16。
|
||||
- **images** (Tensor) -输入图像为四维的Tensor,其shape为 :math:`(batch, channels, height, width)` ,支持的数据类型有:int8、int16、int32、int64、float16、float32、float64、uint8和uint16。
|
||||
- **size** (Tensor) - 必须为含有两个元素的一维的Tensor,分别为new_height, new_width,表示输出图像的高和宽。支持的数据类型为int32。
|
||||
|
||||
输出:
|
||||
Tensor,调整大小后的图像。shape为 :math:`(batch, new\_height, new\_width, channels)` 的四维Tensor,数据类型为float32。
|
||||
Tensor,调整大小后的图像。shape为 :math:`(batch, channels, new\_height, new\_width)` 的四维Tensor,数据类型为float32。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `images` 的数据类型不支持。
|
||||
|
|
|
@ -27,10 +27,6 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kResizeBicubicInputsNum = 2;
|
||||
constexpr size_t kResizeBicubicOutputsNum = 1;
|
||||
constexpr size_t kResizeBicubicInputs0ShapeSize = 4;
|
||||
constexpr size_t kResizeBicubicInputs1ShapeSize = 1;
|
||||
constexpr size_t kResizeBicubicInputs1Dim = 2;
|
||||
constexpr size_t kResizeBicubicAttrSize = 2;
|
||||
constexpr int64_t cached_values_hand_max = 4;
|
||||
constexpr size_t caseid2 = 2;
|
||||
constexpr size_t caseid3 = 3;
|
||||
|
@ -47,10 +43,10 @@ bool half_pixel_centers = false;
|
|||
struct ResizerState {
|
||||
void CalculateSize_kernel_node(const std::vector<KernelTensorPtr> &inputs) {
|
||||
shape0 = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
batch_size = shape0[0];
|
||||
in_height = shape0[1];
|
||||
in_width = shape0[kIndex2];
|
||||
channels = shape0[kIndex3];
|
||||
batch_size = shape0[kIndex0];
|
||||
channels = shape0[kIndex1];
|
||||
in_height = shape0[kIndex2];
|
||||
in_width = shape0[kIndex3];
|
||||
}
|
||||
void CalculateSize_inputs(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
auto *input_addr = static_cast<int32_t *>(inputs[1]->addr);
|
||||
|
@ -59,7 +55,7 @@ struct ResizerState {
|
|||
|
||||
out_hw_size = out_height * out_width;
|
||||
in_hw_size = in_height * in_width;
|
||||
bhwc_size = in_hw_size * channels * batch_size;
|
||||
bchw_size = in_hw_size * channels * batch_size;
|
||||
height_scale = Scaling(in_height, out_height, align_corners);
|
||||
width_scale = Scaling(in_width, out_width, align_corners);
|
||||
}
|
||||
|
@ -73,7 +69,7 @@ struct ResizerState {
|
|||
float width_scale;
|
||||
int64_t out_hw_size;
|
||||
int64_t in_hw_size;
|
||||
int64_t bhwc_size;
|
||||
int64_t bchw_size;
|
||||
};
|
||||
ResizerState sta;
|
||||
} // namespace
|
||||
|
@ -210,13 +206,6 @@ static void ComputeXWeightsAndIndices(const ResizerState &resizer_state, const b
|
|||
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2, x_wai.index_3);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < resizer_state.out_width; ++x) {
|
||||
(*x_wais)[static_cast<size_t>(x)].index_0 *= resizer_state.channels;
|
||||
(*x_wais)[static_cast<size_t>(x)].index_1 *= resizer_state.channels;
|
||||
(*x_wais)[static_cast<size_t>(x)].index_2 *= resizer_state.channels;
|
||||
(*x_wais)[static_cast<size_t>(x)].index_3 *= resizer_state.channels;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -227,10 +216,9 @@ inline float Interpolate1D(const float weight_0, const float weight_1, const flo
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static float ComputeYInterpolation(int which, int channel_num, const WeightsAndIndices &y_wai, const T *y_ptr_0,
|
||||
const T *y_ptr_1, const T *y_ptr_2, const T *y_ptr_3,
|
||||
const WeightsAndIndices &x_wai) {
|
||||
int x_index;
|
||||
static float ComputeYInterpolation(int which, const WeightsAndIndices &y_wai, const T *y_ptr_0, const T *y_ptr_1,
|
||||
const T *y_ptr_2, const T *y_ptr_3, const WeightsAndIndices &x_wai) {
|
||||
int x_index; // w
|
||||
switch (which) {
|
||||
case 0:
|
||||
x_index = x_wai.index_0;
|
||||
|
@ -245,103 +233,80 @@ static float ComputeYInterpolation(int which, int channel_num, const WeightsAndI
|
|||
x_index = x_wai.index_3;
|
||||
break;
|
||||
}
|
||||
const int64_t pt_index = x_index + channel_num;
|
||||
return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2, y_wai.weight_3, y_ptr_0[pt_index],
|
||||
y_ptr_1[pt_index], y_ptr_2[pt_index], y_ptr_3[pt_index]);
|
||||
return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2, y_wai.weight_3, y_ptr_0[x_index],
|
||||
y_ptr_1[x_index], y_ptr_2[x_index], y_ptr_3[x_index]);
|
||||
}
|
||||
|
||||
static float Compute_1D(const float values_[4], const float xw_0, const float xw_1, const float xw_2,
|
||||
const float xw_3) {
|
||||
return Interpolate1D<float>(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1], values_[2], values_[3]);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
std::vector<float> CalSwitch(const WeightsAndIndices &x_wai, std::vector<float> cached_value, const ResizerState &RS,
|
||||
const WeightsAndIndices &y_wai, const T1 *y_ptr_0, const T1 *y_ptr_1, const T1 *y_ptr_2,
|
||||
const T1 *y_ptr_3) {
|
||||
switch (x_wai.advance) {
|
||||
case caseid3:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 0)] = cached_value[static_cast<size_t>(calnum4 * c + 1)];
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 1)] = cached_value[static_cast<size_t>(calnum4 * c + calnum2)];
|
||||
cached_value[static_cast<size_t>(calnum4 * c + calnum2)] =
|
||||
cached_value[static_cast<size_t>(calnum4 * c + calnum3)];
|
||||
}
|
||||
cached_value[static_cast<size_t>(0)] = cached_value[static_cast<size_t>(1)];
|
||||
cached_value[static_cast<size_t>(1)] = cached_value[static_cast<size_t>(calnum2)];
|
||||
cached_value[static_cast<size_t>(calnum2)] = cached_value[static_cast<size_t>(calnum3)];
|
||||
break;
|
||||
case caseid2:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 0)] = cached_value[static_cast<size_t>(calnum4 * c + calnum2)];
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 1)] = cached_value[static_cast<size_t>(calnum4 * c + calnum3)];
|
||||
}
|
||||
cached_value[static_cast<size_t>(0)] = cached_value[static_cast<size_t>(calnum2)];
|
||||
cached_value[static_cast<size_t>(1)] = cached_value[static_cast<size_t>(calnum3)];
|
||||
break;
|
||||
case 1: {
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 0)] = cached_value[static_cast<size_t>(calnum4 * c + calnum3)];
|
||||
}
|
||||
case 1:
|
||||
cached_value[static_cast<size_t>(0)] = cached_value[static_cast<size_t>(calnum3)];
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Set the remaining '4-advance' values by computing.
|
||||
switch (x_wai.advance) {
|
||||
case 0:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 0)] =
|
||||
ComputeYInterpolation(0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
|
||||
}
|
||||
case 1:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + 1)] =
|
||||
ComputeYInterpolation(1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
|
||||
}
|
||||
case caseid2:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + calnum2)] =
|
||||
ComputeYInterpolation(calnum2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
|
||||
}
|
||||
case caseid3:
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
cached_value[static_cast<size_t>(calnum4 * c + calnum3)] =
|
||||
ComputeYInterpolation(calnum3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
|
||||
}
|
||||
break;
|
||||
for (size_t i = x_wai.advance; i <= caseid3; i++) {
|
||||
cached_value[i] = ComputeYInterpolation(i, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
|
||||
}
|
||||
|
||||
return cached_value;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline void interpolate_with_caching(const T1 *input_data, const ResizerState &RS, const bool half_pixel_centers_,
|
||||
T2 output_data) {
|
||||
void ResizeBicubicCPUKernelMod::interpolate_with_caching(const T1 *input_data, const bool half_pixel_centers_,
|
||||
T2 *output_data) {
|
||||
const ResizerState &RS = sta;
|
||||
std::vector<WeightsAndIndices> x_wais(RS.out_width);
|
||||
ComputeXWeightsAndIndices(RS, half_pixel_centers_, &x_wais);
|
||||
const int64_t in_row_width = RS.in_width * RS.channels;
|
||||
const int64_t in_batch_width = RS.in_height * in_row_width;
|
||||
const T1 *input_b_ptr = input_data;
|
||||
float *output_y_ptr = output_data;
|
||||
// std::vector<float> cached_value(RS.channels == calnum3 ? 0 : calnum4 * RS.channels, 0);
|
||||
std::vector<float> cached_value(calnum4 * RS.channels, 0);
|
||||
for (int64_t b = 0; b < RS.batch_size; ++b, input_b_ptr += in_batch_width) {
|
||||
for (int64_t y = 0; y < RS.out_height; ++y, output_y_ptr += RS.out_width * RS.channels) {
|
||||
const int64_t in_row_width = RS.in_width * RS.in_height; // hw
|
||||
const int64_t in_batch_width = RS.channels * in_row_width; // chw
|
||||
const int64_t out_ch = RS.out_height * RS.channels;
|
||||
const int64_t out_chw = out_ch * RS.out_width;
|
||||
const int64_t out_hw = RS.out_height * RS.out_width;
|
||||
const size_t parallel_num = static_cast<size_t>(out_ch * RS.batch_size);
|
||||
std::vector<float> cached_value(calnum4, 0);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) { // nch
|
||||
const int64_t b = i / out_ch, c = i % out_ch / RS.out_height, y = i % RS.out_height;
|
||||
WeightsAndIndices y_wai;
|
||||
if (half_pixel_centers_) {
|
||||
GetWeightsAndIndices<HalfPixelScaler, true>(RS.height_scale, y, RS.in_height, &y_wai);
|
||||
} else {
|
||||
GetWeightsAndIndices<LegacyScaler, false>(RS.height_scale, y, RS.in_height, &y_wai);
|
||||
}
|
||||
// Make pointers represent offsets of data in input_b_ptr.
|
||||
const T1 *y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
|
||||
const T1 *y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
|
||||
const T1 *y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
|
||||
const T1 *y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
|
||||
const T1 *input_b_ptr = input_data + b * in_batch_width + c * in_row_width;
|
||||
T2 *output_y_ptr = output_data + b * out_chw + c * out_hw + y * RS.out_width;
|
||||
// Make pointers represent offsets of data in input_b_ptr
|
||||
const T1 *y_ptr_0 = input_b_ptr + y_wai.index_0 * RS.in_width;
|
||||
const T1 *y_ptr_1 = input_b_ptr + y_wai.index_1 * RS.in_width;
|
||||
const T1 *y_ptr_2 = input_b_ptr + y_wai.index_2 * RS.in_width;
|
||||
const T1 *y_ptr_3 = input_b_ptr + y_wai.index_3 * RS.in_width;
|
||||
for (int64_t x = 0; x < RS.out_width; ++x) {
|
||||
const WeightsAndIndices &x_wai = x_wais[static_cast<size_t>(x)];
|
||||
cached_value = CalSwitch(x_wai, cached_value, RS, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3);
|
||||
for (int64_t c = 0; c < RS.channels; ++c) {
|
||||
output_y_ptr[x * RS.channels + c] =
|
||||
Compute_1D(&cached_value[static_cast<size_t>(calnum4 * c)], x_wai.weight_0, x_wai.weight_1, x_wai.weight_2,
|
||||
x_wai.weight_3);
|
||||
}
|
||||
output_y_ptr[x] =
|
||||
Compute_1D(cached_value.data(), x_wai.weight_0, x_wai.weight_1, x_wai.weight_2, x_wai.weight_3);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, parallel_num, this, ¶llel_search_info_);
|
||||
return;
|
||||
}
|
||||
|
||||
bool ResizeBicubicCPUKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -390,11 +355,16 @@ bool ResizeBicubicCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
auto input0_addr = static_cast<T1 *>(inputs[0]->addr);
|
||||
sta.CalculateSize_inputs(inputs);
|
||||
if (sta.out_height == sta.in_height && sta.out_width == sta.in_width) {
|
||||
for (int64_t i = 0; i < sta.bhwc_size; ++i) {
|
||||
output_addr[i] = static_cast<float>(input0_addr[i]);
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output_addr[i] = static_cast<T2>(input0_addr[i]);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, static_cast<size_t>(sta.bchw_size), this, ¶llel_search_info_);
|
||||
} else {
|
||||
interpolate_with_caching(input0_addr, half_pixel_centers, output_addr);
|
||||
}
|
||||
interpolate_with_caching(input0_addr, sta, half_pixel_centers, output_addr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,10 @@ class ResizeBicubicCPUKernelMod : public NativeCpuKernelMod {
|
|||
private:
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline void interpolate_with_caching(const T1 *input_data, const bool half_pixel_centers_, T2 *output_data);
|
||||
|
||||
using ResizeBicubicFunc = std::function<bool(ResizeBicubicCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, ResizeBicubicFunc>> func_list_;
|
||||
|
|
|
@ -26,10 +26,7 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kResizeBicubicGradInputsNum = 2;
|
||||
constexpr size_t kResizeBicubicGradOutputNum = 1;
|
||||
constexpr size_t kResizeBicubicGradInputs0ShapeSize = 4;
|
||||
constexpr size_t kResizeBicubicGradInputs1ShapeSize = 4;
|
||||
constexpr int64_t cached_values_hand_max = 4;
|
||||
constexpr size_t caseid2 = 2;
|
||||
constexpr size_t caseid3 = 3;
|
||||
constexpr int64_t calnum8 = 8;
|
||||
constexpr int64_t calnum5 = 5;
|
||||
|
@ -42,18 +39,26 @@ std::vector<int64_t> shape0;
|
|||
std::vector<int64_t> shape1;
|
||||
bool align_corners = false;
|
||||
bool half_pixel_centers = false;
|
||||
int64_t origin_chw;
|
||||
int64_t origin_hw;
|
||||
int64_t resized_chw;
|
||||
int64_t resized_hw;
|
||||
} // namespace
|
||||
|
||||
struct ResizerGradState {
|
||||
void CalculateSize(const std::vector<int64_t> &shape0, const std::vector<int64_t> &shape1) {
|
||||
batch_size = shape0[0];
|
||||
channels = shape0[kIndex3];
|
||||
resized_height = shape0[1];
|
||||
resized_width = shape0[kIndex2];
|
||||
original_height = shape1[1];
|
||||
original_width = shape1[kIndex2];
|
||||
batch_size = shape0[kIndex0];
|
||||
channels = shape0[kIndex1];
|
||||
resized_height = shape0[kIndex2];
|
||||
resized_width = shape0[kIndex3];
|
||||
original_height = shape1[kIndex2];
|
||||
original_width = shape1[kIndex3];
|
||||
height_scale = Scaling(original_height, resized_height, align_corners);
|
||||
width_scale = Scaling(original_width, resized_width, align_corners);
|
||||
origin_chw = channels * original_height * original_width;
|
||||
origin_hw = original_height * original_width;
|
||||
resized_chw = resized_height * resized_width * channels;
|
||||
resized_hw = resized_height * resized_width;
|
||||
}
|
||||
int64_t batch_size;
|
||||
int64_t channels;
|
||||
|
@ -83,6 +88,7 @@ struct HalfPixelScalerGrad {
|
|||
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
||||
}
|
||||
};
|
||||
|
||||
struct LegacyScalerGrad {
|
||||
LegacyScalerGrad() {}
|
||||
inline float operator()(const int64_t x, const float scale) const { return static_cast<float>(x) * scale; }
|
||||
|
@ -106,19 +112,9 @@ class CachedInterpolationCalculator {
|
|||
cached_values_hand++;
|
||||
}
|
||||
}
|
||||
switch (new_indices_hand) {
|
||||
case 0:
|
||||
indexes_[0] = x_0;
|
||||
break;
|
||||
case 1:
|
||||
indexes_[1] = x_1;
|
||||
break;
|
||||
case caseid2:
|
||||
indexes_[kIndex2] = x_2;
|
||||
break;
|
||||
case caseid3:
|
||||
indexes_[kIndex3] = x_3;
|
||||
break;
|
||||
std::vector<int64_t> values = {x_0, x_1, x_2, x_3};
|
||||
for (size_t i = new_indices_hand; i <= caseid3; ++i) {
|
||||
indexes_[i] = values[i];
|
||||
}
|
||||
return new_indices_hand;
|
||||
}
|
||||
|
@ -208,22 +204,20 @@ static void ComputeGradientXWeightsAndIndices(const ResizerGradState &RGS, const
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t Calindex(const ResizerGradState &RGS, const int64_t &x1, const int64_t &x2, const int64_t &x3,
|
||||
const int64_t &x4, bool flag_) {
|
||||
if (!flag_) {
|
||||
return static_cast<int64_t>(static_cast<int64_t>(x1 * RGS.original_height * RGS.original_width * RGS.channels) +
|
||||
static_cast<int64_t>(x2 * RGS.original_width * RGS.channels) +
|
||||
static_cast<int64_t>(x3 * RGS.channels) + static_cast<int64_t>(x4));
|
||||
return x1 * origin_chw + x2 * origin_hw + x3 * RGS.original_width + x4;
|
||||
} else {
|
||||
return static_cast<int64_t>(static_cast<int64_t>(x1 * RGS.resized_height * RGS.resized_width * RGS.channels) +
|
||||
static_cast<int64_t>(x2 * RGS.resized_width * RGS.channels) +
|
||||
static_cast<int64_t>(x3 * RGS.channels) + static_cast<int64_t>(x4));
|
||||
return x1 * resized_chw + x2 * resized_hw + x3 * RGS.resized_width + x4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ResizeCommomCalc(const ResizerGradState &RGS, const bool half_pixel_centers,
|
||||
const std::vector<WeightsAndIndices> &x_wais, const bool flag, const float *input_grad,
|
||||
T *output_grad, int64_t b, int64_t y) {
|
||||
T *output_grad, int64_t b, int64_t c, int64_t y) {
|
||||
WeightsAndIndices y_wai;
|
||||
if (half_pixel_centers) {
|
||||
GetWeightsAndIndicesGrad<HalfPixelScalerGrad, true>(RGS.height_scale, y, RGS.original_height, &y_wai);
|
||||
|
@ -232,46 +226,46 @@ void ResizeCommomCalc(const ResizerGradState &RGS, const bool half_pixel_centers
|
|||
}
|
||||
for (int64_t x = 0; x < RGS.resized_width; ++x) {
|
||||
const WeightsAndIndices &x_wai = x_wais[static_cast<size_t>(x)];
|
||||
for (int64_t c = 0; c < RGS.channels; ++c) {
|
||||
T curr_input_grad = input_grad[Calindex(RGS, b, y, x, c, flag)];
|
||||
// row 0 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, y_wai.index_0, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_0, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_0, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_0, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
|
||||
float curr_input_grad = input_grad[Calindex(RGS, b, c, y, x, flag)];
|
||||
// row 0 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_0, x_wai.index_0, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_0, x_wai.index_1, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_0, x_wai.index_2, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_0, x_wai.index_3, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
|
||||
|
||||
// row 1 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, y_wai.index_1, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_1, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_1, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_1, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
|
||||
// row 2 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, y_wai.index_2, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_2, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_2, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_2, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
|
||||
// row 3 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, y_wai.index_3, x_wai.index_0, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_3, x_wai.index_1, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_3, x_wai.index_2, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, y_wai.index_3, x_wai.index_3, c, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
|
||||
}
|
||||
// row 1 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_1, x_wai.index_0, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_1, x_wai.index_1, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_1, x_wai.index_2, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_1, x_wai.index_3, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
|
||||
|
||||
// row 2 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_2, x_wai.index_0, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_2, x_wai.index_1, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_2, x_wai.index_2, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_2, x_wai.index_3, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
|
||||
|
||||
// row 3 of 0, 1, 2, 3
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_3, x_wai.index_0, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_3, x_wai.index_1, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_3, x_wai.index_2, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
|
||||
output_grad[Calindex(RGS, b, c, y_wai.index_3, x_wai.index_3, !flag)] +=
|
||||
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -280,8 +274,10 @@ void CalNonUtil(const ResizerGradState &RGS, const bool half_pixel_centers,
|
|||
const std::vector<WeightsAndIndices> &x_wais, const bool flag, const float *input_grad,
|
||||
T *output_grad) {
|
||||
for (int64_t b = 0; b < RGS.batch_size; ++b) {
|
||||
for (int64_t y = 0; y < RGS.resized_height; ++y) {
|
||||
ResizeCommomCalc(RGS, half_pixel_centers, x_wais, flag, input_grad, output_grad, b, y);
|
||||
for (int64_t c = 0; c < RGS.channels; ++c) {
|
||||
for (int64_t y = 0; y < RGS.resized_height; ++y) {
|
||||
ResizeCommomCalc(RGS, half_pixel_centers, x_wais, flag, input_grad, output_grad, b, c, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -297,14 +293,15 @@ inline void ResizeBicubicGrad(const float *input_grad, const ResizerGradState &R
|
|||
utils_flag = true;
|
||||
}
|
||||
if (utils_flag) {
|
||||
for (int64_t b = 0; b < RGS.batch_size; ++b) {
|
||||
auto task = [&](int64_t start, int64_t end) {
|
||||
for (int64_t y = start; y < end; ++y) {
|
||||
ResizeCommomCalc(RGS, half_pixel_centers_, x_wais, flag, input_grad, output_grad, b, y);
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, static_cast<size_t>(RGS.resized_height));
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const int64_t b = i / (RGS.channels * RGS.resized_height), c = i / RGS.resized_height % RGS.channels;
|
||||
const int64_t y = i % RGS.resized_height;
|
||||
ResizeCommomCalc(RGS, half_pixel_centers_, x_wais, flag, input_grad, output_grad, b, c, y);
|
||||
}
|
||||
};
|
||||
const size_t parallel_num = static_cast<size_t>(RGS.batch_size * RGS.channels * RGS.resized_height);
|
||||
CPUKernelUtils::ParallelFor(task, parallel_num);
|
||||
} else {
|
||||
CalNonUtil(RGS, half_pixel_centers_, x_wais, flag, input_grad, output_grad);
|
||||
}
|
||||
|
|
|
@ -56,9 +56,9 @@ class ResizeBicubicHelperGpuKernel : public GpuKernelHelperBase {
|
|||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
constexpr int INPUT_W_ORDER = 2;
|
||||
constexpr int OUTPUT_W_ORDER = 2;
|
||||
constexpr int INPUT_C_ORDER = 3;
|
||||
constexpr int INPUT_C_ORDER = 1;
|
||||
constexpr int INPUT_H_ORDER = 2;
|
||||
constexpr int INPUT_W_ORDER = 3;
|
||||
ResetResource();
|
||||
align_corners_ = false;
|
||||
is_null_resizebicubic_input_ = false;
|
||||
|
@ -86,11 +86,11 @@ class ResizeBicubicHelperGpuKernel : public GpuKernelHelperBase {
|
|||
return inp_flag;
|
||||
}
|
||||
batch_ = input_shape_[0];
|
||||
inputheight_ = input_shape_[1];
|
||||
inputwidth_ = input_shape_[INPUT_W_ORDER];
|
||||
channel_ = input_shape_[INPUT_C_ORDER];
|
||||
outputheight_ = output_shapesize_[1];
|
||||
outputwidth_ = output_shapesize_[OUTPUT_W_ORDER];
|
||||
inputheight_ = input_shape_[INPUT_H_ORDER];
|
||||
inputwidth_ = input_shape_[INPUT_W_ORDER];
|
||||
outputheight_ = output_shapesize_[INPUT_H_ORDER];
|
||||
outputwidth_ = output_shapesize_[INPUT_W_ORDER];
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<S>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
|
|
|
@ -56,9 +56,9 @@ class ResizeBicubicGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
constexpr int INPUT_W_ORDER = 2;
|
||||
constexpr int OUTPUT_W_ORDER = 2;
|
||||
constexpr int INPUT_C_ORDER = 3;
|
||||
constexpr int INPUT_C_ORDER = 1;
|
||||
constexpr int INPUT_H_ORDER = 2;
|
||||
constexpr int INPUT_W_ORDER = 3;
|
||||
ResetResource();
|
||||
align_corners_ = false;
|
||||
is_null_resizebicubic_grad_input_ = false;
|
||||
|
@ -86,11 +86,11 @@ class ResizeBicubicGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
return inp_flag;
|
||||
}
|
||||
batch_ = input_grad_shape_[0];
|
||||
input_grad_height_ = input_grad_shape_[1];
|
||||
input_grad_width_ = input_grad_shape_[INPUT_W_ORDER];
|
||||
channel_ = input_grad_shape_[INPUT_C_ORDER];
|
||||
origin_height_ = origin_shape_[1];
|
||||
origin_width_ = origin_shape_[OUTPUT_W_ORDER];
|
||||
input_grad_height_ = input_grad_shape_[INPUT_H_ORDER];
|
||||
input_grad_width_ = input_grad_shape_[INPUT_W_ORDER];
|
||||
origin_height_ = origin_shape_[INPUT_H_ORDER];
|
||||
origin_width_ = origin_shape_[INPUT_W_ORDER];
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<S>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
|
|
|
@ -30,8 +30,8 @@ __device__ int Bounds(int access, int limit) {
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubicGradSame(const T *input, S *output, int nhwc) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
__global__ void ResizeBicubicGradSame(const T *input, S *output, int nchw) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
S val = input[pos];
|
||||
output[pos] = val;
|
||||
return;
|
||||
|
@ -40,13 +40,13 @@ __global__ void ResizeBicubicGradSame(const T *input, S *output, int nhwc) {
|
|||
|
||||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubicGrad(const T *input, const S A, const int n, const int c, const int grad_h,
|
||||
const int grad_w, const int origin_h, const int origin_w, const int nhwc,
|
||||
const int hwc, const int wc, const float h_scale, const float w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / hwc;
|
||||
int posc = pos % c;
|
||||
int posh = pos / wc % grad_h;
|
||||
int posw = pos / c % grad_w;
|
||||
const int grad_w, const int origin_h, const int origin_w, const int nchw,
|
||||
const int chw, const int hw, const float h_scale, const float w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / chw;
|
||||
int posc = pos / hw % c;
|
||||
int posh = pos / grad_w % grad_h;
|
||||
int posw = pos % grad_w;
|
||||
S posw_scaled = 0;
|
||||
S posh_scaled = 0;
|
||||
posw_scaled = w_scale * posw;
|
||||
|
@ -86,8 +86,8 @@ __global__ void ResizeBicubicGrad(const T *input, const S A, const int n, const
|
|||
for (int m = 0; m < 4; m++) {
|
||||
access_h = Bounds(h_low - 1 + k, origin_h);
|
||||
access_w = Bounds(w_low - 1 + m, origin_w);
|
||||
input_start = origin_w * c * (posn * origin_h + access_h);
|
||||
temp = input_start + (access_w * c) + posc;
|
||||
input_start = origin_w * origin_h * (c * posn + posc) + access_h * origin_w;
|
||||
temp = input_start + access_w;
|
||||
MsAtomicAdd(&output[temp], value * y_coeffs[k] * x_coeffs[m]);
|
||||
}
|
||||
}
|
||||
|
@ -98,13 +98,13 @@ __global__ void ResizeBicubicGrad(const T *input, const S A, const int n, const
|
|||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubicGradHalfPixelCenters(const T *input, const S A, const int n, const int c, const int grad_h,
|
||||
const int grad_w, const int origin_h, const int origin_w,
|
||||
const int nhwc, const int hwc, const int wc, const float h_scale,
|
||||
const int nchw, const int chw, const int hw, const float h_scale,
|
||||
const float w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / hwc;
|
||||
int posc = pos % c;
|
||||
int posh = pos / wc % grad_h;
|
||||
int posw = pos / c % grad_w;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / chw;
|
||||
int posc = pos / hw % c;
|
||||
int posh = pos / grad_w % grad_h;
|
||||
int posw = pos % grad_w;
|
||||
S posw_scaled = 0;
|
||||
S posh_scaled = 0;
|
||||
posw_scaled = (static_cast<S>(posw) + static_cast<S>(0.5)) * w_scale - static_cast<S>(0.5);
|
||||
|
@ -177,8 +177,8 @@ __global__ void ResizeBicubicGradHalfPixelCenters(const T *input, const S A, con
|
|||
for (int m = 0; m < 4; m++) {
|
||||
access_h = Bounds(h_low - 1 + k, origin_h);
|
||||
access_w = Bounds(w_low - 1 + m, origin_w);
|
||||
input_start = origin_w * c * (posn * origin_h + access_h);
|
||||
temp = input_start + (access_w * c) + posc;
|
||||
input_start = origin_w * origin_h * (c * posn + posc) + access_h * origin_w;
|
||||
temp = input_start + access_w;
|
||||
MsAtomicAdd(&output[temp], value * y_coeffs[k] * x_coeffs[m]);
|
||||
}
|
||||
}
|
||||
|
@ -186,38 +186,29 @@ __global__ void ResizeBicubicGradHalfPixelCenters(const T *input, const S A, con
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void InitZero(S *output, const int origin_size) {
|
||||
for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (origin_size); pos += gridDim.x * blockDim.x) {
|
||||
output[pos] = static_cast<S>(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalResizeBicubicGrad(const T *input, const int n, const int c, const int grad_h, const int grad_w,
|
||||
const int origin_h, const int origin_w, const float h_scale, const float w_scale, S *output,
|
||||
bool half_pixel_centers, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const int wc = grad_w * c;
|
||||
const int hwc = grad_h * wc;
|
||||
const int nhwc = n * hwc;
|
||||
const int hw = grad_w * grad_h;
|
||||
const int chw = c * hw;
|
||||
const int nchw = n * chw;
|
||||
const int origin_size = n * c * origin_h * origin_w;
|
||||
cudaMemset(static_cast<void *>(output), 0, sizeof(S) * origin_size);
|
||||
if (origin_h == grad_h && origin_w == grad_w) {
|
||||
InitZero<<<CUDA_BLOCKS(device_id, origin_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(output, origin_size);
|
||||
ResizeBicubicGradSame<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, output,
|
||||
nhwc);
|
||||
ResizeBicubicGradSame<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, output,
|
||||
nchw);
|
||||
return;
|
||||
}
|
||||
S A = -0.75;
|
||||
if (half_pixel_centers == true) {
|
||||
A = -0.5;
|
||||
InitZero<<<CUDA_BLOCKS(device_id, origin_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(output, origin_size);
|
||||
ResizeBicubicGradHalfPixelCenters<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, grad_h, grad_w, origin_h, origin_w, nhwc, hwc, wc, h_scale, w_scale, output);
|
||||
ResizeBicubicGradHalfPixelCenters<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, grad_h, grad_w, origin_h, origin_w, nchw, chw, hw, h_scale, w_scale, output);
|
||||
return;
|
||||
} else {
|
||||
InitZero<<<CUDA_BLOCKS(device_id, origin_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(output, origin_size);
|
||||
ResizeBicubicGrad<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, grad_h, grad_w, origin_h, origin_w, nhwc, hwc, wc, h_scale, w_scale, output);
|
||||
ResizeBicubicGrad<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, grad_h, grad_w, origin_h, origin_w, nchw, chw, hw, h_scale, w_scale, output);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,13 +31,13 @@ __device__ int Bound(int access, int limit) {
|
|||
|
||||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubic(const T *input, const float A, const int n, const int c, const int input_h,
|
||||
const int input_w, const int output_h, const int output_w, const int nhwc, const int hwc,
|
||||
const int wc, const S h_scale, const S w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / hwc;
|
||||
int posc = pos % c;
|
||||
int posh = pos / wc % output_h;
|
||||
int posw = pos / c % output_w;
|
||||
const int input_w, const int output_h, const int output_w, const int nchw, const int chw,
|
||||
const int hw, const S h_scale, const S w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / chw;
|
||||
int posc = pos / hw % c;
|
||||
int posh = pos / output_w % output_h;
|
||||
int posw = pos % output_w;
|
||||
float posw_scaled = 0;
|
||||
float posh_scaled = 0;
|
||||
posw_scaled = w_scale * posw;
|
||||
|
@ -65,14 +65,14 @@ __global__ void ResizeBicubic(const T *input, const float A, const int n, const
|
|||
for (int k = 0; k < 4; k++) {
|
||||
access_h = Bound(h_low - 1 + k, input_h);
|
||||
access_w = Bound(w_low - 1, input_w);
|
||||
input_start = input_w * c * (posn * input_h + access_h);
|
||||
temp0 = input[input_start + (access_w * c) + posc];
|
||||
input_start = input_w * input_h * (c * posn + posc) + access_h * input_w;
|
||||
temp0 = input[input_start + access_w];
|
||||
access_w = Bound(w_low, input_w);
|
||||
temp1 = input[input_start + (access_w * c) + posc];
|
||||
temp1 = input[input_start + access_w];
|
||||
access_w = Bound(w_low + 1, input_w);
|
||||
temp2 = input[input_start + (access_w * c) + posc];
|
||||
temp2 = input[input_start + access_w];
|
||||
access_w = Bound(w_low + 2, input_w);
|
||||
temp3 = input[input_start + (access_w * c) + posc];
|
||||
temp3 = input[input_start + access_w];
|
||||
coefficients[k] = coeffs0 * temp0 + coeffs1 * temp1 + coeffs2 * temp2 + coeffs3 * temp3;
|
||||
}
|
||||
const int64_t offset_h = lrintf(h_alpha * 1024);
|
||||
|
@ -95,13 +95,13 @@ __global__ void ResizeBicubic(const T *input, const float A, const int n, const
|
|||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubicHalfPixelCenters(const T *input, const float A, const int n, const int c,
|
||||
const int input_h, const int input_w, const int output_h,
|
||||
const int output_w, const int nhwc, const int hwc, const int wc,
|
||||
const int output_w, const int nchw, const int chw, const int hw,
|
||||
const S h_scale, const S w_scale, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / hwc;
|
||||
int posc = pos % c;
|
||||
int posh = pos / wc % output_h;
|
||||
int posw = pos / c % output_w;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
int posn = pos / chw;
|
||||
int posc = pos / hw % c;
|
||||
int posh = pos / output_w % output_h;
|
||||
int posw = pos % output_w;
|
||||
float posw_scaled = 0;
|
||||
float posh_scaled = 0;
|
||||
posw_scaled = (static_cast<float>(posw) + static_cast<float>(0.5)) * w_scale - static_cast<float>(0.5);
|
||||
|
@ -145,14 +145,14 @@ __global__ void ResizeBicubicHalfPixelCenters(const T *input, const float A, con
|
|||
for (int k = 0; k < 4; k++) {
|
||||
access_h = Bound(h_low - 1 + k, input_h);
|
||||
access_w = Bound(w_low - 1, input_w);
|
||||
input_start = input_w * c * (posn * input_h + access_h);
|
||||
temp0 = input[input_start + (access_w * c) + posc];
|
||||
input_start = input_w * input_h * (c * posn + posc) + access_h * input_w;
|
||||
temp0 = input[input_start + access_w];
|
||||
access_w = Bound(w_low, input_w);
|
||||
temp1 = input[input_start + (access_w * c) + posc];
|
||||
temp1 = input[input_start + access_w];
|
||||
access_w = Bound(w_low + 1, input_w);
|
||||
temp2 = input[input_start + (access_w * c) + posc];
|
||||
temp2 = input[input_start + access_w];
|
||||
access_w = Bound(w_low + 2, input_w);
|
||||
temp3 = input[input_start + (access_w * c) + posc];
|
||||
temp3 = input[input_start + access_w];
|
||||
coefficients[k] = coeffs0 * temp0 + coeffs1 * temp1 + coeffs2 * temp2 + coeffs3 * temp3;
|
||||
}
|
||||
const int64_t offset_h = lrintf(h_alpha * 1024);
|
||||
|
@ -189,8 +189,8 @@ __global__ void ResizeBicubicHalfPixelCenters(const T *input, const float A, con
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ResizeBicubicSame(const T *input, S *output, int nhwc) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nhwc; pos += gridDim.x * blockDim.x) {
|
||||
__global__ void ResizeBicubicSame(const T *input, S *output, int nchw) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += gridDim.x * blockDim.x) {
|
||||
S val = input[pos];
|
||||
output[pos] = val;
|
||||
return;
|
||||
|
@ -201,22 +201,22 @@ template <typename T, typename S>
|
|||
void CalResizeBicubic(const T *input, const int n, const int c, const int input_h, const int input_w,
|
||||
const int output_h, const int output_w, const S h_scale, const S w_scale, S *output,
|
||||
bool half_pixel_centers, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const int wc = output_w * c;
|
||||
const int hwc = output_h * wc;
|
||||
const int nhwc = n * hwc;
|
||||
const int hw = output_h * output_w;
|
||||
const int chw = c * hw;
|
||||
const int nchw = n * chw;
|
||||
if (input_h == output_h && input_w == output_w) {
|
||||
ResizeBicubicSame<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, output, nhwc);
|
||||
ResizeBicubicSame<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, output, nchw);
|
||||
return;
|
||||
}
|
||||
float A = -0.75;
|
||||
if (half_pixel_centers == true) {
|
||||
A = -0.5;
|
||||
ResizeBicubicHalfPixelCenters<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, input_h, input_w, output_h, output_w, nhwc, hwc, wc, h_scale, w_scale, output);
|
||||
ResizeBicubicHalfPixelCenters<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output);
|
||||
return;
|
||||
} else {
|
||||
ResizeBicubic<<<CUDA_BLOCKS(device_id, nhwc), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, input_h, input_w, output_h, output_w, nhwc, hwc, wc, h_scale, w_scale, output);
|
||||
ResizeBicubic<<<CUDA_BLOCKS(device_id, nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, A, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -269,8 +269,8 @@ template CUDA_LIB_EXPORT void CalResizeBicubic<float, float>(const float *input,
|
|||
bool half_pixel_centers, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeBicubic<double, float>(const double *input, const int n, const int c,
|
||||
const int input_h, const int input_w, const int output_h,
|
||||
const int output_w, const float h_scale,
|
||||
const float w_scale, float *output,
|
||||
bool half_pixel_centers, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
const int input_h, const int input_w, const int output_h,
|
||||
const int output_w, const float h_scale,
|
||||
const float w_scale, float *output,
|
||||
bool half_pixel_centers, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -54,6 +54,7 @@ bool ResizeBicubicGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs
|
|||
bool ResizeBicubicGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto kernel_grad_ptr = std::dynamic_pointer_cast<ops::ResizeBicubicGrad>(base_operator);
|
||||
kernel_name_ = kernel_grad_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t index3 = 3;
|
||||
constexpr size_t num4 = 4;
|
||||
abstract::ShapePtr ResizeBicubicGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -47,16 +46,18 @@ abstract::ShapePtr ResizeBicubicGradInferShape(const PrimitivePtr &primitive,
|
|||
prim_name);
|
||||
}
|
||||
if (!is_dynamic) {
|
||||
if (grads_shape[0] != original_image_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of grads_shape[0] is " << grads_shape[0]
|
||||
<< ", but the shape of original_image_shape[0] is " << original_image_shape[0]
|
||||
<< ". The first dimension of the shape of grads_shape "
|
||||
if (grads_shape[kInputIndex0] != original_image_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of grads_shape[0] is "
|
||||
<< grads_shape[kInputIndex0] << ", but the shape of original_image_shape[0] is "
|
||||
<< original_image_shape[kInputIndex0]
|
||||
<< ". The batch dimension of the shape of grads_shape "
|
||||
<< "must be equal to that of original_image_shape.";
|
||||
}
|
||||
if (grads_shape[index3] != original_image_shape[index3]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of grads_shape[3] is "
|
||||
<< grads_shape[index3] << ", but the shape of original_image_shape[3] is "
|
||||
<< original_image_shape[index3] << ". The third dimension of the shape of grads_shape "
|
||||
if (grads_shape[kInputIndex1] != original_image_shape[kInputIndex1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of grads_shape[1] is "
|
||||
<< grads_shape[kInputIndex1] << ", but the shape of original_image_shape[1] is "
|
||||
<< original_image_shape[kInputIndex1]
|
||||
<< ". The channel dimension of the shape of grads_shape "
|
||||
<< "must be equal to that of original_image_shape.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,9 +38,6 @@ abstract::ShapePtr ResizeBicubicInferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
const int64_t kMaxLen = GetValue<int64_t>(max_length_ptr);
|
||||
auto align_corners_ptr = primitive->GetAttr("align_corners");
|
||||
bool align_corners = GetValue<bool>(align_corners_ptr);
|
||||
auto half_pixel_centers_ptr = primitive->GetAttr("half_pixel_centers");
|
||||
|
@ -52,9 +49,9 @@ abstract::ShapePtr ResizeBicubicInferShape(const PrimitivePtr &primitive,
|
|||
auto shape0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (!IsDynamicRank(shape0)) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("images rank", SizeToLong(shape0.size()), kEqual, shape0_dim, prim_name);
|
||||
constexpr int64_t indexid3 = 3;
|
||||
constexpr int64_t indexid1 = 1;
|
||||
output_shape[0] = shape0[0];
|
||||
output_shape[indexid3] = shape0[indexid3];
|
||||
output_shape[indexid1] = shape0[indexid1];
|
||||
}
|
||||
|
||||
auto shape1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
|
@ -75,10 +72,8 @@ abstract::ShapePtr ResizeBicubicInferShape(const PrimitivePtr &primitive,
|
|||
for (size_t i = 0; i < size_value.size(); ++i) {
|
||||
CheckAndConvertUtils::CheckInteger("size", size_value[i], kGreaterThan, kNumZero, prim_name);
|
||||
}
|
||||
output_shape[kInputIndex1] = size_value[kInputIndex0];
|
||||
output_shape[kInputIndex2] = size_value[kInputIndex1];
|
||||
(void)CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)),
|
||||
kLessEqual, kMaxLen, prim_name);
|
||||
output_shape[kInputIndex2] = size_value[kInputIndex0];
|
||||
output_shape[kInputIndex3] = size_value[kInputIndex1];
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
|
|
|
@ -519,15 +519,13 @@ class Upsample(Cell):
|
|||
raise ValueError(
|
||||
"For 'Upsample', bicubic mode needs 4D input, but got 5D input"
|
||||
)
|
||||
transpose = P.Transpose()
|
||||
align_corners = self.align_corners if self.has_align_corners else False
|
||||
upsample = P.image_ops.ResizeBicubic(
|
||||
align_corners=align_corners,
|
||||
half_pixel_centers=not align_corners,
|
||||
)
|
||||
tensor = transpose(tensor, (0, 2, 3, 1))
|
||||
tensor = upsample(tensor, Tensor(size, dtype=mstype.int32))
|
||||
return transpose(tensor, (0, 3, 1, 2))
|
||||
return tensor
|
||||
|
||||
def run_trilinear(tensor, ndim, size):
|
||||
if ndim == 3:
|
||||
|
|
|
@ -757,14 +757,14 @@ class ResizeBicubic(Primitive):
|
|||
|
||||
|
||||
Inputs:
|
||||
- **images** (Tensor) - The input image must be a 4-D tensor of shape :math:`(batch, height, width, channels)`.
|
||||
- **images** (Tensor) - The input image must be a 4-D tensor of shape :math:`(batch, channels, height, width)`.
|
||||
The format must be NHWC.
|
||||
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
|
||||
- **size** (Tensor) - A 1-D tensor of shape [2], with 2 elements: new_height, new_width.
|
||||
Types allowed: int32.
|
||||
|
||||
Outputs:
|
||||
A 4-D tensor of shape :math:`(batch, new\_height, new\_width, channels)` with type float32.
|
||||
A 4-D tensor of shape :math:`(batch, channels, new\_height, new\_width)` with type float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `images` type is not allowed.
|
||||
|
@ -806,7 +806,6 @@ class ResizeBicubic(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, align_corners=False, half_pixel_centers=False):
|
||||
"""Initialize"""
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
validator.check_value_type('align_corners', align_corners, bool, self.name)
|
||||
validator.check_value_type('half_pixel_centers', half_pixel_centers, bool, self.name)
|
||||
self.init_prim_io_names(inputs=['images', 'size'], outputs=['y'])
|
||||
|
@ -838,13 +837,12 @@ class ResizeBicubic(Primitive):
|
|||
validator.check("size[1]", size_value[1], "minimum", 0, Rel.GT, self.name)
|
||||
|
||||
batch_size = images_shape[0]
|
||||
channel = images_shape[1]
|
||||
height = size_value[0]
|
||||
width = size_value[1]
|
||||
channel = images_shape[3]
|
||||
out_shape = (batch_size, height, width, channel)
|
||||
return {'shape': out_shape,
|
||||
'dtype': mstype.float32,
|
||||
'value': None}
|
||||
|
||||
out_shape = (batch_size, channel, height, width)
|
||||
return {'shape': out_shape, 'dtype': mstype.float32, 'value': None}
|
||||
|
||||
|
||||
class ResizeArea(Primitive):
|
||||
|
|
|
@ -3387,14 +3387,14 @@ test_case_nn_ops = [
|
|||
'desc_bprop': []}),
|
||||
('ResizeBicubic', {
|
||||
'block': ResizeBicubic(align_corners=False, half_pixel_centers=False),
|
||||
'desc_inputs': [Tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]]),
|
||||
'desc_inputs': [Tensor([[[[1., 2.], [3., 4.]]]]),
|
||||
Tensor(np.array([1, 4]).reshape(2).astype(np.int32))],
|
||||
'desc_bprop': [Tensor([[[[1.], [1.5], [2.], [2.09375]]]], mstype.float32)]}),
|
||||
'desc_bprop': [Tensor([[[[1., 1.5, 2., 2.09375]]]], mstype.float32)]}),
|
||||
('ResizeBicubicGrad', {
|
||||
'block': ResizeBicubicGrad(),
|
||||
'desc_inputs': [Tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]], mstype.float32),
|
||||
Tensor([[[[1.], [4.], [6.], [4.]]]])],
|
||||
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
|
||||
'desc_inputs': [Tensor([[[[1., 2.], [3., 4.]]]], mstype.float32),
|
||||
Tensor([[[[1., 4., 6., 4.]]]], mstype.float32)],
|
||||
'desc_bprop': [Tensor([[[[1., 4., 6., 4.]]]], mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('ResizeBilinear', {
|
||||
'block': P.ResizeBilinear((5, 5)),
|
||||
|
|
Loading…
Reference in New Issue