回退 'Pull Request !36155 : [assistant][ResizeNearestNeighborV2][ResizeNearestNeighborV2Grad] cast shape type size_t to int64_t for ResizeNearestNeighborV2 & ResizeNearestNeighborV2Grad'

This commit is contained in:
yanghaoran 2022-06-22 01:22:24 +00:00 committed by Gitee
parent f98be788a3
commit 8c444c00fc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 4 additions and 1139 deletions

View File

@ -85,6 +85,8 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kErfOpName, {1});
Register(kSparseApplyAdagradOpName, {2});
Register(kResizeNearestNeighborGradOpName, {1});
Register(kResizeNearestNeighborV2OpName, {1});
Register(kResizeNearestNeighborV2GradOpName, {1});
Register(kApplyRMSPropOpname, {5, 6, 7});
Register(kResizeBilinearV2OpName, {1});
Register(kReduceProdOpName, {1});

View File

@ -390,8 +390,6 @@ constexpr auto kHcomOpTypeReceive = "HcomReceive";
constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter";
// attr key name
constexpr auto kAttrAlignCorners = "align_corners";
constexpr auto kAttrHalfPixelCenters = "half_pixel_centers";
constexpr auto kAttrInputNames = "input_names";
constexpr auto kAttrAttrNames = "attr_names";
constexpr auto kAttrIsAiCpuKernel = "is_AICPU_kernel";

View File

@ -17,7 +17,6 @@
#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
#include <dirent.h>
#include <limits>
#include <memory>
#include <unordered_map>
#include <unordered_set>
@ -165,23 +164,6 @@ std::string GetProcessorStr(const AnfNodePtr &anf_node);
Processor GetProcessorFromContext();
std::string GetStrProcessorFromContext();
float Scaling(size_t in_size, size_t out_size, bool align_corners);
inline float Scaler(const size_t x, const float scale, bool half_pixel_centers) {
if (half_pixel_centers) {
/**
* function with a std::floor(), so instead of subtracting the 0.5 as we
* do in HalfPixelScale, we leave it as is, as the std::floor does the
* correct thing.
* */
return (static_cast<float>(x) + 0.5f) * scale;
} else {
/**
* Older incorrect scaling method that causes all resizes to have a slight
* translation leading to inconsistent results. For example, a flip then a
* resize gives different results then a resize then a flip.
* */
return static_cast<float>(x) * scale;
}
}
float ScaleGrid(const int x, const float scale);
FusionType GetFusionTypeByName(const std::string &name);
std::string GetFusionNameByType(const kernel::FusionType &type);

View File

@ -62,8 +62,6 @@ constexpr auto kGather = "Gather";
constexpr auto kIdentity = "Identity";
constexpr auto kIdentityN = "IdentityN";
constexpr auto kRandomChoiceWithMask = "RandomChoiceWithMask";
constexpr auto kResizeNearestNeighborV2 = "ResizeNearestNeighborV2";
constexpr auto kResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad";
constexpr auto kUpdateCache = "UpdateCache";
constexpr auto kCacheSwapTable = "CacheSwapTable";
constexpr auto kSubAndFilter = "SubAndFilter";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,17 +16,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "Eigen/Core"
#include "Eigen/Dense"
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef _WIN32
#undef ERROR
#endif
#include "Eigen/Core"
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;
@ -44,115 +35,6 @@ template <typename T>
using MatrixSquare = Eigen::Matrix<T, Dynamic, Dynamic, RowMajor>;
template <typename T>
using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, RowMajor>;
template <typename T, int NDIMS = kDim1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned> Tensor;
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstTensor;
// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, kDim1, Eigen::RowMajor, IndexType>, Eigen::Aligned> Flat;
typedef Eigen::TensorMap<Eigen::Tensor<const T, kDim1, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, kDim1, Eigen::RowMajor, IndexType>, Eigen::Aligned> Vec;
typedef Eigen::TensorMap<Eigen::Tensor<const T, kDim1, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstVec;
// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, kDim2, Eigen::RowMajor, IndexType>, Eigen::Aligned> Matrix;
typedef Eigen::TensorMap<Eigen::Tensor<const T, kDim2, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstMatrix;
};
class EigenTensor {
public:
EigenTensor() = delete;
EigenTensor(ShapeVector &shape, void *data_ptr) : tensor_shape(shape), tensor_data_ptr(data_ptr) {}
~EigenTensor() = default;
/*
* Eigen vec
* @return Eigen vec
*/
template <typename T>
typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}
/*
* Eigen matrix
* @return Eigen matrix
*/
template <typename T>
typename TTypes<T>::Matrix matrix() {
return tensor<T, kDim2>();
}
/*
* Eigen ConstMatrix
* @return Eigen ConstMatrix
*/
template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, kDim2>();
}
/*
* Eigen tensor
* @return Eigen tensor
*/
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(reinterpret_cast<T *>(tensor_data_ptr), AsEigenDSizes<NDIMS>());
}
/*
* Eigen ConstTensor
* @return Eigen ConstTensor
*/
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor tensor() const {
return typename TTypes<T, NDIMS>::ConstTensor(reinterpret_cast<const T *>(tensor_data_ptr), AsEigenDSizes<NDIMS>());
}
/*
* Eigen Flat
* @return Eigen Flat
*/
template <typename T>
typename TTypes<T>::Flat flat() {
return typename TTypes<T>::Flat(
reinterpret_cast<T *>(tensor_data_ptr),
{std::accumulate(tensor_shape.begin(), tensor_shape.end(), 1, std::multiplies<int64_t>())});
}
/*
* which case we pad the rest of the sizes with 1.
* @return Eigen::DSizes: pad the rest of the sizes with 1
*/
template <int NDIMS, typename IndexType>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const {
Eigen::DSizes<IndexType, NDIMS> dsizes;
for (size_t d = 0; d < tensor_shape.size(); d++) {
dsizes[d] = static_cast<IndexType>(tensor_shape[d]);
}
for (size_t d = tensor_shape.size(); d < NDIMS; d++) {
dsizes[d] = 1;
}
return dsizes;
}
/*
* Fill `*dsizes` from `*this`
* @return Eigen::DSizes: pad the rest of the sizes with 1
*/
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const {
return AsEigenDSizesWithPadding<NDIMS, IndexType>();
}
private:
ShapeVector tensor_shape;
void *tensor_data_ptr;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_

View File

@ -1,166 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/resize_nearest_neighbor_v2_cpu_kernel.h"
#include <string>
#include "kernel/common_utils.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kResizeNearestNeighborV2InputsNum = 2;
constexpr size_t kResizeNearestNeighborV2OutputNum = 1;
} // namespace
void ResizeNearestNeighborV2CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
cnode_ptr_ = kernel_node;
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
y_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, kIndex0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
auto size_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
if (x_shape_.size() != kShape4dDims) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dimension of 'x' should be " << kShape4dDims
<< ", but got " << x_shape_.size();
}
if (size_shape.size() != kShape1dDims) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dimension of 'size' should be " << kShape1dDims
<< ", but got " << size_shape.size();
}
align_corners_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrAlignCorners);
half_pixel_centers_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrHalfPixelCenters);
std::string data_format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFormat);
if (data_format.compare(kOpFormat_NCHW) == 0) {
dim_idx_map_ = {{'N', kIndex0}, {'C', kIndex1}, {'H', kIndex2}, {'W', kIndex3}};
} else if (data_format.compare(kOpFormat_NHWC) == 0) {
dim_idx_map_ = {{'N', kIndex0}, {'H', kIndex1}, {'W', kIndex2}, {'C', kIndex3}};
} else {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr of 'data_format' only support ["
<< kOpFormat_NCHW << ", " << kOpFormat_NHWC << "].";
}
}
bool ResizeNearestNeighborV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kResizeNearestNeighborV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kResizeNearestNeighborV2OutputNum, kernel_name_);
bool res = false;
switch (y_type_) {
case kNumberTypeUInt8:
res = LaunchKernel<uint8_t>(inputs, outputs);
break;
case kNumberTypeUInt16:
res = LaunchKernel<uint16_t>(inputs, outputs);
break;
case kNumberTypeInt8:
res = LaunchKernel<int8_t>(inputs, outputs);
break;
case kNumberTypeInt16:
res = LaunchKernel<int16_t>(inputs, outputs);
break;
case kNumberTypeInt32:
res = LaunchKernel<int32_t>(inputs, outputs);
break;
case kNumberTypeInt64:
res = LaunchKernel<int64_t>(inputs, outputs);
break;
case kNumberTypeFloat16:
res = LaunchKernel<float16>(inputs, outputs);
break;
case kNumberTypeFloat32:
res = LaunchKernel<float>(inputs, outputs);
break;
case kNumberTypeFloat64:
res = LaunchKernel<double>(inputs, outputs);
break;
default:
MS_EXCEPTION(TypeError)
<< "For '" << kernel_name_
<< "', the dtype of 'x' should be float16, float32, float64, int32, int64, int16, int8, uint16 or uin8 but got "
<< TypeIdLabel(y_type_);
}
return res;
}
template <typename T>
bool ResizeNearestNeighborV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
const int64_t batch_size = x_shape_[dim_idx_map_['N']];
const int64_t in_height = x_shape_[dim_idx_map_['H']];
const int64_t in_width = x_shape_[dim_idx_map_['W']];
const int64_t channels = x_shape_[dim_idx_map_['C']];
const int64_t out_height = y_shape_[dim_idx_map_['H']];
const int64_t out_width = y_shape_[dim_idx_map_['W']];
const float height_scale = Scaling(in_height, out_height, align_corners_);
const float width_scale = Scaling(in_width, out_width, align_corners_);
auto x_4d = EigenTensor(x_shape_, inputs[kIndex0]->addr).tensor<T, kDim4>();
auto y_4d = EigenTensor(y_shape_, outputs[kIndex0]->addr).tensor<T, kDim4>();
for (int64_t b = 0; b < batch_size; ++b) {
for (int64_t y = 0; y < out_height; ++y) {
int64_t in_y =
std::min((align_corners_) ? static_cast<int64_t>(roundf(Scaler(y, height_scale, half_pixel_centers_)))
: static_cast<int64_t>(floorf(Scaler(y, height_scale, half_pixel_centers_))),
in_height - 1);
if (half_pixel_centers_) {
in_y = std::max(static_cast<int64_t>(0), in_y);
}
for (int64_t x = 0; x < out_width; ++x) {
int64_t in_x =
std::min((align_corners_) ? static_cast<int64_t>(roundf(Scaler(x, width_scale, half_pixel_centers_)))
: static_cast<int64_t>(floorf(Scaler(x, width_scale, half_pixel_centers_))),
in_width - 1);
if (half_pixel_centers_) {
in_x = std::max(static_cast<int64_t>(0), in_x);
}
// data_format = NHWC
if (dim_idx_map_['C'] == kIndex3) {
std::copy_n(&x_4d(b, in_y, in_x, 0), channels, &y_4d(b, y, x, 0));
} else {
// data_format = NCHW
for (int64_t c = 0; c < channels; ++c) {
y_4d(b, c, y, x) = x_4d(b, c, in_y, in_x);
}
}
}
}
}
return true;
}
std::vector<KernelAttr> ResizeNearestNeighborV2CpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ResizeNearestNeighborV2, ResizeNearestNeighborV2CpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,57 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_CPU_KERNEL_H_
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "kernel/common_utils.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class ResizeNearestNeighborV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
ResizeNearestNeighborV2CpuKernelMod() = default;
~ResizeNearestNeighborV2CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId y_type_{kTypeUnknown};
bool align_corners_{false};
bool half_pixel_centers_{false};
std::vector<int64_t> x_shape_;
std::vector<int64_t> y_shape_;
std::unordered_map<char, size_t> dim_idx_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_CPU_KERNEL_H_

View File

@ -1,162 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "plugin/device/cpu/kernel/resize_nearest_neighbor_v2_grad_cpu_kernel.h"
#include "kernel/common_utils.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kResizeNearestNeighborV2GradInputsNum = 2;
constexpr size_t kResizeNearestNeighborV2GradOutputNum = 1;
} // namespace
void ResizeNearestNeighborV2GradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
cnode_ptr_ = kernel_node;
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
y_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, kIndex0);
grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
auto size_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex1);
if (grads_shape_.size() != kShape4dDims) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dimension of 'x' should be " << kShape4dDims
<< ", but got " << grads_shape_.size();
}
if (size_shape.size() != kShape1dDims) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dimension of 'size' should be " << kShape1dDims
<< ", but got " << size_shape.size();
}
align_corners_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrAlignCorners);
half_pixel_centers_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrHalfPixelCenters);
std::string data_format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFormat);
if (data_format.compare(kOpFormat_NCHW) == 0) {
dim_idx_map_ = {{'N', kIndex0}, {'C', kIndex1}, {'H', kIndex2}, {'W', kIndex3}};
} else if (data_format.compare(kOpFormat_NHWC) == 0) {
dim_idx_map_ = {{'N', kIndex0}, {'H', kIndex1}, {'W', kIndex2}, {'C', kIndex3}};
} else {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr of 'data_format' only support ["
<< kOpFormat_NCHW << ", " << kOpFormat_NHWC << "].";
}
}
bool ResizeNearestNeighborV2GradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kResizeNearestNeighborV2GradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kResizeNearestNeighborV2GradOutputNum, kernel_name_);
bool res = false;
switch (y_type_) {
case kNumberTypeUInt8:
res = LaunchKernel<uint8_t>(inputs, outputs);
break;
case kNumberTypeUInt16:
res = LaunchKernel<uint16_t>(inputs, outputs);
break;
case kNumberTypeInt8:
res = LaunchKernel<int8_t>(inputs, outputs);
break;
case kNumberTypeInt16:
res = LaunchKernel<int16_t>(inputs, outputs);
break;
case kNumberTypeInt32:
res = LaunchKernel<int32_t>(inputs, outputs);
break;
case kNumberTypeInt64:
res = LaunchKernel<int64_t>(inputs, outputs);
break;
case kNumberTypeFloat16:
res = LaunchKernel<float16>(inputs, outputs);
break;
case kNumberTypeFloat32:
res = LaunchKernel<float>(inputs, outputs);
break;
case kNumberTypeFloat64:
res = LaunchKernel<double>(inputs, outputs);
break;
default:
MS_EXCEPTION(TypeError)
<< "For '" << kernel_name_
<< "', the dtype of 'x' should be float16, float32, float64, int32, int64, int16, int8, uint16 or uin8 but got "
<< TypeIdLabel(y_type_);
break;
}
return res;
}
template <typename T>
bool ResizeNearestNeighborV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
const int64_t batch_size = grads_shape_[dim_idx_map_['N']];
const int64_t in_height = grads_shape_[dim_idx_map_['H']];
const int64_t in_width = grads_shape_[dim_idx_map_['W']];
const int64_t channels = grads_shape_[dim_idx_map_['C']];
const int64_t out_height = y_shape_[dim_idx_map_['H']];
const int64_t out_width = y_shape_[dim_idx_map_['W']];
const float height_scale = Scaling(out_height, in_height, align_corners_);
const float width_scale = Scaling(out_width, in_width, align_corners_);
auto grads_4d = EigenTensor(grads_shape_, inputs[kIndex0]->addr).tensor<T, kDim4>();
auto y_4d = EigenTensor(y_shape_, outputs[kIndex0]->addr).tensor<T, kDim4>();
y_4d.setZero();
for (int64_t y = 0; y < in_height; ++y) {
int64_t out_y =
std::min((align_corners_) ? static_cast<int64_t>(roundf(Scaler(y, height_scale, half_pixel_centers_)))
: static_cast<int64_t>(floorf(Scaler(y, height_scale, half_pixel_centers_))),
out_height - 1);
for (int64_t x = 0; x < in_width; ++x) {
int64_t out_x =
std::min((align_corners_) ? static_cast<int64_t>(roundf(Scaler(x, width_scale, half_pixel_centers_)))
: static_cast<int64_t>(floorf(Scaler(x, width_scale, half_pixel_centers_))),
out_width - 1);
for (int64_t b = 0; b < batch_size; ++b) {
for (int64_t c = 0; c < channels; ++c) {
// data_format = NHWC
if (dim_idx_map_['C'] == kIndex3) {
y_4d(b, out_y, out_x, c) += grads_4d(b, y, x, c);
} else {
// data_format = NCHW
y_4d(b, c, out_y, out_x) += grads_4d(b, c, y, x);
}
}
}
}
}
return true;
}
std::vector<KernelAttr> ResizeNearestNeighborV2GradCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,55 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_CPU_KERNEL_H_
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <vector>
#include "kernel/common_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class ResizeNearestNeighborV2GradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
ResizeNearestNeighborV2GradCpuKernelMod() = default;
~ResizeNearestNeighborV2GradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId y_type_{kTypeUnknown};
bool align_corners_{false};
bool half_pixel_centers_{false};
std::vector<int64_t> grads_shape_;
std::vector<int64_t> y_shape_;
std::unordered_map<char, size_t> dim_idx_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_CPU_KERNEL_H_

View File

@ -80,8 +80,6 @@ PrimShapeDependMap &GetHostDependsMap() {
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
static const auto &kReshape = prim::kPrimReshape->name();
static const auto &kResizeNearestNeighborV2 = prim::kPrimResizeNearestNeighborV2->name();
static const auto &kResizeNearestNeighborV2Grad = prim::kPrimResizeNearestNeighborV2Grad->name();
static const auto &kScatterNd = prim::kPrimScatterNd->name();
static const auto &kTruncatedNormal = prim::kPrimTruncatedNormal->name();
static const auto &kRandomGamma = prim::kPrimRandomGamma->name();
@ -124,8 +122,6 @@ PrimShapeDependMap &GetHostDependsMap() {
{kTile, ShapeSet{1}},
{kTopK, ShapeSet{1}},
{kReshape, ShapeSet{1}},
{kResizeNearestNeighborV2, ShapeSet{1}},
{kResizeNearestNeighborV2Grad, ShapeSet{1}},
{kScatterNd, ShapeSet{2}},
{kSliceGrad, ShapeSet{2, 3}},
{kFillV2, ShapeSet{0}},

View File

@ -449,8 +449,6 @@ GVAR_DEF(PrimitivePtr, kPrimParallelResizeBilinearGrad, std::make_shared<Primiti
GVAR_DEF(PrimitivePtr, kPrimResizeGrad, std::make_shared<Primitive>("ResizeGrad"));
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighbor, std::make_shared<Primitive>("ResizeNearestNeighbor"));
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighborGrad, std::make_shared<Primitive>("ResizeNearestNeighborGrad"));
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighborV2, std::make_shared<Primitive>("ResizeNearestNeighborV2"));
GVAR_DEF(PrimitivePtr, kPrimResizeNearestNeighborV2Grad, std::make_shared<Primitive>("ResizeNearestNeighborV2Grad"));
GVAR_DEF(PrimitivePtr, kPrimDynamicResizeNearestNeighbor, std::make_shared<Primitive>("DynamicResizeNearestNeighbor"));
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1D, std::make_shared<Primitive>("ResizeLinear1D"));
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1DGrad, std::make_shared<Primitive>("ResizeLinear1DGrad"));

View File

@ -1,121 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/resize_nearest_neighbor_v2_grad.h"
#include <string>
#include <algorithm>
#include <memory>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
#define IsNoneOrAnyValue(value_ptr) ((value_ptr->isa<None>()) || (value_ptr->isa<AnyValue>()))
abstract::ShapePtr ResizeNearestNeighborV2GradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto size_ptr = input_args[kInputIndex1]->BuildValue();
(void)CheckAndConvertUtils::CheckInteger("dimension of grads", SizeToLong(grads_shape.size()), kEqual,
SizeToLong(kDim4), prim_name);
(void)CheckAndConvertUtils::CheckInteger("dimension of size", SizeToLong(size_shape.size()), kEqual,
SizeToLong(kDim1), prim_name);
auto data_format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
std::map<char, size_t> dim_idx_map;
auto align_corners_ptr = primitive->GetAttr(kAlignCorners);
MS_EXCEPTION_IF_NULL(align_corners_ptr);
auto align_corners = GetValue<bool>(align_corners_ptr);
auto half_pixel_centers_ptr = primitive->GetAttr(kHalfPixelCenters);
MS_EXCEPTION_IF_NULL(half_pixel_centers_ptr);
auto half_pixel_centers = GetValue<bool>(half_pixel_centers_ptr);
if (align_corners && half_pixel_centers) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << ". If half_pixel_centers is True, align_corners must be False.";
}
if (data_format == Format::NCHW) {
dim_idx_map = {{'N', kInputIndex0}, {'C', kInputIndex1}, {'H', kInputIndex2}, {'W', kInputIndex3}};
} else if (data_format == Format::NHWC) {
dim_idx_map = {{'N', kInputIndex0}, {'H', kInputIndex1}, {'W', kInputIndex2}, {'C', kInputIndex3}};
} else {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the attr of 'data_format' only support [" << kFormatNCHW
<< ", " << kFormatNHWC << "]. But get '" << data_format << "'.";
}
bool is_compile = IsNoneOrAnyValue(size_ptr);
ShapeVector y_shape(kDim4);
if (is_compile) {
y_shape[dim_idx_map['N']] = grads_shape[dim_idx_map['N']];
y_shape[dim_idx_map['C']] = grads_shape[dim_idx_map['C']];
y_shape[dim_idx_map['H']] = abstract::Shape::SHP_ANY;
y_shape[dim_idx_map['W']] = abstract::Shape::SHP_ANY;
ShapeVector y_shape_min(y_shape);
y_shape_min[dim_idx_map['H']] = 0;
y_shape_min[dim_idx_map['W']] = 0;
ShapeVector y_shape_max(grads_shape);
return std::make_shared<abstract::Shape>(y_shape, y_shape_min, y_shape_max);
} else {
MS_EXCEPTION_IF_NULL(size_ptr);
auto size_value = CheckAndConvertUtils::CheckTensorIntValue("input size", size_ptr, prim_name);
if (size_value.size() != kDim2) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the elements number of 'size' should be 2, but get "
<< size_value.size() << " number.";
}
y_shape[dim_idx_map['N']] = grads_shape[dim_idx_map['N']];
y_shape[dim_idx_map['C']] = grads_shape[dim_idx_map['C']];
y_shape[dim_idx_map['H']] = size_value.front();
y_shape[dim_idx_map['W']] = size_value.back();
}
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr ResizeNearestNeighborV2GradInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
std::set<TypePtr> support_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
auto grads_type = CheckAndConvertUtils::CheckTensorTypeValid("grads", input_args[kInputIndex0]->BuildType(),
support_types, primitive->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64},
primitive->name());
return grads_type;
}
} // namespace
AbstractBasePtr ResizeNearestNeighborV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
auto infer_shape = ResizeNearestNeighborV2GradInferShape(primitive, input_args);
auto infer_type = ResizeNearestNeighborV2GradInferType(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(ResizeNearestNeighborV2Grad, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeNearestNeighborV2Grad, prim::kPrimResizeNearestNeighborV2Grad,
ResizeNearestNeighborV2GradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,46 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_H_
#define MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_H_
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/base_operator.h"
#include "utils/check_convert_utils.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad";
/// \brief the grad operation of @ref mindspore.ops.ResizeNearestNeighborV2
/// Refer to Python API @ref mindspore._grad_ops.ResizeNearestNeighborV2Grad for more details.
class MIND_API ResizeNearestNeighborV2Grad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ResizeNearestNeighborV2Grad);
/// \brief Constructor.
ResizeNearestNeighborV2Grad() : BaseOperator(kNameResizeNearestNeighborV2Grad) {}
};
AbstractBasePtr ResizeNearestNeighborV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimResizeNearestNeighborV2GradPtr = std::shared_ptr<ResizeNearestNeighborV2Grad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_GRAD_H_

View File

@ -275,8 +275,6 @@ constexpr auto kSymmetric = "symmetric";
constexpr auto kDstType = "dst_type";
constexpr auto kNone = "none";
constexpr auto kMean = "mean";
constexpr auto kFormatNCHW = "NCHW";
constexpr auto kFormatNHWC = "NHWC";
constexpr auto kBatchMean = "batchmean";
constexpr auto kSum = "sum";
constexpr auto kIndices = "indices";

View File

@ -1,121 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/resize_nearest_neighbor_v2.h"
#include <string>
#include <algorithm>
#include <memory>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
namespace {
#define IsSameType(source_type, cmp_type) (cmp_type->equal(source_type))
#define IsNoneOrAnyValue(value_ptr) ((value_ptr->isa<None>()) || (value_ptr->isa<AnyValue>()))
abstract::ShapePtr ResizeNearestNeighborV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto size_ptr = input_args[kInputIndex1]->BuildValue();
(void)CheckAndConvertUtils::CheckInteger("dimension of x", SizeToLong(x_shape.size()), kEqual, SizeToLong(kDim4),
prim_name);
(void)CheckAndConvertUtils::CheckInteger("dimension of size", SizeToLong(size_shape.size()), kEqual,
SizeToLong(kDim1), prim_name);
auto align_corners_ptr = primitive->GetAttr(kAlignCorners);
MS_EXCEPTION_IF_NULL(align_corners_ptr);
auto align_corners = GetValue<bool>(align_corners_ptr);
auto half_pixel_centers_ptr = primitive->GetAttr(kHalfPixelCenters);
MS_EXCEPTION_IF_NULL(half_pixel_centers_ptr);
auto half_pixel_centers = GetValue<bool>(half_pixel_centers_ptr);
if (align_corners && half_pixel_centers) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << ". If half_pixel_centers is True, align_corners must be False.";
}
auto data_format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
std::map<char, size_t> dim_idx_map;
if (data_format == Format::NCHW) {
dim_idx_map = {{'N', kInputIndex0}, {'C', kInputIndex1}, {'H', kInputIndex2}, {'W', kInputIndex3}};
} else if (data_format == Format::NHWC) {
dim_idx_map = {{'N', kInputIndex0}, {'H', kInputIndex1}, {'W', kInputIndex2}, {'C', kInputIndex3}};
} else {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the attr of 'data_format' only support [" << kFormatNCHW
<< ", " << kFormatNHWC << "]. But get '" << data_format << "'.";
}
bool is_compile = IsNoneOrAnyValue(size_ptr);
ShapeVector y_shape(kDim4);
y_shape[dim_idx_map['N']] = x_shape[dim_idx_map['N']];
y_shape[dim_idx_map['C']] = x_shape[dim_idx_map['C']];
if (is_compile) {
y_shape[dim_idx_map['H']] = abstract::Shape::SHP_ANY;
y_shape[dim_idx_map['W']] = abstract::Shape::SHP_ANY;
ShapeVector y_shape_min(y_shape);
y_shape_min[dim_idx_map['H']] = 0;
y_shape_min[dim_idx_map['W']] = 0;
ShapeVector y_shape_max(x_shape);
return std::make_shared<abstract::Shape>(y_shape, y_shape_min, y_shape_max);
} else {
MS_EXCEPTION_IF_NULL(size_ptr);
auto size_value = CheckAndConvertUtils::CheckTensorIntValue("input size", size_ptr, prim_name);
if (size_value.size() != kDim2) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the elements number of 'size' should be 2, but get "
<< size_value.size() << " number.";
}
y_shape[dim_idx_map['H']] = size_value.front();
y_shape[dim_idx_map['W']] = size_value.back();
}
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr ResizeNearestNeighborV2InferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
std::set<TypePtr> support_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
auto start_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(),
support_types, primitive->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64},
primitive->name());
return start_type;
}
} // namespace
AbstractBasePtr ResizeNearestNeighborV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
constexpr int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto infer_type = ResizeNearestNeighborV2InferType(primitive, input_args);
auto infer_shape = ResizeNearestNeighborV2InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(ResizeNearestNeighborV2, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeNearestNeighborV2, prim::kPrimResizeNearestNeighborV2, ResizeNearestNeighborV2Infer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,46 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_H_
#define MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_H_
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/base_operator.h"
#include "utils/check_convert_utils.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeNearestNeighborV2 = "ResizeNearestNeighborV2";
/// \brief Resizes the input tensor by using the nearest neighbor algorithm.
/// Refer to Python API @ref mindspore.ops.ResizeNearestNeighborV2 for more details.
class MIND_API ResizeNearestNeighborV2 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ResizeNearestNeighborV2);
/// \brief Constructor.
ResizeNearestNeighborV2() : BaseOperator(kNameResizeNearestNeighborV2) {}
};
AbstractBasePtr ResizeNearestNeighborV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimResizeNearestNeighborV2Ptr = std::shared_ptr<ResizeNearestNeighborV2>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RESIZE_NEAREST_NEIGHBOR_V2_H_

View File

@ -16,7 +16,6 @@
"""array_ops"""
from mindspore import Tensor
from mindspore.ops.primitive import constexpr
from ...common import dtype as mstype
from ...numpy.array_ops import where
from .._grad.grad_math_ops import binop_grad_common
@ -25,7 +24,6 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.array_ops import Tril
from ..operations.array_ops import MatrixDiagV3
from ..operations.array_ops import MatrixDiagPartV3
from ..operations.array_ops import ResizeNearestNeighborV2
from ..operations.array_ops import MatrixSetDiagV3
from ..operations.array_ops import Triu
from ..operations.array_ops import IdentityN
@ -37,12 +35,6 @@ from ..operations.array_ops import Expand
from .. import functional as F
from .. import operations as P
from .._utils.utils import is_shape_unknown
from ..operations import _grad_ops as G
@constexpr
def _create_tensor(data, dtype):
return Tensor(data, dtype=dtype)
def _segment_min_or_max_grad(segment_sum_op, input_x, segment_ids, output, dout):
@ -272,25 +264,6 @@ def get_bprop_identity_n(self):
return bprop
@bprop_getters.register(ResizeNearestNeighborV2)
def get_bprop_resize_nearest_neighbor_v2(self):
"""Generate bprop for ResizeNearestNeighborV2"""
align_corners = self.align_corners
half_pixel_centers = self.half_pixel_centers
data_format = self.data_format
grad_op = G.ResizeNearestNeighborV2Grad(align_corners, half_pixel_centers, data_format)
def bprop(x, size, output, dout):
x_shape = P.Shape()(x)
grad_in_size = x_shape[1:3]
if data_format == 'NCHW':
grad_in_size = x_shape[2:4]
dx = grad_op(dout, _create_tensor(grad_in_size, mstype.int32))
return dx, zeros_like(grad_in_size)
return bprop
@bprop_getters.register(P.ExtractVolumePatches)
def get_bprop_extract_volume_patches(self):
"""Generate bprop for ExtractVolumePatches"""

View File

@ -153,8 +153,6 @@ from .reduce_prod import _reduce_prod_aicpu
from .reduce_mean import _reduce_mean_aicpu
from .resize_bilinear import _resize_bilinear_aicpu
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
from .resize_nearest_neighbor_v2 import _resize_nearest_neighbor_v2_aicpu
from .resize_nearest_neighbor_v2_grad import _resize_nearest_neighbor_v2_grad_aicpu
from .scatter_elements import _scatter_elements_aicpu
from .non_max_suppression import _non_max_suppression_aicpu
from .square import _square_aicpu

View File

@ -1,42 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResizeNearestNeighborV2 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
resize_nearest_neighbor_v2_op_info = AiCPURegOp("ResizeNearestNeighborV2") \
.fusion_type("OPAQUE") \
.attr("align_corners", "bool") \
.attr("half_pixel_centers", "bool") \
.attr("format", "str") \
.input(0, "x", "required") \
.input(1, "size", "required") \
.output(0, "y", "dynamic") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(resize_nearest_neighbor_v2_op_info)
def _resize_nearest_neighbor_v2_aicpu():
"""ResizeNearestNeighborV2 AiCPU register"""
return

View File

@ -1,42 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResizeNearestNeighborV2Grad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
resize_nearest_neighbor_v2_grad_op_info = AiCPURegOp("ResizeNearestNeighborV2Grad") \
.fusion_type("OPAQUE") \
.attr("align_corners", "bool") \
.attr("half_pixel_centers", "bool") \
.attr("format", "str") \
.input(0, "grads", "required") \
.input(1, "size", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(resize_nearest_neighbor_v2_grad_op_info)
def _resize_nearest_neighbor_v2_grad_aicpu():
"""ResizeNearestNeighborV2Grad AiCPU register"""
return

View File

@ -1879,29 +1879,6 @@ class ResizeLinear1DGrad(Primitive):
"coordinate_transformation_mode", self.name)
class ResizeNearestNeighborV2Grad(Primitive):
"""
Compute gradient of `ResizeNearestNeighborV2` operator.
Args:
align_corners (bool): Whether the centers of the 4 corner pixels of the input
and output tensors are aligned. Default: False.
half_pixel_centers (bool): Default :False.
data_format: An optional `string` that describes the format of the input `x` Defaults to `NHWC`.
"""
@prim_attr_register
def __init__(self, align_corners=False, half_pixel_centers=False, data_format='NHWC'):
"""Initialize ResizeNearestNeighborV2Grad"""
self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
validator.check_value_type('align_corners', align_corners, [bool], self.name)
validator.check_value_type('half_pixel_centers', half_pixel_centers, [bool], self.name)
validator.check_value_type('data_format', data_format, [str], self.name)
self.format = validator.check_string(data_format, ['NHWC', 'NCHW'], 'data_format', self.name)
self.add_prim_attr('data_format', self.format)
class ROIAlignGrad(PrimitiveWithInfer):
"""
ROIAlignGrad operator.

View File

@ -4181,73 +4181,6 @@ class ResizeNearestNeighbor(Primitive):
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
class ResizeNearestNeighborV2(Primitive):
r"""
Resizes the input tensor to specific size by using the nearest neighbor algorithm.
Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
neighbor algorithm selects the value of the nearest point and does not consider the
values of neighboring points at all, yielding a piecewise-constant interpolant.
Args:
align_corners: An optional `bool`. Defaults to `False`.
If true, the centers of the 4 corner pixels of the input and output tensors are
aligned, preserving the values at the corner pixels.
half_pixel_centers: An optional `bool`. Defaults to `False`.
data_format: An optional `string` that describes the format of the input `x`. Defaults to `NHWC`.
Inputs:
- **x** (Tensor) - 4-D with shape `[batch, height, width, channels]` or `[batch, channels, height, width]`
depending on the attr 'data_format'. Support type [`int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
`float16`, `float32`, `float64`].
- **size** (Tensor) - A 1-D int32 Tensor of 2 elements: [`new_height, new_width`]. The new size for the images.
Outputs:
Tensor `y`, has the same type as input `x` with the shape of `[batch, channels, new_height, new_width]` or
`[batch, new_height, new_width, channels]` depending on attr 'data_format'.
Raises:
TypeError: If `x` or `size` is not a Tensor.
TypeError: If `x` data type not in support list.
TypeError: If `size` data type is not int32.
TypeError: If `align_corners` or `half_pixel_centers` is not `bool` value.
TypeError: If `data_format` is not `str`.
ValueError: If `data_format` not in [`NHWC`, `NCHW`].
ValueError: If any value of `size` is non positive.
ValueError: If the dimension of `x` is not 4.
ValueError: If the dimension of `size` is not 1.
ValueError: If the elements number of `size` is not 2.
ValueError: If attr `half_pixel_centers` and `align_corners` are True at the same time.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> input_tensor = Tensor(np.ones((1, 4, 4, 1)), mstype.float32)
>>> size = Tensor([2, 2], mstype.int32)
>>> resize = ops.ResizeNearestNeighborV2()
>>> output = resize(input_tensor, size)
>>> print(output)
[[[[1.]
[1.]]
[[1.]
[1.]]]]
>>> print(output.shape)
(1, 2, 2, 1)
"""
@prim_attr_register
def __init__(self, align_corners=False, half_pixel_centers=False, data_format='NHWC'):
"""Initialize ResizeNearestNeighborV2"""
self.init_prim_io_names(inputs=['x', 'size'], outputs=['y'])
validator.check_bool(align_corners, 'align_corners', self.name)
validator.check_bool(half_pixel_centers, 'half_pixel_centers', self.name)
validator.check_value_type('data_format', data_format, [str], self.name)
self.format = validator.check_string(data_format, ['NHWC', 'NCHW'], 'data_format', self.name)
self.add_prim_attr('data_format', self.format)
class GatherNd(Primitive):
r"""
Gathers slices from a tensor by indices.

View File

@ -51,8 +51,6 @@ from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.random_ops import TruncatedNormal
from mindspore.ops.operations.other_ops import SampleDistortedBoundingBoxV2
from mindspore.ops.operations.array_ops import Triu
from mindspore.ops.operations.array_ops import ResizeNearestNeighborV2
from mindspore.ops.operations._grad_ops import ResizeNearestNeighborV2Grad
from mindspore.ops.operations.array_ops import MatrixDiagV3
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
@ -2688,16 +2686,6 @@ test_case_nn_ops = [
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
'skip': ['backward']}),
('ResizeNearestNeighborV2', {
'block': ResizeNearestNeighborV2(),
'desc_inputs': [Tensor(np.random.rand(16, 16, 32, 32).astype(np.float32)),
Tensor(np.array([8, 8]).astype(np.int32))],
'desc_bprop': [Tensor(np.random.rand(16, 16, 8, 8).astype(np.float32))]}),
('ResizeNearestNeighborV2Grad', {
'block': ResizeNearestNeighborV2Grad(),
'desc_inputs': [Tensor(np.random.rand(16, 16, 8, 8).astype(np.float32)),
Tensor(np.array([32, 32]).astype(np.int32))],
'skip': ['backward']}),
('ROIAlign', {
'block': P.ROIAlign(7, 7, 0.03125, 2),
'desc_inputs': [[2, 256, 192, 320], [1024, 5]],