!40472 [assistant][ops] Add new aicpu operator AffineGridGrad
Merge pull request !40472 from 姚高超/AffineGridGrad
This commit is contained in:
commit
d733827a2f
|
@ -0,0 +1,325 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/affine_grid_grad_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kAffineGridGradInputsNum = 2;
|
||||
constexpr size_t kAffineGridGradOutputsNum = 1;
|
||||
|
||||
#define AFFINEGRIDGRAD_LAUNCH_CASE(DTYPE, TYPE, DTYPE0, INPUTS, OUTPUTS) \
|
||||
case DTYPE: { \
|
||||
if ((DTYPE0) == kNumberTypeInt32) { \
|
||||
LaunchKernel<TYPE, int32_t>(INPUTS, OUTPUTS); \
|
||||
} else { \
|
||||
LaunchKernel<TYPE, int64_t>(INPUTS, OUTPUTS); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
const int kRowNum1 = 1;
|
||||
const int kColNum1 = 1;
|
||||
const int kRowNum2 = 2;
|
||||
const int kColNum2 = 2;
|
||||
const int kRowNum3 = 3;
|
||||
const int kColNum3 = 3;
|
||||
const int kRowNum4 = 4;
|
||||
const int kColNum4 = 4;
|
||||
const int kLenXSize3D = 4;
|
||||
const int kLenXSize4D = 5;
|
||||
const int kXSizeH3D = 2;
|
||||
const int kXSizeW3D = 3;
|
||||
const int kXSizeD4D = 2;
|
||||
const int kXSizeH4D = 3;
|
||||
const int kXSizeW4D = 4;
|
||||
} // namespace
|
||||
|
||||
bool AffineGridGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto prim = base_operator->GetPrim();
|
||||
align_corners_ = GetValue<bool>(prim->GetAttr("align_corners"));
|
||||
auto type_id = inputs[0]->GetDtype();
|
||||
input_info_.push_back(type_id);
|
||||
type_id = inputs[1]->GetDtype();
|
||||
input_info_.push_back(type_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
int AffineGridGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
x_size_dims_ = inputs[1]->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
void AffineGridGradCpuKernelMod::LaunchKernel_3D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
|
||||
int64_t H = x_size_data[kXSizeH3D];
|
||||
int64_t W = x_size_data[kXSizeW3D];
|
||||
Eigen::VectorXf vecX, vecY;
|
||||
vecX.setZero(W, 1);
|
||||
vecY.setZero(H, 1);
|
||||
if (W != 1) {
|
||||
vecX = Eigen::VectorXf::LinSpaced(vecX.size(), -1.0, 1.0);
|
||||
}
|
||||
if (H != 1) {
|
||||
vecY = Eigen::VectorXf::LinSpaced(vecY.size(), -1.0, 1.0);
|
||||
}
|
||||
if (!align_corners_) {
|
||||
float x_ = static_cast<float>((W - 1)) / static_cast<float>(W);
|
||||
float y_ = static_cast<float>((H - 1)) / static_cast<float>(H);
|
||||
for (int64_t i = 0; i < W; i++) {
|
||||
vecX[i] = vecX[i] * x_;
|
||||
}
|
||||
for (int64_t i = 0; i < H; i++) {
|
||||
vecY[i] = vecY[i] * y_;
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::MatrixXf all(kRowNum3, W * H);
|
||||
all = make_base_grid_3D<T0>(inputs, vecX, vecY);
|
||||
DoCompute_3D<T, T0>(inputs, outputs, all);
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
Eigen::MatrixXf AffineGridGradCpuKernelMod::make_base_grid_3D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
Eigen::VectorXf vecX, Eigen::VectorXf vecY) {
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
int64_t H = x_size_data[kXSizeH3D];
|
||||
int64_t W = x_size_data[kXSizeW3D];
|
||||
Eigen::MatrixXf all(kRowNum3, W * H);
|
||||
int64_t datanums = H * W;
|
||||
auto task1 = [&](const int64_t start, const int64_t end) {
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
int64_t j = i % W;
|
||||
int64_t k = i / W;
|
||||
all(0, k * W + j) = vecX(j);
|
||||
all(kRowNum1, k * W + j) = vecY(k);
|
||||
all(kRowNum2, k * W + j) = 1.0;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task1, datanums, this, ¶llel_search_info_);
|
||||
return all;
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
void AffineGridGradCpuKernelMod::DoCompute_3D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, Eigen::MatrixXf all) {
|
||||
auto data_y_grad = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(data_y_grad);
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
int64_t N = x_size_data[0];
|
||||
int64_t H = x_size_data[kXSizeH3D];
|
||||
int64_t W = x_size_data[kXSizeW3D];
|
||||
|
||||
Eigen::MatrixXf y_grad(H * W, kColNum2);
|
||||
Eigen::MatrixXf result(kRowNum3, kColNum2);
|
||||
float result_0;
|
||||
float result_1;
|
||||
float result_2;
|
||||
int64_t k_num = 0;
|
||||
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
for (int64_t k = 0; k < H * W; k++) {
|
||||
y_grad(k, 0) = static_cast<float>(*(data_y_grad + (n * H * W * kColNum2 + k * kColNum2)));
|
||||
y_grad(k, kColNum1) = static_cast<float>(*(data_y_grad + (n * H * W * kColNum2 + k * kColNum2) + kColNum1));
|
||||
}
|
||||
result = all * y_grad;
|
||||
|
||||
for (int64_t k = 0; k < kColNum2; k++) {
|
||||
result_0 = result(0, k);
|
||||
result_1 = result(kRowNum1, k);
|
||||
result_2 = result(kRowNum2, k);
|
||||
*(output + k_num) = static_cast<T>(result_0);
|
||||
*(output + k_num + kColNum1) = static_cast<T>(result_1);
|
||||
*(output + k_num + kColNum2) = static_cast<T>(result_2);
|
||||
k_num += kColNum3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
void AffineGridGradCpuKernelMod::LaunchKernel_4D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
int64_t D = x_size_data[kXSizeD4D];
|
||||
int64_t H = x_size_data[kXSizeH4D];
|
||||
int64_t W = x_size_data[kXSizeW4D];
|
||||
|
||||
Eigen::VectorXf vecX, vecY, vecZ;
|
||||
vecX.setZero(W, 1);
|
||||
vecY.setZero(H, 1);
|
||||
vecZ.setZero(D, 1);
|
||||
if (W != 1) {
|
||||
vecX = Eigen::VectorXf::LinSpaced(vecX.size(), -1.0, 1.0);
|
||||
}
|
||||
if (H != 1) {
|
||||
vecY = Eigen::VectorXf::LinSpaced(vecY.size(), -1.0, 1.0);
|
||||
}
|
||||
if (D != 1) {
|
||||
vecZ = Eigen::VectorXf::LinSpaced(vecZ.size(), -1.0, 1.0);
|
||||
}
|
||||
if (!align_corners_) {
|
||||
float x_ = static_cast<float>((W - 1)) / static_cast<float>(W);
|
||||
float y_ = static_cast<float>((H - 1)) / static_cast<float>(H);
|
||||
float z_ = static_cast<float>((D - 1)) / static_cast<float>(D);
|
||||
for (int64_t i = 0; i < W; i++) {
|
||||
vecX[i] = vecX[i] * x_;
|
||||
}
|
||||
for (int64_t i = 0; i < H; i++) {
|
||||
vecY[i] = vecY[i] * y_;
|
||||
}
|
||||
for (int64_t i = 0; i < D; i++) {
|
||||
vecZ[i] = vecZ[i] * z_;
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::MatrixXf all(kRowNum4, D * W * H);
|
||||
all = make_base_grid_4D<T0>(inputs, vecX, vecY, vecZ);
|
||||
DoCompute_4D<T, T0>(inputs, outputs, all);
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
Eigen::MatrixXf AffineGridGradCpuKernelMod::make_base_grid_4D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
Eigen::VectorXf vecX, Eigen::VectorXf vecY,
|
||||
Eigen::VectorXf vecZ) {
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
int64_t D = x_size_data[kXSizeD4D];
|
||||
int64_t H = x_size_data[kXSizeH4D];
|
||||
int64_t W = x_size_data[kXSizeW4D];
|
||||
Eigen::MatrixXf all(kRowNum4, D * W * H);
|
||||
int64_t datanums = D * H * W;
|
||||
auto task1 = [&](const int64_t start, const int64_t end) {
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
int64_t m = i / (H * W);
|
||||
int64_t j = (i % (H * W)) / W;
|
||||
int64_t k = (i % (H * W)) % W;
|
||||
all(0, m * H * W + j * W + k) = vecX(k);
|
||||
all(kRowNum1, m * H * W + j * W + k) = vecY(j);
|
||||
all(kRowNum2, m * H * W + j * W + k) = vecZ(m);
|
||||
all(kRowNum3, m * H * W + j * W + k) = 1.0;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task1, datanums, this, ¶llel_search_info_);
|
||||
return all;
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
void AffineGridGradCpuKernelMod::DoCompute_4D(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, Eigen::MatrixXf all) {
|
||||
auto data_y_grad = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(data_y_grad);
|
||||
auto x_size_data = reinterpret_cast<T0 *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(x_size_data);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
int64_t N = x_size_data[0];
|
||||
int64_t D = x_size_data[kXSizeD4D];
|
||||
int64_t H = x_size_data[kXSizeH4D];
|
||||
int64_t W = x_size_data[kXSizeW4D];
|
||||
|
||||
Eigen::MatrixXf y_grad(D * H * W, kColNum3);
|
||||
Eigen::MatrixXf result(kRowNum4, kColNum3);
|
||||
float result_0;
|
||||
float result_1;
|
||||
float result_2;
|
||||
float result_3;
|
||||
int64_t k_num = 0;
|
||||
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
for (int64_t k = 0; k < D * H * W; k++) {
|
||||
y_grad(k, 0) = static_cast<float>(*(data_y_grad + (n * D * H * W * kColNum3 + k * kColNum3) + 0));
|
||||
y_grad(k, kColNum1) = static_cast<float>(*(data_y_grad + (n * D * H * W * kColNum3 + k * kColNum3) + kColNum1));
|
||||
y_grad(k, kColNum2) = static_cast<float>(*(data_y_grad + (n * D * H * W * kColNum3 + k * kColNum3) + kColNum2));
|
||||
}
|
||||
result = all * y_grad;
|
||||
for (int64_t k = 0; k < kColNum3; k++) {
|
||||
result_0 = result(0, k);
|
||||
result_1 = result(kRowNum1, k);
|
||||
result_2 = result(kRowNum2, k);
|
||||
result_3 = result(kRowNum3, k);
|
||||
*(output + k_num) = static_cast<T>(result_0);
|
||||
*(output + k_num + kColNum1) = static_cast<T>(result_1);
|
||||
*(output + k_num + kColNum2) = static_cast<T>(result_2);
|
||||
*(output + k_num + kColNum3) = static_cast<T>(result_3);
|
||||
k_num += kColNum4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
bool AffineGridGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (x_size_dims_[0] == kLenXSize3D) {
|
||||
LaunchKernel_3D<T, T0>(inputs, outputs);
|
||||
} else if (x_size_dims_[0] == kLenXSize4D) {
|
||||
LaunchKernel_4D<T, T0>(inputs, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AffineGridGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
constexpr int INPUTSNUM = 1;
|
||||
// CheckParams();
|
||||
TypeId input_type = input_info_[0];
|
||||
TypeId x_size_type = input_info_[INPUTSNUM];
|
||||
switch (input_type) {
|
||||
AFFINEGRIDGRAD_LAUNCH_CASE(kNumberTypeFloat16, float16, x_size_type, inputs, outputs)
|
||||
AFFINEGRIDGRAD_LAUNCH_CASE(kNumberTypeFloat32, float, x_size_type, inputs, outputs)
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported input data type: " << TypeIdLabel(input_type)
|
||||
<< ".";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
using AffineGridGradPair = std::pair<KernelAttr, AffineGridGradCpuKernelMod::KernelRunFunc>;
|
||||
const std::vector<AffineGridGradPair> &AffineGridGradCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, AffineGridGradCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr().AddSkipCheckAttr(true), &AffineGridGradCpuKernelMod::Launch},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AffineGridGrad, AffineGridGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_AFFINEGRIDGRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_AFFINEGRIDGRAD_CPU_KERNEL_H_
|
||||
|
||||
#include <Eigen/Dense>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AffineGridGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<AffineGridGradCpuKernelMod> {
|
||||
public:
|
||||
AffineGridGradCpuKernelMod() = default;
|
||||
~AffineGridGradCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename T0>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T, typename T0>
|
||||
void LaunchKernel_3D(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T, typename T0>
|
||||
void LaunchKernel_4D(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T0>
|
||||
Eigen::MatrixXf make_base_grid_3D(const std::vector<kernel::AddressPtr> &inputs, Eigen::VectorXf vecX,
|
||||
Eigen::VectorXf vecY);
|
||||
template <typename T, typename T0>
|
||||
void DoCompute_3D(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
|
||||
Eigen::MatrixXf all);
|
||||
template <typename T0>
|
||||
Eigen::MatrixXf make_base_grid_4D(const std::vector<kernel::AddressPtr> &inputs, Eigen::VectorXf vecX,
|
||||
Eigen::VectorXf vecY, Eigen::VectorXf vecZ);
|
||||
template <typename T, typename T0>
|
||||
void DoCompute_4D(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
|
||||
Eigen::MatrixXf all);
|
||||
using AffineGridGradFunc =
|
||||
std::function<bool(AffineGridGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
std::vector<int64_t> x_size_dims_;
|
||||
bool align_corners_{false};
|
||||
std::vector<TypeId> input_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_AFFINEGRIDGRAD_CPU_KERNEL_H_
|
|
@ -232,6 +232,7 @@ constexpr auto kScatterAddWithAxis = "ScatterAddWithAxis";
|
|||
constexpr auto kCSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
|
||||
constexpr auto kSlice = "Slice";
|
||||
constexpr auto kAffineGrid = "AffineGrid";
|
||||
constexpr auto kAffineGridGrad = "AffineGridGrad";
|
||||
constexpr auto kGatherDGrad = "GatherDGrad";
|
||||
constexpr auto kGatherDGradV2 = "GatherDGradV2";
|
||||
constexpr auto kSparseTensorToCSRSparseMatrix = "SparseTensorToCSRSparseMatrix";
|
||||
|
@ -741,6 +742,7 @@ GVAR_DEF(PrimitivePtr, kPrimSegmentMax, std::make_shared<Primitive>(kSegmentMax)
|
|||
GVAR_DEF(PrimitivePtr, kPrimSegmentMin, std::make_shared<Primitive>(kSegmentMin));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentSum, std::make_shared<Primitive>(kSegmentSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAffineGrid, std::make_shared<Primitive>(kAffineGrid));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAffineGridGrad, std::make_shared<Primitive>(kAffineGridGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentMean, std::make_shared<Primitive>(kSegmentMean));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentProd, std::make_shared<Primitive>(kSegmentProd));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSparseMinimum, std::make_shared<Primitive>(kSparseSparseMinimum));
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/affine_grid_grad.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int kLenImageSize2D = 4;
|
||||
constexpr int kLenImageSize3D = 5;
|
||||
|
||||
abstract::ShapePtr AffineGridGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
|
||||
auto y_grad_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
|
||||
if (y_grad_shape_ptr->IsDynamic()) {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||
}
|
||||
auto x_size_arg = input_args[kInputIndex1];
|
||||
auto x_size_value_ptr = x_size_arg->BuildValue();
|
||||
if ((x_size_arg->isa<abstract::AbstractTuple>() && x_size_value_ptr->isa<ValueTuple>()) ||
|
||||
(x_size_arg->isa<abstract::AbstractTensor>() && x_size_value_ptr->isa<tensor::Tensor>())) {
|
||||
ShapeVector x_size_val;
|
||||
if (x_size_value_ptr->isa<ValueTuple>()) {
|
||||
x_size_val = CheckAndConvertUtils::CheckTupleInt("input[x_size]", x_size_value_ptr, prim_name);
|
||||
} else if (x_size_value_ptr->isa<tensor::Tensor>()) { // 2-rd infer will be a tensor
|
||||
x_size_val = CheckAndConvertUtils::CheckTensorIntValue("x_size", x_size_value_ptr, prim_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', "
|
||||
<< "the input[x_size] must be a tuple of int.";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("x_size", x_size_val, prim_name);
|
||||
int64_t x_size_val_size = SizeToLong(x_size_val.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("size of 'x_size'", x_size_val_size, kIncludeBoth,
|
||||
{kLenImageSize2D, kLenImageSize3D}, prim_name);
|
||||
if (x_size_val[kInputIndex0] != y_grad_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "the x_size[0] must be equal to the shape[0] of y_grad, "
|
||||
<< "but got the x_size[0] is " << x_size_val[kInputIndex0]
|
||||
<< " and the shape[0] of y_grad is " << y_grad_shape[kInputIndex0] << ".";
|
||||
}
|
||||
auto y_grad_rank = SizeToLong(y_grad_shape.size());
|
||||
ShapeVector x_grad_shape;
|
||||
if (x_size_val_size == kLenImageSize2D) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'y_grad'", y_grad_rank, kEqual, kLenImageSize2D, prim_name);
|
||||
if (y_grad_shape[kInputIndex1] == x_size_val[kInputIndex2] &&
|
||||
y_grad_shape[kInputIndex2] == x_size_val[kInputIndex3] && y_grad_shape[kInputIndex3] == kInputIndex2) {
|
||||
auto N = static_cast<int64_t>(x_size_val[kInputIndex0]);
|
||||
x_grad_shape = {N, kInputIndex2, kInputIndex3};
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "the shape of 'y_grad' must be [N, H, W, 2] and "
|
||||
<< "the shape of 'x_size' must be [N, C, H, W] for 2D; "
|
||||
<< "But got the shape of 'y_grad is [" << y_grad_shape[kInputIndex0] << ", "
|
||||
<< y_grad_shape[kInputIndex1] << ", " << y_grad_shape[kInputIndex2] << ", "
|
||||
<< y_grad_shape[kInputIndex3] << "] "
|
||||
<< "and the size of 'x_size' is [" << x_size_val[kInputIndex0] << ", "
|
||||
<< x_size_val[kInputIndex1] << ", " << x_size_val[kInputIndex2] << ", "
|
||||
<< x_size_val[kInputIndex3] << "] ";
|
||||
}
|
||||
} else if (x_size_val_size == kLenImageSize3D) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'y_grad'", y_grad_rank, kEqual, kLenImageSize3D, prim_name);
|
||||
if (y_grad_shape[kInputIndex1] == x_size_val[kInputIndex2] &&
|
||||
y_grad_shape[kInputIndex2] == x_size_val[kInputIndex3] &&
|
||||
y_grad_shape[kInputIndex3] == x_size_val[kInputIndex4] && y_grad_shape[kInputIndex4] == kInputIndex3) {
|
||||
auto N = static_cast<int64_t>(x_size_val[kInputIndex0]);
|
||||
x_grad_shape = {N, kInputIndex3, kInputIndex4};
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "the shape of 'y_grad' must be [N, D, H, W, 3] and "
|
||||
<< "the shape of 'x_size' must be [N, C, D, H, W] for 3D; "
|
||||
<< "But got the shape of 'y_grad is [" << y_grad_shape[kInputIndex0] << ", "
|
||||
<< y_grad_shape[kInputIndex1] << ", " << y_grad_shape[kInputIndex2] << ", "
|
||||
<< y_grad_shape[kInputIndex3] << ", " << y_grad_shape[kInputIndex4] << "] "
|
||||
<< "and the size of 'x_size' is [" << x_size_val[kInputIndex0] << ", "
|
||||
<< x_size_val[kInputIndex1] << ", " << x_size_val[kInputIndex2] << ", "
|
||||
<< x_size_val[kInputIndex3] << ", " << x_size_val[kInputIndex4] << "] ";
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "the size of 'x_size' must be 4 for 2D or 5 for 3D. ";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(x_grad_shape);
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr AffineGridGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::string op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, kInputIndex0);
|
||||
auto y_grad_type = input_args[kInputIndex0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(y_grad_type);
|
||||
const std::set<TypePtr> y_grad_valid_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("y_grad", y_grad_type, y_grad_valid_types, op_name);
|
||||
auto x_size_type = input_args[kInputIndex1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_size_type);
|
||||
const std::set<TypePtr> x_size_valid_types = {kTensorType, kTuple}; // 2-rd infer will be a tensor.
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("x_size", x_size_type, x_size_valid_types, op_name);
|
||||
return y_grad_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AffineGridGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto infer_type = AffineGridGradInferType(primitive, input_args);
|
||||
auto infer_shape = AffineGridGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
void AffineGridGrad::Init(const bool align_corners) { set_align_corners(align_corners); }
|
||||
|
||||
void AffineGridGrad::set_align_corners(const bool align_corners) {
|
||||
(void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners));
|
||||
}
|
||||
|
||||
bool AffineGridGrad::get_align_corners() const {
|
||||
auto value_ptr = GetAttr(kAlignCorners);
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_HOST_DEPENDS(kNameAffineGridGrad, {1});
|
||||
MIND_API_OPERATOR_IMPL(AffineGridGrad, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AffineGridGrad, prim::kPrimAffineGridGrad, AffineGridGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_AFFINE_GRID_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_AFFINE_GRID_GRAD_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAffineGridGrad = "AffineGridGrad";
|
||||
class MIND_API AffineGridGrad : public BaseOperator {
|
||||
public:
|
||||
AffineGridGrad() : BaseOperator(kNameAffineGridGrad) { InitIOName({"y_grad", "x_size"}, {"x_grad"}); }
|
||||
MIND_API_BASE_MEMBER(AffineGridGrad);
|
||||
void Init(const bool align_corners = false);
|
||||
void set_align_corners(const bool align_corners);
|
||||
bool get_align_corners() const;
|
||||
};
|
||||
abstract::AbstractBasePtr AffineGridGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_AFFINE_GRID_GRAD_H_
|
|
@ -54,6 +54,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore import context
|
||||
|
||||
|
||||
@bprop_getters.register(StridedSliceV2)
|
||||
|
@ -619,6 +620,7 @@ def get_bprop_affinegrid(self):
|
|||
"""Generate bprop for AffineGrid"""
|
||||
|
||||
align_corners = self.align_corners
|
||||
input_grad = G.AffineGridGrad(align_corners)
|
||||
ones = P.Ones()
|
||||
transpose = P.Transpose()
|
||||
concat = P.Concat(1)
|
||||
|
@ -834,12 +836,19 @@ def get_bprop_affinegrid(self):
|
|||
dtheta = transpose(dtheta, perm2)
|
||||
return dtheta, tre
|
||||
|
||||
def bprop(theta, output_size, out, dout):
|
||||
def bprop_gpu(theta, output_size, out, dout):
|
||||
is_tensor, _ = convert_to_tensor(output_size)
|
||||
if is_tensor:
|
||||
return dyn_bprop(theta, output_size, out, dout)
|
||||
return static_bprop(theta, output_size, out, dout)
|
||||
|
||||
def bprop(theta, output_size, out, dout):
|
||||
dx = input_grad(dout, output_size)
|
||||
return dx, zeros_like(output_size)
|
||||
|
||||
if context.get_context('device_target') == "GPU":
|
||||
return bprop_gpu
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AffineGridGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
affine_grid_grad_op_info = AiCPURegOp("AffineGridGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("align_corners", "bool")\
|
||||
.input(0, "y_grad", "required") \
|
||||
.input(1, "x_size", "required") \
|
||||
.output(0, "x_grad", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(affine_grid_grad_op_info)
|
||||
def _affine_grid_grad_aicpu():
|
||||
"""AffineGridGrad aicpu register"""
|
||||
return
|
|
@ -3766,6 +3766,42 @@ class FractionalMaxPoolGradWithFixedKsize(Primitive):
|
|||
self.init_prim_io_names(inputs=['origin_input', 'out_backprop', 'argmax'], outputs=['y'])
|
||||
|
||||
|
||||
class AffineGridGrad(Primitive):
|
||||
r"""
|
||||
Computes gradients for AffineGrid operation.
|
||||
|
||||
Args:
|
||||
align_corners (bool): if True, consider -1 and 1 to refer to the centers
|
||||
of the corner pixels rather than the image corners. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **y_grad** (Tensor) - Data type must be float16 or float32.
|
||||
- **x_size** (tuple) - Data type must be int32 or int64.
|
||||
|
||||
Outputs:
|
||||
Tensor, with data type same as `y_grad`.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops.operations._grad_ops as _grad_ops
|
||||
>>> affinegridgrad = _grad_ops.AffineGridGrad()
|
||||
>>> y_grad = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
|
||||
>>> x_size = (1, 2, 2, 2)
|
||||
>>> x_grad = affinegridgrad(y_grad, x_size)
|
||||
>>> print(x_grad)
|
||||
[[[0. 0. 4.]
|
||||
[0. 0. 4.]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, align_corners=False):
|
||||
"""Initialize AffineGridGrad."""
|
||||
validator.check_value_type("align_corners", align_corners, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
|
||||
|
||||
|
||||
class HSigmoidGrad(Primitive):
|
||||
"""Gets the gradient of HSigmoid operation."""
|
||||
@prim_attr_register
|
||||
|
|
|
@ -205,6 +205,7 @@ from mindspore.ops.operations.array_ops import RightShift
|
|||
from mindspore.ops.operations.array_ops import LeftShift
|
||||
from mindspore.ops.operations.array_ops import Expand
|
||||
from mindspore.ops.operations.array_ops import HammingWindow
|
||||
from mindspore.ops.operations.array_ops import AffineGrid
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyMomentum
|
||||
from mindspore.ops.operations.nn_ops import AdaptiveAvgPool3D
|
||||
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool3D
|
||||
|
@ -3676,6 +3677,10 @@ test_case_array_ops = [
|
|||
'block': P.DepthToSpace(2),
|
||||
'desc_inputs': [[1, 12, 1, 1]],
|
||||
'desc_bprop': [[1, 3, 2, 2]]}),
|
||||
('AffineGrid', {
|
||||
'block': AffineGrid(align_corners=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(1, 2, 3), mstype.float32), (1, 1, 1, 2)],
|
||||
'desc_bprop': [Tensor(np.random.rand(1, 1, 2, 2), mstype.float32)]}),
|
||||
('Split', {
|
||||
'block': P.Split(1, 2),
|
||||
'desc_inputs': [Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))],
|
||||
|
|
Loading…
Reference in New Issue