New add gatherd_v2_grad aicpu ops.
This commit is contained in:
parent
0f0263820f
commit
0757f21796
|
@ -19,6 +19,7 @@
|
|||
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -368,6 +368,7 @@ constexpr auto kEnvironCreateOpName = "EnvironCreate";
|
|||
constexpr auto kEnvironSetOpName = "EnvironSet";
|
||||
constexpr auto kEnvironGetOpName = "EnvironGet";
|
||||
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
|
||||
constexpr auto kGatherDGradV2OpName = "GatherDGradV2";
|
||||
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
||||
constexpr auto kResizeBicubicOpName = "ResizeBicubic";
|
||||
constexpr auto kUpdateStateOpName = "UpdateState";
|
||||
|
|
|
@ -29,6 +29,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_pulse.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_sharder.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random_choice_with_mask_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gather_grad_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/environ/aicpu_environ_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_create.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_set.cc
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_
|
||||
#define AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include "common/kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
template <typename T, typename S>
|
||||
void AtomicAddTask(T *const address, const T val) {
|
||||
auto *address_as_ull = reinterpret_cast<S *>(address);
|
||||
S old = *address_as_ull;
|
||||
S assumed;
|
||||
T desired;
|
||||
T *assumed_t = nullptr;
|
||||
S *desired_s = nullptr;
|
||||
do {
|
||||
assumed = old;
|
||||
assumed_t = reinterpret_cast<T *>(&assumed);
|
||||
desired_s = reinterpret_cast<S *>(&desired);
|
||||
desired = *assumed_t + static_cast<T>(val);
|
||||
old = __sync_val_compare_and_swap(address_as_ull, assumed, *desired_s);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AtomicAdd(T *const address, const T val) {
|
||||
switch (sizeof(T)) {
|
||||
case sizeof(uint8_t): {
|
||||
AtomicAddTask<T, uint8_t>(address, val);
|
||||
break;
|
||||
}
|
||||
case sizeof(int16_t): {
|
||||
AtomicAddTask<T, int16_t>(address, val);
|
||||
break;
|
||||
}
|
||||
case sizeof(int32_t): {
|
||||
AtomicAddTask<T, int32_t>(address, val);
|
||||
break;
|
||||
}
|
||||
case sizeof(int64_t): {
|
||||
AtomicAddTask<T, int64_t>(address, val);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
AICPU_LOGE("Unsupported aicpu atomic add format!");
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_OPS_AICPU_COMMON_ATOMIC_OP_H_
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "./gather_grad_kernels.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "common/atomic_op.h"
|
||||
#include "aicpu_sharder/aicpu_sharder.h"
|
||||
#include "proto/aicpu_tensor.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
constexpr auto kDim = "dim";
|
||||
constexpr auto kAddressSize = 4;
|
||||
constexpr auto kDim0 = 0;
|
||||
constexpr auto kDim1 = 1;
|
||||
constexpr auto kDim2 = 2;
|
||||
constexpr auto kDim3 = 3;
|
||||
|
||||
template <typename T, typename S>
|
||||
static uint32_t GatherGrad(const T *index, const S *grad, S *output, const int64_t dim_before_axis,
|
||||
const int64_t dim_at_axis_index, const int64_t dim_at_axis_output,
|
||||
const int64_t dim_after_axis) {
|
||||
if (dim_after_axis == 0) {
|
||||
AICPU_LOGE("dim_after_axis cannot be 0.");
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
int64_t number = dim_before_axis * dim_at_axis_index * dim_after_axis;
|
||||
bool status = false;
|
||||
auto shard_gather_grad = [&](size_t start, size_t end) {
|
||||
int64_t dim_input = dim_at_axis_index * dim_after_axis;
|
||||
int64_t dim_output = dim_at_axis_output * dim_after_axis;
|
||||
for (size_t id = start; id < end; ++id) {
|
||||
T j_read = index[id];
|
||||
auto max_index = static_cast<T>(dim_at_axis_output);
|
||||
if (j_read >= max_index || j_read < -max_index) {
|
||||
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -max_index, max_index, j_read);
|
||||
AtomicAdd<bool>(&status, true);
|
||||
return;
|
||||
}
|
||||
if (j_read < 0) {
|
||||
j_read += max_index;
|
||||
}
|
||||
|
||||
int64_t read_id = id / dim_input * dim_output + j_read * dim_after_axis + id % dim_after_axis;
|
||||
AtomicAdd<S>(output + read_id, grad[id]);
|
||||
}
|
||||
};
|
||||
|
||||
const int64_t per_unit_size = number / std::thread::hardware_concurrency();
|
||||
SharderNonBlock::GetInstance().ParallelFor(number, per_unit_size, shard_gather_grad);
|
||||
if (status) {
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
return AICPU_KERNEL_STATE_SUCCESS;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename S>
|
||||
uint32_t GatherDGradV2Kernel::GatherDGradV2Task() {
|
||||
if (io_addrs_.size() != kAddressSize) {
|
||||
AICPU_LOGE("GatherDGradV2Kernel's address is invalid");
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
T *index = reinterpret_cast<T *>(io_addrs_[kDim1]);
|
||||
S *grad = reinterpret_cast<S *>(io_addrs_[kDim2]);
|
||||
S *output = reinterpret_cast<S *>(io_addrs_[kDim3]);
|
||||
|
||||
int64_t output_rank = static_cast<int64_t>(output_shape_.size());
|
||||
if (dim_ >= output_rank || dim_ < -output_rank) {
|
||||
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -output_rank, output_rank, dim_);
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
if (dim_ < 0) {
|
||||
dim_ = dim_ + output_rank;
|
||||
}
|
||||
int64_t grad_rank = static_cast<int64_t>(grad_shape_.size());
|
||||
if (dim_ >= grad_rank) {
|
||||
AICPU_LOGE("The value of 'dim' should be in [%d %d), but got %d", -grad_rank, grad_rank, dim_);
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
|
||||
int64_t dim_before_axis =
|
||||
std::accumulate(output_shape_.begin(), output_shape_.begin() + dim_, 1, std::multiplies<int64_t>());
|
||||
int64_t dim_at_axis_grad = grad_shape_[dim_];
|
||||
int64_t dim_at_axis_output = output_shape_[dim_];
|
||||
int64_t dim_after_axis =
|
||||
std::accumulate(output_shape_.begin() + dim_ + 1, output_shape_.end(), 1, std::multiplies<int64_t>());
|
||||
int64_t output_size = dim_before_axis * dim_at_axis_output * dim_after_axis * sizeof(T);
|
||||
if (memset_s(output, output_size, 0x0, output_size)) {
|
||||
AICPU_LOGE("memset_s failed!");
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
return GatherGrad(index, grad, output, dim_before_axis, dim_at_axis_grad, dim_at_axis_output, dim_after_axis);
|
||||
}
|
||||
|
||||
uint32_t GatherDGradV2Kernel::ParseKernelParam() {
|
||||
// ori input
|
||||
aicpuops::Tensor input_tensor = node_def_.inputs(kDim0);
|
||||
const auto &input_shape = input_tensor.tensor_shape();
|
||||
for (int i = 0; i < input_shape.dim_size(); ++i) {
|
||||
(void)input_shape_.emplace_back(input_shape.dim(i).size());
|
||||
}
|
||||
|
||||
// index input
|
||||
input_tensor = node_def_.inputs(kDim1);
|
||||
index_type_ = static_cast<aicpuops::DataType>(input_tensor.tensor_type());
|
||||
const auto &index_shape = input_tensor.tensor_shape();
|
||||
for (int i = 0; i < index_shape.dim_size(); ++i) {
|
||||
(void)index_shape_.emplace_back(index_shape.dim(i).size());
|
||||
}
|
||||
|
||||
// grad input
|
||||
input_tensor = node_def_.inputs(kDim2);
|
||||
grad_type_ = static_cast<aicpuops::DataType>(input_tensor.tensor_type());
|
||||
const auto &grad_shape = input_tensor.tensor_shape();
|
||||
for (int i = 0; i < grad_shape.dim_size(); ++i) {
|
||||
(void)grad_shape_.emplace_back(grad_shape.dim(i).size());
|
||||
}
|
||||
if (index_shape_ != grad_shape_) {
|
||||
AICPU_LOGE("the shape of index and grad should be same!");
|
||||
return AICPU_KERNEL_STATE_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// output
|
||||
aicpuops::Tensor output_tensor = node_def_.outputs(kDim0);
|
||||
const auto &output_shape = output_tensor.tensor_shape();
|
||||
for (int i = 0; i < output_shape.dim_size(); ++i) {
|
||||
(void)output_shape_.emplace_back(output_shape.dim(i).size());
|
||||
}
|
||||
if (output_shape_ != input_shape_) {
|
||||
AICPU_LOGE("the shape of input and output should be same!");
|
||||
return AICPU_KERNEL_STATE_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto node_def_attrs = node_def_.attrs();
|
||||
dim_ = node_def_attrs[kDim].i();
|
||||
|
||||
return AICPU_KERNEL_STATE_SUCCESS;
|
||||
}
|
||||
|
||||
uint32_t GatherDGradV2Kernel::DoCompute() {
|
||||
std::map<int, std::map<int, std::function<uint32_t()>>> calls;
|
||||
// index int32
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT8] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int8_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int16_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int32_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_INT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, int64_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, Eigen::half>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, float>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_FLOAT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, double>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT8] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint8_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint16_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint32_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_UINT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, uint64_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT32][aicpuops::DataType::MS_BOOL] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int32_t, bool>, this);
|
||||
// index int64
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT8] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int8_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int16_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int32_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_INT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, int64_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, Eigen::half>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, float>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_FLOAT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, double>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT8] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint8_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT16] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint16_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT32] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint32_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_UINT64] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, uint64_t>, this);
|
||||
calls[aicpuops::DataType::MS_INT64][aicpuops::DataType::MS_BOOL] =
|
||||
std::bind(&GatherDGradV2Kernel::GatherDGradV2Task<int64_t, bool>, this);
|
||||
|
||||
if (calls.find(index_type_) == calls.end()) {
|
||||
AICPU_LOGE("GatherDGradV2 op don't support index tensor types: %s", typeid(index_type_).name());
|
||||
return AICPU_KERNEL_STATE_FAILED;
|
||||
}
|
||||
return calls[index_type_][grad_type_]();
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t GatherDGradV2(void *param) {
|
||||
aicpu::GatherDGradV2Kernel gatherd_grad_v2_kernel;
|
||||
return gatherd_grad_v2_kernel.Compute(param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_
|
||||
#define AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "common/kernel_base.h"
|
||||
|
||||
namespace aicpu {
|
||||
class GatherDGradV2Kernel : public KernelBase {
|
||||
public:
|
||||
GatherDGradV2Kernel() : KernelBase("GatherDGradV2Kernel") {}
|
||||
~GatherDGradV2Kernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t ParseKernelParam() override;
|
||||
uint32_t DoCompute() override;
|
||||
|
||||
template <typename T, typename S>
|
||||
uint32_t GatherDGradV2Task();
|
||||
|
||||
int64_t dim_{0};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> index_shape_;
|
||||
std::vector<int64_t> grad_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
|
||||
aicpuops::DataType index_type_{aicpuops::DataType::MS_UNKNOWN};
|
||||
aicpuops::DataType grad_type_{aicpuops::DataType::MS_UNKNOWN};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_OPS_AICPU_GATHER_GRAD_KERNELS_H_
|
|
@ -274,7 +274,7 @@ uint32_t RandomChoiceWithMaskKernel::ParseKernelParam() {
|
|||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t RandomChoiceWithMask(void *param) {
|
||||
aicpu::RandomChoiceWithMaskKernel randomChoiceWithMaskKernel;
|
||||
return randomChoiceWithMaskKernel.Compute(param);
|
||||
aicpu::RandomChoiceWithMaskKernel random_choice_with_mask_kernel;
|
||||
return random_choice_with_mask_kernel.Compute(param);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ constexpr auto kHistogram = "Histogram";
|
|||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kIdentityN = "IdentityN";
|
||||
constexpr auto kRandomChoiceWithMask = "RandomChoiceWithMask";
|
||||
constexpr auto kGatherDGradV2 = "GatherDGradV2";
|
||||
constexpr auto kResizeNearestNeighborV2 = "ResizeNearestNeighborV2";
|
||||
constexpr auto kResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad";
|
||||
constexpr auto kUpdateCache = "UpdateCache";
|
||||
|
@ -110,7 +111,8 @@ const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
|
|||
kPriorityReplayBufferPush,
|
||||
kPriorityReplayBufferSample,
|
||||
kPriorityReplayBufferUpdate,
|
||||
kPriorityReplayBufferDestroy};
|
||||
kPriorityReplayBufferDestroy,
|
||||
kGatherDGradV2};
|
||||
const std::set<std::string> kDynamicInputOps{
|
||||
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName,
|
||||
kStackPushOpName, kStackPopOpName, kDynamicStitch, kPriorityReplayBufferPush, kPriorityReplayBufferSample,
|
||||
|
|
|
@ -30,9 +30,9 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
static const std::set<std::string> kAICpuOpNames = {
|
||||
kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName,
|
||||
kEnvironDestroyAllOpName, kPriorityReplayBufferCreate, kPriorityReplayBufferPush,
|
||||
kPriorityReplayBufferSample, kPriorityReplayBufferUpdate, kPriorityReplayBufferDestroy};
|
||||
kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName, kEnvironDestroyAllOpName,
|
||||
kPriorityReplayBufferCreate, kPriorityReplayBufferPush, kPriorityReplayBufferSample, kPriorityReplayBufferUpdate,
|
||||
kPriorityReplayBufferDestroy, kGatherDGradV2OpName};
|
||||
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -33,6 +33,7 @@ from .log import _log_aicpu
|
|||
from .padding import _padding_aicpu
|
||||
from .gather import _gather_aicpu
|
||||
from .gather_grad import _gather_grad_aicpu
|
||||
from .gather_d_grad_v2 import _gather_d_grad_v2_aicpu
|
||||
from .gather_d import _gather_d_aicpu
|
||||
from .scatter import _scatter_aicpu
|
||||
from .exp import _exp_aicpu
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GatherDGradV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
gather_d_grad_v2_op_info = AiCPURegOp("GatherDGradV2") \
|
||||
.attr("dim", "int") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "index", "required") \
|
||||
.input(2, "grad", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default,
|
||||
DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default,
|
||||
DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default,
|
||||
DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default,
|
||||
DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default,
|
||||
DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default,
|
||||
DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default,
|
||||
DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default,
|
||||
DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default,
|
||||
DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default,
|
||||
DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default,
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default,
|
||||
DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default,
|
||||
DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default,
|
||||
DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default,
|
||||
DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default,
|
||||
DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(gather_d_grad_v2_op_info)
|
||||
def _gather_d_grad_v2_aicpu():
|
||||
"""GatherDGradV2 aicpu register"""
|
||||
return
|
Loading…
Reference in New Issue