forked from mindspore-Ecosystem/mindspore
!35743 [MS][CPU]resize linear 1d grad cppu kernel
Merge pull request !35743 from mengyuanli/resize_linear_1d_grad_cpu
This commit is contained in:
commit
ea7c9daac3
|
@ -179,6 +179,22 @@ struct CachedInterpolation {
|
|||
float lerp;
|
||||
};
|
||||
|
||||
struct AlignCornersFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 1 ? new_x * (old_length - 1) / (new_length - 1) : 0;
|
||||
}
|
||||
};
|
||||
struct AsymmetricFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 0 ? new_x * old_length / new_length : 0;
|
||||
}
|
||||
};
|
||||
struct HalfPixelFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 0 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
|
||||
CachedInterpolation *interpolation);
|
||||
|
||||
|
|
|
@ -17,28 +17,12 @@
|
|||
#include "plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include "mindspore/core/ops/resize_linear_1d.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace {
|
||||
struct AlignCornersFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 1 ? new_x * (old_length - 1) / (new_length - 1) : 0;
|
||||
}
|
||||
};
|
||||
struct AsymmetricFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 0 ? new_x * old_length / new_length : 0;
|
||||
}
|
||||
};
|
||||
struct HalfPixelFunc {
|
||||
float operator()(const float &new_x, const int &old_length, const int &new_length) const {
|
||||
return new_length != 0 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kResizeLinear1D = "ResizeLinear1D";
|
||||
constexpr const size_t kResizeLinear1DInputsNum = 2;
|
||||
|
@ -131,19 +115,9 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DCpuKernelMod::KernelRunFun
|
|||
|
||||
ResizeLinear1DCpuKernelMod::CoordinateTransformationFunc ResizeLinear1DCpuKernelMod::ChooseCoordinateTransformationFunc(
|
||||
CoordinateTransformationMode coordinate_transformation_mode) {
|
||||
switch (coordinate_transformation_mode) {
|
||||
case ALIGN_CORNERS: {
|
||||
return AlignCornersFunc();
|
||||
};
|
||||
case HALF_PIXEL: {
|
||||
return HalfPixelFunc();
|
||||
};
|
||||
case ASYMMETRIC: {
|
||||
return AsymmetricFunc();
|
||||
}
|
||||
default:
|
||||
return AlignCornersFunc();
|
||||
}
|
||||
const std::unordered_map<CoordinateTransformationMode, CoordinateTransformationFunc> coordinate_map{
|
||||
{ALIGN_CORNERS, AlignCornersFunc()}, {HALF_PIXEL, HalfPixelFunc()}, {ASYMMETRIC, AsymmetricFunc()}};
|
||||
return coordinate_map.at(coordinate_transformation_mode);
|
||||
}
|
||||
|
||||
bool ResizeLinear1DCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include "mindspore/core/ops/grad/resize_linear_1d_grad.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kResizeLinear1DGrad = "ResizeLinear1DGrad";
|
||||
constexpr const size_t kResizeLinear1DGradInputsNum = 2;
|
||||
constexpr const size_t kResizeLinear1DGradOutputsNum = 1;
|
||||
|
||||
void ResizeLinear1DGradCpuKernelMod::ComputeInterpolationCaches(const size_t out_size, const size_t in_size,
|
||||
const CoordinateTransformationFunc &func,
|
||||
CachedInterpolation *interpolation) {
|
||||
interpolation[out_size].lower = 0;
|
||||
interpolation[out_size].upper = 0;
|
||||
for (size_t i = 0; i <= out_size - 1; ++i) {
|
||||
const float in = func(i, in_size, out_size);
|
||||
const float in_floor = std::floor(in);
|
||||
const float in_ceil = std::ceil(in);
|
||||
interpolation[i].lower = static_cast<size_t>(in_floor > 0 ? in_floor : 0);
|
||||
interpolation[i].upper = static_cast<size_t>(in_ceil < static_cast<float>(in_size - 1) ? in_ceil : in_size - 1);
|
||||
interpolation[i].lerp = in - in_floor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ResizeLinear1DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kResizeLinear1DGradInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kResizeLinear1DGradOutputsNum, kernel_name_);
|
||||
float *grad_output = reinterpret_cast<float *>(inputs[kIndex0]->addr);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(grad_output, false);
|
||||
T *grad_input = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(grad_input, false);
|
||||
|
||||
if (output_width_ == input_width_) {
|
||||
auto task = [grad_output, grad_input](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
grad_input[i] = static_cast<T>(grad_output[i]);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, inputs[kIndex0]->size / sizeof(float), this, ¶llel_search_info_, pool_);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (memset_s(grad_input, outputs[kIndex0]->size, 0, outputs[kIndex0]->size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output buffer memset failed.";
|
||||
}
|
||||
|
||||
std::vector<CachedInterpolation> xs(output_width_ + 1);
|
||||
ComputeInterpolationCaches(output_width_, input_width_, coordinate_transformation_func_, xs.data());
|
||||
|
||||
auto task = [grad_output, grad_input, xs, this](size_t start, size_t end) {
|
||||
for (size_t index = start; index < end; ++index) {
|
||||
for (size_t w = 0; w < output_width_; ++w) {
|
||||
const size_t xs_lower = xs[w].lower;
|
||||
const size_t xs_upper = xs[w].upper;
|
||||
const float xs_lerp = static_cast<float>(xs[w].lerp);
|
||||
*(grad_input + index * input_width_ + xs_lower) +=
|
||||
static_cast<T>((*(grad_output + index * output_width_ + w)) * (1 - xs_lerp));
|
||||
*(grad_input + index * input_width_ + xs_upper) +=
|
||||
static_cast<T>((*(grad_output + index * output_width_ + w)) * xs_lerp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ParallelLaunchAutoSearch(task, batch_ * channel_, this, ¶llel_search_info_, pool_);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define RESIZE_LINEAR_1D_GRAD_CPU_REG(MS_T, T) \
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(MS_T).AddOutputAttr(MS_T), \
|
||||
&ResizeLinear1DGradCpuKernelMod::LaunchKernel<T>
|
||||
|
||||
const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRunFunc>>
|
||||
&ResizeLinear1DGradCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{RESIZE_LINEAR_1D_GRAD_CPU_REG(kNumberTypeFloat32, float)},
|
||||
{RESIZE_LINEAR_1D_GRAD_CPU_REG(kNumberTypeFloat64, double)},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
ResizeLinear1DGradCpuKernelMod::CoordinateTransformationFunc
|
||||
ResizeLinear1DGradCpuKernelMod::ChooseCoordinateTransformationFunc(
|
||||
CoordinateTransformationMode coordinate_transformation_mode) {
|
||||
const std::unordered_map<CoordinateTransformationMode, CoordinateTransformationFunc> coordinate_map{
|
||||
{ALIGN_CORNERS, AlignCornersFunc()}, {HALF_PIXEL, HalfPixelFunc()}, {ASYMMETRIC, AsymmetricFunc()}};
|
||||
return coordinate_map.at(coordinate_transformation_mode);
|
||||
}
|
||||
|
||||
bool ResizeLinear1DGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ResizeLinear1DGrad>(base_operator);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(kernel_ptr, false);
|
||||
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
if (inputs.size() != kResizeLinear1DGradInputsNum || outputs.size() != kResizeLinear1DGradOutputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kResizeLinear1DGradInputsNum
|
||||
<< " and " << kResizeLinear1DGradOutputsNum << ", but got " << inputs.size() << " and "
|
||||
<< outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string coordinate_transformation_mode = kernel_ptr->get_coordinate_transformation_mode();
|
||||
if (coordinate_transformation_mode == "align_corners") {
|
||||
coordinate_transformation_mode_ = ALIGN_CORNERS;
|
||||
} else if (coordinate_transformation_mode == "half_pixel") {
|
||||
coordinate_transformation_mode_ = HALF_PIXEL;
|
||||
} else if (coordinate_transformation_mode == "asymmetric") {
|
||||
coordinate_transformation_mode_ = ASYMMETRIC;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', coordinate_transformation_mode: " << coordinate_transformation_mode
|
||||
<< " not support now.";
|
||||
return false;
|
||||
}
|
||||
|
||||
coordinate_transformation_func_ = ChooseCoordinateTransformationFunc(coordinate_transformation_mode_);
|
||||
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int ResizeLinear1DGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = 0;
|
||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<int64_t> grad_shape = inputs[kIndex0]->GetShapeVector();
|
||||
auto grad_batch = LongToSize(grad_shape[kIndex0]);
|
||||
auto grad_channel = LongToSize(grad_shape[kIndex1]);
|
||||
output_width_ = LongToSize(grad_shape[kIndex2]);
|
||||
|
||||
std::vector<int64_t> shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
batch_ = LongToSize(shape_[kIndex0]);
|
||||
channel_ = LongToSize(shape_[kIndex1]);
|
||||
input_width_ = LongToSize(shape_[kIndex2]);
|
||||
|
||||
if (grad_batch != batch_ || grad_channel != channel_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', grad batch is : " << grad_batch
|
||||
<< ", while input batch is : " << batch_ << "; "
|
||||
<< "grad channel is : " << grad_channel << ", while input channel is : " << channel_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ResizeLinear1DGrad, []() {
|
||||
return std::make_shared<ResizeLinear1DGradCpuKernelMod>(kResizeLinear1DGrad);
|
||||
});
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
||||
class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
|
||||
public MatchKernelHelper<ResizeLinear1DGradCpuKernelMod> {
|
||||
public:
|
||||
ResizeLinear1DGradCpuKernelMod() = default;
|
||||
explicit ResizeLinear1DGradCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ResizeLinear1DGradCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
enum CoordinateTransformationMode { ALIGN_CORNERS = 0, HALF_PIXEL = 1, ASYMMETRIC = 2, INVALID_MODE = 255 };
|
||||
using CoordinateTransformationFunc =
|
||||
std::function<float(const float &new_x, const int &old_length, const int &new_length)>;
|
||||
|
||||
void ComputeInterpolationCaches(const size_t out_size, const size_t in_size, const CoordinateTransformationFunc &func,
|
||||
CachedInterpolation *interpolation);
|
||||
|
||||
CoordinateTransformationFunc ChooseCoordinateTransformationFunc(
|
||||
CoordinateTransformationMode coordinate_transformation_mode);
|
||||
|
||||
std::string kernel_type_{kUnknown};
|
||||
bool align_corners_{false};
|
||||
bool half_pixel_center_{false};
|
||||
size_t batch_{0};
|
||||
size_t channel_{0};
|
||||
size_t input_width_{0};
|
||||
size_t output_width_{0};
|
||||
CoordinateTransformationMode coordinate_transformation_mode_{ALIGN_CORNERS};
|
||||
CoordinateTransformationFunc coordinate_transformation_func_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_RESIZE_LINEAR_1D_GRAD_CPU_KERNEL_H_
|
|
@ -432,6 +432,7 @@ GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighbor, std::make_shared<Primitive>("
|
|||
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighborGrad, std::make_shared<Primitive>("ResizeNearestNeighborGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicResizeNearestNeighbor, std::make_shared<Primitive>("DynamicResizeNearestNeighbor"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1D, std::make_shared<Primitive>("ResizeLinear1D"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1DGrad, std::make_shared<Primitive>("ResizeLinear1DGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSort, std::make_shared<Primitive>("Sort"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/grad/resize_linear_1d_grad.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ResizeLinear1DGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto grad_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
MS_EXCEPTION_IF_NULL(grad_shape_ptr);
|
||||
auto grad_shape = grad_shape_ptr->shape();
|
||||
auto input_x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto input_x_shape = input_x_shape_ptr->shape();
|
||||
std::vector<int64_t> ret_shape;
|
||||
ret_shape.push_back(grad_shape[kInputIndex0]);
|
||||
ret_shape.push_back(grad_shape[kInputIndex1]);
|
||||
ret_shape.push_back(input_x_shape[kInputIndex2]);
|
||||
if (grad_shape_ptr->IsDynamic()) {
|
||||
auto grad_min_shape = grad_shape_ptr->min_shape();
|
||||
std::vector<int64_t> ret_min_shape;
|
||||
ret_min_shape.push_back(grad_min_shape[kInputIndex0]);
|
||||
ret_min_shape.push_back(grad_min_shape[kInputIndex1]);
|
||||
ret_min_shape.push_back(input_x_shape[kInputIndex2]);
|
||||
auto grad_max_shape = grad_shape_ptr->max_shape();
|
||||
std::vector<int64_t> ret_max_shape;
|
||||
ret_max_shape.push_back(grad_max_shape[kInputIndex0]);
|
||||
ret_max_shape.push_back(grad_max_shape[kInputIndex1]);
|
||||
ret_max_shape.push_back(input_x_shape[kInputIndex2]);
|
||||
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr ResizeLinear1DGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
return input_args[1]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ResizeLinear1DGrad::set_coordinate_transformation_mode(const std::string coordinate_transformation_mode) {
|
||||
(void)this->AddAttr("coordinate_transformation_mode", api::MakeValue(coordinate_transformation_mode));
|
||||
}
|
||||
std::string ResizeLinear1DGrad::get_coordinate_transformation_mode() const {
|
||||
auto value_ptr = GetAttr("coordinate_transformation_mode");
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
void ResizeLinear1DGrad::Init(const std::string coordinate_transformation_mode) {
|
||||
this->set_coordinate_transformation_mode(coordinate_transformation_mode);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ResizeLinear1DGrad, BaseOperator);
|
||||
AbstractBasePtr ResizeLinear1DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||
kEqual, input_num, prim_name);
|
||||
return abstract::MakeAbstract(ResizeLinear1DGradInferShape(primitive, input_args),
|
||||
ResizeLinear1DGradInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeLinear1DGrad, prim::kPrimResizeLinear1DGrad, ResizeLinear1DGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameResizeLinear1DGrad = "ResizeLinear1DGrad";
|
||||
class MIND_API ResizeLinear1DGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ResizeLinear1DGrad);
|
||||
ResizeLinear1DGrad() : BaseOperator(kNameResizeLinear1DGrad) { InitIOName({"grad", "input_x"}, {"output"}); }
|
||||
|
||||
void Init(const std::string coordinate_transformation_mode = "align_corners");
|
||||
|
||||
void set_coordinate_transformation_mode(const std::string coordinate_transformation_mode);
|
||||
std::string get_coordinate_transformation_mode() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_LINEAR_1D_GRAD_H_
|
|
@ -65,10 +65,12 @@ abstract::ShapePtr ResizeLinear1DInferShape(const PrimitivePtr &primitive,
|
|||
auto shape0_v = shape0_ptr->shape();
|
||||
auto shape1_v = shape1_ptr->shape();
|
||||
if (shape0_v.size() < kInputShape0Dim) {
|
||||
MS_EXCEPTION(ValueError) << "The rank of images tensor must be greater than 3. But got " << shape0_v.size();
|
||||
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the rank of images tensor must be greater than 3. But got "
|
||||
<< shape0_v.size();
|
||||
}
|
||||
if (shape1_v.size() != kInputShape1Dim) {
|
||||
MS_EXCEPTION(ValueError) << "The size tensor must be a 1-D tensor. But got " << shape1_v.size() << "-D";
|
||||
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the size tensor must be a 1-D tensor. But got "
|
||||
<< shape1_v.size() << "-D";
|
||||
}
|
||||
if (size_arg->isa<abstract::AbstractTensor>() && size_arg->BuildValue()->isa<tensor::Tensor>()) {
|
||||
auto size_shape_ptr = reinterpret_cast<int64_t *>(size_shape_tensor->data_c());
|
||||
|
|
|
@ -39,6 +39,7 @@ from ..operations.nn_ops import MaxPoolV1
|
|||
from ..operations._grad_ops import MaxPoolGradV1
|
||||
from ..operations.nn_ops import ReLUV3
|
||||
from ..operations._grad_ops import ReluGrad
|
||||
from ..operations.image_ops import ResizeLinear1D
|
||||
|
||||
|
||||
@bprop_getters.register(P.CTCLossV2)
|
||||
|
@ -247,3 +248,15 @@ def get_bprop_grid_sampler_2d(self):
|
|||
return dx, dgrid
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(ResizeLinear1D)
|
||||
def get_bprop_resize_bilinear(self):
|
||||
"""Grad definition for `ResizeLinear1D` operation."""
|
||||
resize_grad = G.ResizeLinear1DGrad(self.coordinate_transformation_mode)
|
||||
|
||||
def bprop(input_x, size, out, dout):
|
||||
dx = resize_grad(dout, input_x)
|
||||
return (dx, zeros_like(size))
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -1840,6 +1840,29 @@ class ResizeNearestNeighborGrad(Primitive):
|
|||
self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
|
||||
|
||||
|
||||
class ResizeLinear1DGrad(Primitive):
|
||||
"""
|
||||
Compute gradient of `ResizeLinear1D` operator.
|
||||
|
||||
Note:
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
Args:
|
||||
coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
|
||||
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, coordinate_transformation_mode="align_corners"):
|
||||
"""Initialize ResizeLinear1DGrad"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['grads', 'input_x'], outputs=['y'])
|
||||
validator.check_value_type(
|
||||
"coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
|
||||
validator.check_string(coordinate_transformation_mode, ["align_corners", "half_pixel", "asymmetric"],
|
||||
"coordinate_transformation_mode", self.name)
|
||||
|
||||
|
||||
class ROIAlignGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
ROIAlignGrad operator.
|
||||
|
|
|
@ -469,6 +469,9 @@ class ResizeLinear1D(Primitive):
|
|||
r"""
|
||||
Using the linear interpolate method resize the input tensor 'x'.
|
||||
|
||||
Note:
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
Args:
|
||||
coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
|
||||
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations.image_ops import ResizeLinear1D
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class ResizeLinear1DNet(nn.Cell):
|
||||
"""ResizeLinear1DNet."""
|
||||
|
||||
def __init__(self, coordinate_transformation_mode="align_corners"):
|
||||
"""Init."""
|
||||
super(ResizeLinear1DNet, self).__init__()
|
||||
self.resize = ResizeLinear1D(coordinate_transformation_mode)
|
||||
|
||||
def construct(self, x, size):
|
||||
"""Construct."""
|
||||
return self.resize(x, size)
|
||||
|
||||
|
||||
class ResizeLinear1DGradNet(nn.Cell):
|
||||
"""ResizeLinear1DGradNet."""
|
||||
|
||||
def __init__(self, forward_cpu_net):
|
||||
"""Init."""
|
||||
super(ResizeLinear1DGradNet, self).__init__()
|
||||
self.resize_grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.forward_cpu_net = forward_cpu_net
|
||||
|
||||
def construct(self, grad_output, input_x, size):
|
||||
"""Construct."""
|
||||
gout = self.resize_grad(self.forward_cpu_net)(
|
||||
input_x, size, grad_output)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_resize_linear_1d_grad_align_corners(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1DGrad cpu kernel align_corners mode
|
||||
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
size = Tensor(np.array([6], dtype=np.int64))
|
||||
grad_output = Tensor(np.array([[[1., 2., 3., 4., 5., 6.],
|
||||
[7., 8., 9., 10., 11., 12.]]], dtype=np.float32))
|
||||
net_cpu = ResizeLinear1DNet()
|
||||
grad = ResizeLinear1DGradNet(net_cpu)
|
||||
output = grad(grad_output, x, size)
|
||||
expect = np.array([[[2.8, 8.4, 9.8],
|
||||
[13.6, 22.8, 20.6]]]).astype(np.float32)
|
||||
print("ms grad input: ", output[0].asnumpy())
|
||||
assert np.allclose(output[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_resize_linear_1d_grad_half_pixel(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1DGrad cpu kernel half_pixel mode
|
||||
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
size = Tensor(np.array([6], dtype=np.int64))
|
||||
grad_output = Tensor(np.array([[[1., 2., 3., 4., 5., 6.],
|
||||
[7., 8., 9., 10., 11., 12.]]], dtype=np.float32))
|
||||
net_cpu = ResizeLinear1DNet("half_pixel")
|
||||
grad = ResizeLinear1DGradNet(net_cpu)
|
||||
output = grad(grad_output, x, size)
|
||||
expect = np.array([[[3.25, 7, 10.75],
|
||||
[15.25, 19, 22.75]]]).astype(np.float32)
|
||||
print("ms grad input: ", output[0].asnumpy())
|
||||
assert np.allclose(output[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_resize_linear_1d_grad_same_shape(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1DGrad cpu kernel same shape
|
||||
Description: test the rightness of ResizeLinear1DGrad cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
size = Tensor(np.array([3], dtype=np.int64))
|
||||
grad_output = Tensor(np.array([[[1., 2., 3.],
|
||||
[7., 8., 9.]]], dtype=np.float32))
|
||||
net_cpu = ResizeLinear1DNet()
|
||||
grad = ResizeLinear1DGradNet(net_cpu)
|
||||
output = grad(grad_output, x, size)
|
||||
expect = np.array([[[1., 2., 3.],
|
||||
[7., 8., 9.]]]).astype(np.float32)
|
||||
print("ms grad input: ", output[0].asnumpy())
|
||||
assert np.allclose(output[0].asnumpy(), expect)
|
Loading…
Reference in New Issue