!35743 [MS][CPU]resize linear 1d grad cppu kernel

Merge pull request !35743 from mengyuanli/resize_linear_1d_grad_cpu
This commit is contained in:
i-robot 2022-06-13 01:54:57 +00:00 committed by Gitee
commit ea7c9daac3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 585 additions and 32 deletions

View File

@ -179,6 +179,22 @@ struct CachedInterpolation {
float lerp;
};
struct AlignCornersFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 1 ? new_x * (old_length - 1) / (new_length - 1) : 0;
}
};
struct AsymmetricFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 0 ? new_x * old_length / new_length : 0;
}
};
struct HalfPixelFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 0 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0;
}
};
void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
CachedInterpolation *interpolation);

View File

@ -17,28 +17,12 @@
#include "plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include <unordered_map>
#include "mindspore/core/ops/resize_linear_1d.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
#include "kernel/common_utils.h"
namespace {
struct AlignCornersFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 1 ? new_x * (old_length - 1) / (new_length - 1) : 0;
}
};
struct AsymmetricFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 0 ? new_x * old_length / new_length : 0;
}
};
struct HalfPixelFunc {
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
return new_length != 0 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0;
}
};
} // namespace
namespace mindspore::kernel {
constexpr auto kResizeLinear1D = "ResizeLinear1D";
constexpr const size_t kResizeLinear1DInputsNum = 2;
@ -131,19 +115,9 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DCpuKernelMod::KernelRunFun
ResizeLinear1DCpuKernelMod::CoordinateTransformationFunc ResizeLinear1DCpuKernelMod::ChooseCoordinateTransformationFunc(
CoordinateTransformationMode coordinate_transformation_mode) {
switch (coordinate_transformation_mode) {
case ALIGN_CORNERS: {
return AlignCornersFunc();
};
case HALF_PIXEL: {
return HalfPixelFunc();
};
case ASYMMETRIC: {
return AsymmetricFunc();
}
default:
return AlignCornersFunc();
}
const std::unordered_map<CoordinateTransformationMode, CoordinateTransformationFunc> coordinate_map{
{ALIGN_CORNERS, AlignCornersFunc()}, {HALF_PIXEL, HalfPixelFunc()}, {ASYMMETRIC, AsymmetricFunc()}};
return coordinate_map.at(coordinate_transformation_mode);
}
bool ResizeLinear1DCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,

View File

@ -0,0 +1,180 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include <unordered_map>
#include "mindspore/core/ops/grad/resize_linear_1d_grad.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
#include "kernel/common_utils.h"
namespace mindspore::kernel {
constexpr auto kResizeLinear1DGrad = "ResizeLinear1DGrad";
constexpr const size_t kResizeLinear1DGradInputsNum = 2;
constexpr const size_t kResizeLinear1DGradOutputsNum = 1;
void ResizeLinear1DGradCpuKernelMod::ComputeInterpolationCaches(const size_t out_size, const size_t in_size,
const CoordinateTransformationFunc &func,
CachedInterpolation *interpolation) {
interpolation[out_size].lower = 0;
interpolation[out_size].upper = 0;
for (size_t i = 0; i <= out_size - 1; ++i) {
const float in = func(i, in_size, out_size);
const float in_floor = std::floor(in);
const float in_ceil = std::ceil(in);
interpolation[i].lower = static_cast<size_t>(in_floor > 0 ? in_floor : 0);
interpolation[i].upper = static_cast<size_t>(in_ceil < static_cast<float>(in_size - 1) ? in_ceil : in_size - 1);
interpolation[i].lerp = in - in_floor;
}
}
template <typename T>
bool ResizeLinear1DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kResizeLinear1DGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kResizeLinear1DGradOutputsNum, kernel_name_);
float *grad_output = reinterpret_cast<float *>(inputs[kIndex0]->addr);
MS_ERROR_IF_NULL_W_RET_VAL(grad_output, false);
T *grad_input = reinterpret_cast<T *>(outputs[kIndex0]->addr);
MS_ERROR_IF_NULL_W_RET_VAL(grad_input, false);
if (output_width_ == input_width_) {
auto task = [grad_output, grad_input](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
grad_input[i] = static_cast<T>(grad_output[i]);
}
};
ParallelLaunchAutoSearch(task, inputs[kIndex0]->size / sizeof(float), this, &parallel_search_info_, pool_);
return true;
}
if (memset_s(grad_input, outputs[kIndex0]->size, 0, outputs[kIndex0]->size) != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output buffer memset failed.";
}
std::vector<CachedInterpolation> xs(output_width_ + 1);
ComputeInterpolationCaches(output_width_, input_width_, coordinate_transformation_func_, xs.data());
auto task = [grad_output, grad_input, xs, this](size_t start, size_t end) {
for (size_t index = start; index < end; ++index) {
for (size_t w = 0; w < output_width_; ++w) {
const size_t xs_lower = xs[w].lower;
const size_t xs_upper = xs[w].upper;
const float xs_lerp = static_cast<float>(xs[w].lerp);
*(grad_input + index * input_width_ + xs_lower) +=
static_cast<T>((*(grad_output + index * output_width_ + w)) * (1 - xs_lerp));
*(grad_input + index * input_width_ + xs_upper) +=
static_cast<T>((*(grad_output + index * output_width_ + w)) * xs_lerp);
}
}
};
ParallelLaunchAutoSearch(task, batch_ * channel_, this, &parallel_search_info_, pool_);
return true;
}
#define RESIZE_LINEAR_1D_GRAD_CPU_REG(MS_T, T) \
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(MS_T).AddOutputAttr(MS_T), \
&ResizeLinear1DGradCpuKernelMod::LaunchKernel<T>
const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRunFunc>>
&ResizeLinear1DGradCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRunFunc>> func_list = {
{RESIZE_LINEAR_1D_GRAD_CPU_REG(kNumberTypeFloat32, float)},
{RESIZE_LINEAR_1D_GRAD_CPU_REG(kNumberTypeFloat64, double)},
};
return func_list;
}
ResizeLinear1DGradCpuKernelMod::CoordinateTransformationFunc
ResizeLinear1DGradCpuKernelMod::ChooseCoordinateTransformationFunc(
CoordinateTransformationMode coordinate_transformation_mode) {
const std::unordered_map<CoordinateTransformationMode, CoordinateTransformationFunc> coordinate_map{
{ALIGN_CORNERS, AlignCornersFunc()}, {HALF_PIXEL, HalfPixelFunc()}, {ASYMMETRIC, AsymmetricFunc()}};
return coordinate_map.at(coordinate_transformation_mode);
}
bool ResizeLinear1DGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ResizeLinear1DGrad>(base_operator);
MS_ERROR_IF_NULL_W_RET_VAL(kernel_ptr, false);
kernel_name_ = kernel_ptr->name();
if (inputs.size() != kResizeLinear1DGradInputsNum || outputs.size() != kResizeLinear1DGradOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kResizeLinear1DGradInputsNum
<< " and " << kResizeLinear1DGradOutputsNum << ", but got " << inputs.size() << " and "
<< outputs.size();
return false;
}
std::string coordinate_transformation_mode = kernel_ptr->get_coordinate_transformation_mode();
if (coordinate_transformation_mode == "align_corners") {
coordinate_transformation_mode_ = ALIGN_CORNERS;
} else if (coordinate_transformation_mode == "half_pixel") {
coordinate_transformation_mode_ = HALF_PIXEL;
} else if (coordinate_transformation_mode == "asymmetric") {
coordinate_transformation_mode_ = ASYMMETRIC;
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', coordinate_transformation_mode: " << coordinate_transformation_mode
<< " not support now.";
return false;
}
coordinate_transformation_func_ = ChooseCoordinateTransformationFunc(coordinate_transformation_mode_);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int ResizeLinear1DGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != 0) {
return ret;
}
std::vector<int64_t> grad_shape = inputs[kIndex0]->GetShapeVector();
auto grad_batch = LongToSize(grad_shape[kIndex0]);
auto grad_channel = LongToSize(grad_shape[kIndex1]);
output_width_ = LongToSize(grad_shape[kIndex2]);
std::vector<int64_t> shape_ = inputs[kIndex1]->GetShapeVector();
batch_ = LongToSize(shape_[kIndex0]);
channel_ = LongToSize(shape_[kIndex1]);
input_width_ = LongToSize(shape_[kIndex2]);
if (grad_batch != batch_ || grad_channel != channel_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', grad batch is : " << grad_batch
<< ", while input batch is : " << batch_ << "; "
<< "grad channel is : " << grad_channel << ", while input channel is : " << channel_;
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ResizeLinear1DGrad, []() {
return std::make_shared<ResizeLinear1DGradCpuKernelMod>(kResizeLinear1DGrad);
});
} // namespace mindspore::kernel

View File

@ -0,0 +1,83 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore::kernel {
constexpr auto kUnknown = "Unknown";
class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<ResizeLinear1DGradCpuKernelMod> {
public:
ResizeLinear1DGradCpuKernelMod() = default;
explicit ResizeLinear1DGradCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ResizeLinear1DGradCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::OpSupport(); }
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);
enum CoordinateTransformationMode { ALIGN_CORNERS = 0, HALF_PIXEL = 1, ASYMMETRIC = 2, INVALID_MODE = 255 };
using CoordinateTransformationFunc =
std::function<float(const float &new_x, const int &old_length, const int &new_length)>;
void ComputeInterpolationCaches(const size_t out_size, const size_t in_size, const CoordinateTransformationFunc &func,
CachedInterpolation *interpolation);
CoordinateTransformationFunc ChooseCoordinateTransformationFunc(
CoordinateTransformationMode coordinate_transformation_mode);
std::string kernel_type_{kUnknown};
bool align_corners_{false};
bool half_pixel_center_{false};
size_t batch_{0};
size_t channel_{0};
size_t input_width_{0};
size_t output_width_{0};
CoordinateTransformationMode coordinate_transformation_mode_{ALIGN_CORNERS};
CoordinateTransformationFunc coordinate_transformation_func_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_

View File

@ -432,6 +432,7 @@ GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighbor, std::make_shared<Primitive>("
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighborGrad, std::make_shared<Primitive>("ResizeNearestNeighborGrad"));
GVAR_DEF(PrimitivePtr, kPrimDynamicResizeNearestNeighbor, std::make_shared<Primitive>("DynamicResizeNearestNeighbor"));
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1D, std::make_shared<Primitive>("ResizeLinear1D"));
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1DGrad, std::make_shared<Primitive>("ResizeLinear1DGrad"));
GVAR_DEF(PrimitivePtr, kPrimSort, std::make_shared<Primitive>("Sort"));
GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"));
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));

View File

@ -0,0 +1,90 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "ops/grad/resize_linear_1d_grad.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ResizeLinear1DGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto grad_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
MS_EXCEPTION_IF_NULL(grad_shape_ptr);
auto grad_shape = grad_shape_ptr->shape();
auto input_x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto input_x_shape = input_x_shape_ptr->shape();
std::vector<int64_t> ret_shape;
ret_shape.push_back(grad_shape[kInputIndex0]);
ret_shape.push_back(grad_shape[kInputIndex1]);
ret_shape.push_back(input_x_shape[kInputIndex2]);
if (grad_shape_ptr->IsDynamic()) {
auto grad_min_shape = grad_shape_ptr->min_shape();
std::vector<int64_t> ret_min_shape;
ret_min_shape.push_back(grad_min_shape[kInputIndex0]);
ret_min_shape.push_back(grad_min_shape[kInputIndex1]);
ret_min_shape.push_back(input_x_shape[kInputIndex2]);
auto grad_max_shape = grad_shape_ptr->max_shape();
std::vector<int64_t> ret_max_shape;
ret_max_shape.push_back(grad_max_shape[kInputIndex0]);
ret_max_shape.push_back(grad_max_shape[kInputIndex1]);
ret_max_shape.push_back(input_x_shape[kInputIndex2]);
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
}
return std::make_shared<abstract::Shape>(ret_shape);
}
TypePtr ResizeLinear1DGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
return input_args[1]->BuildType();
}
} // namespace
void ResizeLinear1DGrad::set_coordinate_transformation_mode(const std::string coordinate_transformation_mode) {
(void)this->AddAttr("coordinate_transformation_mode", api::MakeValue(coordinate_transformation_mode));
}
std::string ResizeLinear1DGrad::get_coordinate_transformation_mode() const {
auto value_ptr = GetAttr("coordinate_transformation_mode");
return GetValue<std::string>(value_ptr);
}
void ResizeLinear1DGrad::Init(const std::string coordinate_transformation_mode) {
this->set_coordinate_transformation_mode(coordinate_transformation_mode);
}
MIND_API_OPERATOR_IMPL(ResizeLinear1DGrad, BaseOperator);
AbstractBasePtr ResizeLinear1DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
kEqual, input_num, prim_name);
return abstract::MakeAbstract(ResizeLinear1DGradInferShape(primitive, input_args),
ResizeLinear1DGradInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeLinear1DGrad, prim::kPrimResizeLinear1DGrad, ResizeLinear1DGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeLinear1DGrad = "ResizeLinear1DGrad";
class MIND_API ResizeLinear1DGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ResizeLinear1DGrad);
ResizeLinear1DGrad() : BaseOperator(kNameResizeLinear1DGrad) { InitIOName({"grad", "input_x"}, {"output"}); }
void Init(const std::string coordinate_transformation_mode = "align_corners");
void set_coordinate_transformation_mode(const std::string coordinate_transformation_mode);
std::string get_coordinate_transformation_mode() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_

View File

@ -65,10 +65,12 @@ abstract::ShapePtr ResizeLinear1DInferShape(const PrimitivePtr &primitive,
auto shape0_v = shape0_ptr->shape();
auto shape1_v = shape1_ptr->shape();
if (shape0_v.size() < kInputShape0Dim) {
MS_EXCEPTION(ValueError) << "The rank of images tensor must be greater than 3. But got " << shape0_v.size();
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the rank of images tensor must be greater than 3. But got "
<< shape0_v.size();
}
if (shape1_v.size() != kInputShape1Dim) {
MS_EXCEPTION(ValueError) << "The size tensor must be a 1-D tensor. But got " << shape1_v.size() << "-D";
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the size tensor must be a 1-D tensor. But got "
<< shape1_v.size() << "-D";
}
if (size_arg->isa<abstract::AbstractTensor>() && size_arg->BuildValue()->isa<tensor::Tensor>()) {
auto size_shape_ptr = reinterpret_cast<int64_t *>(size_shape_tensor->data_c());

View File

@ -39,6 +39,7 @@ from ..operations.nn_ops import MaxPoolV1
from ..operations._grad_ops import MaxPoolGradV1
from ..operations.nn_ops import ReLUV3
from ..operations._grad_ops import ReluGrad
from ..operations.image_ops import ResizeLinear1D
@bprop_getters.register(P.CTCLossV2)
@ -247,3 +248,15 @@ def get_bprop_grid_sampler_2d(self):
return dx, dgrid
return bprop
@bprop_getters.register(ResizeLinear1D)
def get_bprop_resize_bilinear(self):
"""Grad definition for `ResizeLinear1D` operation."""
resize_grad = G.ResizeLinear1DGrad(self.coordinate_transformation_mode)
def bprop(input_x, size, out, dout):
dx = resize_grad(dout, input_x)
return (dx, zeros_like(size))
return bprop

View File

@ -1840,6 +1840,29 @@ class ResizeNearestNeighborGrad(Primitive):
self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
class ResizeLinear1DGrad(Primitive):
"""
Compute gradient of `ResizeLinear1D` operator.
Note:
This is an experimental feature and is subjected to change.
Args:
coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
"""
@prim_attr_register
def __init__(self, coordinate_transformation_mode="align_corners"):
"""Initialize ResizeLinear1DGrad"""
self.init_prim_io_names(
inputs=['grads', 'input_x'], outputs=['y'])
validator.check_value_type(
"coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
validator.check_string(coordinate_transformation_mode, ["align_corners", "half_pixel", "asymmetric"],
"coordinate_transformation_mode", self.name)
class ROIAlignGrad(PrimitiveWithInfer):
"""
ROIAlignGrad operator.

View File

@ -469,6 +469,9 @@ class ResizeLinear1D(Primitive):
r"""
Using the linear interpolate method resize the input tensor 'x'.
Note:
This is an experimental feature and is subjected to change.
Args:
coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.

View File

@ -0,0 +1,126 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore.ops.operations.image_ops import ResizeLinear1D
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class ResizeLinear1DNet(nn.Cell):
"""ResizeLinear1DNet."""
def __init__(self, coordinate_transformation_mode="align_corners"):
"""Init."""
super(ResizeLinear1DNet, self).__init__()
self.resize = ResizeLinear1D(coordinate_transformation_mode)
def construct(self, x, size):
"""Construct."""
return self.resize(x, size)
class ResizeLinear1DGradNet(nn.Cell):
"""ResizeLinear1DGradNet."""
def __init__(self, forward_cpu_net):
"""Init."""
super(ResizeLinear1DGradNet, self).__init__()
self.resize_grad = C.GradOperation(get_all=True, sens_param=True)
self.forward_cpu_net = forward_cpu_net
def construct(self, grad_output, input_x, size):
"""Construct."""
gout = self.resize_grad(self.forward_cpu_net)(
input_x, size, grad_output)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_resize_linear_1d_grad_align_corners(dtype):
"""
Feature: ResizeLinear1DGrad cpu kernel align_corners mode
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
size = Tensor(np.array([6], dtype=np.int64))
grad_output = Tensor(np.array([[[1., 2., 3., 4., 5., 6.],
[7., 8., 9., 10., 11., 12.]]], dtype=np.float32))
net_cpu = ResizeLinear1DNet()
grad = ResizeLinear1DGradNet(net_cpu)
output = grad(grad_output, x, size)
expect = np.array([[[2.8, 8.4, 9.8],
[13.6, 22.8, 20.6]]]).astype(np.float32)
print("ms grad input: ", output[0].asnumpy())
assert np.allclose(output[0].asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_resize_linear_1d_grad_half_pixel(dtype):
"""
Feature: ResizeLinear1DGrad cpu kernel half_pixel mode
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
size = Tensor(np.array([6], dtype=np.int64))
grad_output = Tensor(np.array([[[1., 2., 3., 4., 5., 6.],
[7., 8., 9., 10., 11., 12.]]], dtype=np.float32))
net_cpu = ResizeLinear1DNet("half_pixel")
grad = ResizeLinear1DGradNet(net_cpu)
output = grad(grad_output, x, size)
expect = np.array([[[3.25, 7, 10.75],
[15.25, 19, 22.75]]]).astype(np.float32)
print("ms grad input: ", output[0].asnumpy())
assert np.allclose(output[0].asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_resize_linear_1d_grad_same_shape(dtype):
"""
Feature: ResizeLinear1DGrad cpu kernel same shape
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
size = Tensor(np.array([3], dtype=np.int64))
grad_output = Tensor(np.array([[[1., 2., 3.],
[7., 8., 9.]]], dtype=np.float32))
net_cpu = ResizeLinear1DNet()
grad = ResizeLinear1DGradNet(net_cpu)
output = grad(grad_output, x, size)
expect = np.array([[[1., 2., 3.],
[7., 8., 9.]]]).astype(np.float32)
print("ms grad input: ", output[0].asnumpy())
assert np.allclose(output[0].asnumpy(), expect)