New add gatherd_v2_grad aicpu ops.

This commit is contained in:
linqingke 2022-07-11 14:31:27 +08:00
parent 0f0263820f
commit 0757f21796
11 changed files with 425 additions and 6 deletions

View File

@ -19,6 +19,7 @@
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"

View File

@ -368,6 +368,7 @@ constexpr auto kEnvironCreateOpName = "EnvironCreate";
constexpr auto kEnvironSetOpName = "EnvironSet";
constexpr auto kEnvironGetOpName = "EnvironGet";
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
constexpr auto kGatherDGradV2OpName = "GatherDGradV2";
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
constexpr auto kResizeBicubicOpName = "ResizeBicubic";
constexpr auto kUpdateStateOpName = "UpdateState";

View File

@ -29,6 +29,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_pulse.cc
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_sharder.cc
${CMAKE_CURRENT_SOURCE_DIR}/random_choice_with_mask_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/gather_grad_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/environ/aicpu_environ_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_create.cc
${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_set.cc

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_
#define AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_
#include <cstdint>
#include "common/kernel_log.h"
namespace aicpu {
template <typename T, typename S>
void AtomicAddTask(T *const address, const T val) {
auto *address_as_ull = reinterpret_cast<S *>(address);
S old = *address_as_ull;
S assumed;
T desired;
T *assumed_t = nullptr;
S *desired_s = nullptr;
do {
assumed = old;
assumed_t = reinterpret_cast<T *>(&assumed);
desired_s = reinterpret_cast<S *>(&desired);
desired = *assumed_t + static_cast<T>(val);
old = __sync_val_compare_and_swap(address_as_ull, assumed, *desired_s);
} while (assumed != old);
}
template <typename T>
void AtomicAdd(T *const address, const T val) {
switch (sizeof(T)) {
case sizeof(uint8_t): {
AtomicAddTask<T, uint8_t>(address, val);
break;
}
case sizeof(int16_t): {
AtomicAddTask<T, int16_t>(address, val);
break;
}
case sizeof(int32_t): {
AtomicAddTask<T, int32_t>(address, val);
break;
}
case sizeof(int64_t): {
AtomicAddTask<T, int64_t>(address, val);
break;
}
default:
AICPU_LOGE("Unsupported aicpu atomic add format!");
}
}
} // namespace aicpu
#endif // AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_

View File

@ -0,0 +1,225 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./gather_grad_kernels.h"
#include <Eigen/Dense>
#include <map>
#include <thread>
#include <numeric>
#include <vector>
#include <functional>
#include "utils/convert_utils_base.h"
#include "common/atomic_op.h"
#include "aicpu_sharder/aicpu_sharder.h"
#include "proto/aicpu_tensor.pb.h"
namespace aicpu {
namespace {
constexpr auto kDim = "dim";
constexpr auto kAddressSize = 4;
constexpr auto kDim0 = 0;
constexpr auto kDim1 = 1;
constexpr auto kDim2 = 2;
constexpr auto kDim3 = 3;
template <typename T, typename S>
static uint32_t GatherGrad(const T *index, const S *grad, S *output, const int64_t dim_before_axis,
const int64_t dim_at_axis_index, const int64_t dim_at_axis_output,
const int64_t dim_after_axis) {
if (dim_after_axis == 0) {
AICPU_LOGE("dim_after_axis cannot be 0.");
return AICPU_KERNEL_STATE_FAILED;
}
int64_t number = dim_before_axis * dim_at_axis_index * dim_after_axis;
bool status = false;
auto shard_gather_grad = [&](size_t start, size_t end) {
int64_t dim_input = dim_at_axis_index * dim_after_axis;
int64_t dim_output = dim_at_axis_output * dim_after_axis;
for (size_t id = start; id < end; ++id) {
T j_read = index[id];
auto max_index = static_cast<T>(dim_at_axis_output);
if (j_read >= max_index || j_read < -max_index) {
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -max_index, max_index, j_read);
AtomicAdd<bool>(&status, true);
return;
}
if (j_read < 0) {
j_read += max_index;
}
int64_t read_id = id / dim_input * dim_output + j_read * dim_after_axis + id % dim_after_axis;
AtomicAdd<S>(output + read_id, grad[id]);
}
};
const int64_t per_unit_size = number / std::thread::hardware_concurrency();
SharderNonBlock::GetInstance().ParallelFor(number, per_unit_size, shard_gather_grad);
if (status) {
return AICPU_KERNEL_STATE_FAILED;
}
return AICPU_KERNEL_STATE_SUCCESS;
}
} // namespace
template <typename T, typename S>
uint32_t GatherDGradV2Kernel::GatherDGradV2Task() {
if (io_addrs_.size() != kAddressSize) {
AICPU_LOGE("GatherDGradV2Kernel's address is invalid");
return AICPU_KERNEL_STATE_FAILED;
}
T *index = reinterpret_cast<T *>(io_addrs_[kDim1]);
S *grad = reinterpret_cast<S *>(io_addrs_[kDim2]);
S *output = reinterpret_cast<S *>(io_addrs_[kDim3]);
int64_t output_rank = static_cast<int64_t>(output_shape_.size());
if (dim_ >= output_rank || dim_ < -output_rank) {
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -output_rank, output_rank, dim_);
return AICPU_KERNEL_STATE_FAILED;
}
if (dim_ < 0) {
dim_ = dim_ + output_rank;
}
int64_t grad_rank = static_cast<int64_t>(grad_shape_.size());
if (dim_ >= grad_rank) {
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -grad_rank, grad_rank, dim_);
return AICPU_KERNEL_STATE_FAILED;
}
int64_t dim_before_axis =
std::accumulate(output_shape_.begin(), output_shape_.begin() + dim_, 1, std::multiplies<int64_t>());
int64_t dim_at_axis_grad = grad_shape_[dim_];
int64_t dim_at_axis_output = output_shape_[dim_];
int64_t dim_after_axis =
std::accumulate(output_shape_.begin() + dim_ + 1, output_shape_.end(), 1, std::multiplies<int64_t>());
int64_t output_size = dim_before_axis * dim_at_axis_output * dim_after_axis * sizeof(T);
if (memset_s(output, output_size, 0x0, output_size)) {
AICPU_LOGE("memset_s failed!");
return AICPU_KERNEL_STATE_FAILED;
}
return GatherGrad(index, grad, output, dim_before_axis, dim_at_axis_grad, dim_at_axis_output, dim_after_axis);
}
uint32_t GatherDGradV2Kernel::ParseKernelParam() {
// ori input
aicpuops::Tensor input_tensor = node_def_.inputs(kDim0);
const auto &input_shape = input_tensor.tensor_shape();
for (int i = 0; i < input_shape.dim_size(); ++i) {
(void)input_shape_.emplace_back(input_shape.dim(i).size());
}
// index input
input_tensor = node_def_.inputs(kDim1);
index_type_ = static_cast<aicpuops::DataType>(input_tensor.tensor_type());
const auto &index_shape = input_tensor.tensor_shape();
for (int i = 0; i < index_shape.dim_size(); ++i) {
(void)index_shape_.emplace_back(index_shape.dim(i).size());
}
// grad input
input_tensor = node_def_.inputs(kDim2);
grad_type_ = static_cast<aicpuops::DataType>(input_tensor.tensor_type());
const auto &grad_shape = input_tensor.tensor_shape();
for (int i = 0; i < grad_shape.dim_size(); ++i) {
(void)grad_shape_.emplace_back(grad_shape.dim(i).size());
}
if (index_shape_ != grad_shape_) {
AICPU_LOGE("the shape of index and grad should be same!");
return AICPU_KERNEL_STATE_PARAM_INVALID;
}
// output
aicpuops::Tensor output_tensor = node_def_.outputs(kDim0);
const auto &output_shape = output_tensor.tensor_shape();
for (int i = 0; i < output_shape.dim_size(); ++i) {
(void)output_shape_.emplace_back(output_shape.dim(i).size());
}
if (output_shape_ != input_shape_) {
AICPU_LOGE("the shape of input and output should be same!");
return AICPU_KERNEL_STATE_PARAM_INVALID;
}
auto node_def_attrs = node_def_.attrs();
dim_ = node_def_attrs[kDim].i();
return AICPU_KERNEL_STATE_SUCCESS;
}
uint32_t GatherDGradV2Kernel::DoCompute() {
std::map<int, std::map<int, std::function<uint32_t()>>> calls;
// index int32
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT8] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int8_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int16_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int32_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int64_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, Eigen::half>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, float>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, double>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT8] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint8_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint16_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint32_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint64_t>, this);
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_BOOL] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, bool>, this);
// index int64
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT8] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int8_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int16_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int32_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int64_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, Eigen::half>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, float>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, double>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT8] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint8_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT16] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint16_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT32] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint32_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT64] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint64_t>, this);
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_BOOL] =
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, bool>, this);
if (calls.find(index_type_) == calls.end()) {
AICPU_LOGE("GatherDGradV2 op don't support index tensor types: %s", typeid(index_type_).name());
return AICPU_KERNEL_STATE_FAILED;
}
return calls[index_type_][grad_type_]();
}
} // namespace aicpu
extern "C" {
__attribute__((visibility("default"))) uint32_t GatherDGradV2(void *param) {
aicpu::GatherDGradV2Kernel gatherd_grad_v2_kernel;
return gatherd_grad_v2_kernel.Compute(param);
}
}

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_
#define AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_
#include <vector>
#include "common/kernel_base.h"
namespace aicpu {
class GatherDGradV2Kernel : public KernelBase {
public:
GatherDGradV2Kernel() : KernelBase("GatherDGradV2Kernel") {}
~GatherDGradV2Kernel() = default;
protected:
uint32_t ParseKernelParam() override;
uint32_t DoCompute() override;
template <typename T, typename S>
uint32_t GatherDGradV2Task();
int64_t dim_{0};
std::vector<int64_t> input_shape_;
std::vector<int64_t> index_shape_;
std::vector<int64_t> grad_shape_;
std::vector<int64_t> output_shape_;
aicpuops::DataType index_type_{aicpuops::DataType::MS_UNKNOWN};
aicpuops::DataType grad_type_{aicpuops::DataType::MS_UNKNOWN};
};
} // namespace aicpu
#endif // AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_

View File

@ -274,7 +274,7 @@ uint32_t RandomChoiceWithMaskKernel::ParseKernelParam() {
extern "C" {
__attribute__((visibility("default"))) uint32_t RandomChoiceWithMask(void *param) {
aicpu::RandomChoiceWithMaskKernel randomChoiceWithMaskKernel;
return randomChoiceWithMaskKernel.Compute(param);
aicpu::RandomChoiceWithMaskKernel random_choice_with_mask_kernel;
return random_choice_with_mask_kernel.Compute(param);
}
}

View File

@ -63,6 +63,7 @@ constexpr auto kHistogram = "Histogram";
constexpr auto kIdentity = "Identity";
constexpr auto kIdentityN = "IdentityN";
constexpr auto kRandomChoiceWithMask = "RandomChoiceWithMask";
constexpr auto kGatherDGradV2 = "GatherDGradV2";
constexpr auto kResizeNearestNeighborV2 = "ResizeNearestNeighborV2";
constexpr auto kResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad";
constexpr auto kUpdateCache = "UpdateCache";
@ -110,7 +111,8 @@ const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
kPriorityReplayBufferPush,
kPriorityReplayBufferSample,
kPriorityReplayBufferUpdate,
kPriorityReplayBufferDestroy};
kPriorityReplayBufferDestroy,
kGatherDGradV2};
const std::set<std::string> kDynamicInputOps{
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName,
kStackPushOpName, kStackPopOpName, kDynamicStitch, kPriorityReplayBufferPush, kPriorityReplayBufferSample,

View File

@ -30,9 +30,9 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
MS_EXCEPTION_IF_NULL(equiv);
static const std::set<std::string> kAICpuOpNames = {
kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName,
kEnvironDestroyAllOpName, kPriorityReplayBufferCreate, kPriorityReplayBufferPush,
kPriorityReplayBufferSample, kPriorityReplayBufferUpdate, kPriorityReplayBufferDestroy};
kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName, kEnvironDestroyAllOpName,
kPriorityReplayBufferCreate, kPriorityReplayBufferPush, kPriorityReplayBufferSample, kPriorityReplayBufferUpdate,
kPriorityReplayBufferDestroy, kGatherDGradV2OpName};
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
if (!node->isa<CNode>()) {

View File

@ -33,6 +33,7 @@ from .log import _log_aicpu
from .padding import _padding_aicpu
from .gather import _gather_aicpu
from .gather_grad import _gather_grad_aicpu
from .gather_d_grad_v2 import _gather_d_grad_v2_aicpu
from .gather_d import _gather_d_aicpu
from .scatter import _scatter_aicpu
from .exp import _exp_aicpu

View File

@ -0,0 +1,79 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GatherDGradV2 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
gather_d_grad_v2_op_info = AiCPURegOp("GatherDGradV2") \
.attr("dim", "int") \
.input(0, "x", "required") \
.input(1, "index", "required") \
.input(2, "grad", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I8_Default, DataType.I32_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default,
DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default,
DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(gather_d_grad_v2_op_info)
def _gather_d_grad_v2_aicpu():
"""GatherDGradV2 aicpu register"""
return