!47675 merge canndev operator ScatterNdUpdate and ScatterNd to MindSpore

Merge pull request !47675 from 沈竞兴/canndev_merge3
This commit is contained in:
i-robot 2023-01-14 11:39:24 +00:00 committed by Gitee
commit 66400d742d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 519 additions and 4 deletions

View File

@ -282,6 +282,8 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_set_diag_v3.cc:aicpu::MatrixSetDiagV3CpuKernel::DoCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d.cc:aicpu::MaxUnpool2DCpuKernel::MaxUnpool2DCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_solve_ls.cc:aicpu::MatrixSolveLsCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/tensor_scatter_update.cc:aicpu::TensorScatterUpdateCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd.cc:aicpu::ScatterNdCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/col2im.cc:aicpu::Col2imCpuKernel::Col2imParamCheck
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc:aicpu::CSRSparseMatrixToSparseTensorCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc:aicpu::CumprodCpuKernel::CumprodCompute
@ -302,8 +304,8 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/maxpool_grad.cc:aicpu::SpatialMaxPoolWithArgMaxHelper
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/reduce_prod.cc:aicpu::ReduceProdCpuKernel::ReduceProdCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/reduce_prod.cc:aicpu::ReduceProdCpuKernel::ReduceProdCompute_Complex
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/parameterized_truncated_normal.cc:aicpu::Generate
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd.cc:aicpu::ScatterNdCpuKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/parameterized_truncated_normal.cc:aicpu::Generate
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossComputeFP
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_3d_grad.cc:aicpu::MaxUnpool3DGradCpuKernel::MaxUnpool3DGradCompute

View File

@ -324,6 +324,7 @@ constexpr auto kGammaOpName = "Gamma";
constexpr auto kGatherDGradV2OpName = "GatherDGradV2";
constexpr auto kGatherDOpName = "GatherD";
constexpr auto kGatherOpName = "Gather";
constexpr auto kGatherNdOpName = "GatherNd";
constexpr auto kGatherV2OpName = "GatherV2";
constexpr auto kGatherV2DOpName = "GatherV2D";
constexpr auto kGeLUOpName = "GeLU";

View File

@ -0,0 +1,159 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gather_nd.h"
#include <string.h>
#include <algorithm>
#include <complex>
#include <iostream>
#include <map>
#include "eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kInputNum = 2;
const uint32_t kOutputNum = 1;
const char *kGatherNd = "GatherNd";
} // namespace
namespace aicpu {
uint32_t GatherNdCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check GatherNd Input and Output failed.");
Tensor *input_x = ctx.Input(0);
Tensor *input_indices = ctx.Input(1);
auto shape_x = input_x->GetTensorShape();
auto shape_indices = input_indices->GetTensorShape();
auto indices_rank = shape_indices->GetDims();
auto indices_nd = shape_indices->GetDimSize(indices_rank - 1);
if (shape_x->GetDims() < 1) {
KERNEL_LOG_ERROR("[%s] Tensor input_x's rank is less than 1.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (indices_rank < 1) {
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank is less than 1.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (indices_nd > shape_x->GetDims()) {
KERNEL_LOG_ERROR("[%s] Slice's length must be less than x rank. ", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto data_type0 = input_x->GetDataType();
auto data_type1 = input_indices->GetDataType();
if (data_type1 != DT_INT32 && data_type1 != DT_INT64) {
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type1).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
switch (data_type0) {
case DT_INT8:
return DTYPE_CHOOSE<int8_t>(ctx);
case DT_INT16:
return DTYPE_CHOOSE<int16_t>(ctx);
case DT_INT32:
return DTYPE_CHOOSE<int32_t>(ctx);
case DT_INT64:
return DTYPE_CHOOSE<int64_t>(ctx);
case DT_UINT8:
return DTYPE_CHOOSE<uint8_t>(ctx);
case DT_UINT16:
return DTYPE_CHOOSE<uint16_t>(ctx);
case DT_UINT32:
return DTYPE_CHOOSE<uint32_t>(ctx);
case DT_UINT64:
return DTYPE_CHOOSE<uint64_t>(ctx);
case DT_FLOAT16:
return DTYPE_CHOOSE<Eigen::half>(ctx);
case DT_FLOAT:
return DTYPE_CHOOSE<float>(ctx);
case DT_DOUBLE:
return DTYPE_CHOOSE<double>(ctx);
case DT_COMPLEX64:
return DTYPE_CHOOSE<std::complex<float>>(ctx);
case DT_COMPLEX128:
return DTYPE_CHOOSE<std::complex<double>>(ctx);
default:
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type0).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
template <typename data_type>
uint32_t GatherNdCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
switch (indices_type) {
case DT_INT32:
return GatherNdComputeRealKernel<int32_t, data_type>(ctx);
case DT_INT64:
return GatherNdComputeRealKernel<int64_t, data_type>(ctx);
default:
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
DTypeStr(indices_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename indices_type, typename data_type>
uint32_t GatherNdCpuKernel::GatherNdComputeRealKernel(CpuKernelContext &ctx) {
auto x_shape = ctx.Input(0)->GetTensorShape();
auto indices_shape = ctx.Input(1)->GetTensorShape();
int64_t n_slices = 1;
int64_t slice_size = 1;
const int64_t indices_dims = indices_shape->GetDims();
int64_t indices_nd = indices_shape->GetDimSize(indices_dims - 1);
const int64_t params_dims = x_shape->GetDims();
for (int64_t i = 0; i < indices_dims - 1; ++i) {
n_slices *= indices_shape->GetDimSize(i);
}
for (int64_t i = indices_nd; i < params_dims; ++i) {
slice_size *= x_shape->GetDimSize(i);
}
int64_t remain_flat_size = x_shape->NumElements();
std::vector<int64_t> dims_to_count = std::vector<int64_t>(indices_nd, 0);
for (int64_t i = 0; i < indices_nd; ++i) {
dims_to_count[i] = remain_flat_size / x_shape->GetDimSize(i);
remain_flat_size = dims_to_count[i];
}
auto indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
auto x_data = reinterpret_cast<data_type *>(ctx.Input(0)->GetData());
auto output_data = reinterpret_cast<data_type *>(ctx.Output(0)->GetData());
for (int64_t i = 0; i < n_slices; ++i) {
int64_t from_pos = 0;
for (int64_t j = 0; j < indices_nd; ++j) {
from_pos += indices_data[i * indices_nd + j] * dims_to_count[j];
}
std::memcpy(output_data + i * slice_size, x_data + from_pos, sizeof(data_type) * slice_size);
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kGatherNd, GatherNdCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_GATHERND_H_
#define AICPU_KERNELS_NORMALIZED_GATHERND_H_
#include <string.h>
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
namespace aicpu {
class GatherNdCpuKernel : public CpuKernel {
public:
GatherNdCpuKernel() = default;
~GatherNdCpuKernel() override = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename data_type>
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
template <typename indices_type, typename data_type>
uint32_t GatherNdComputeRealKernel(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -207,5 +207,4 @@ uint32_t ScatterNdUpdateCpuKernel::ScatterNdUpdateComputeRealKernel(CpuKernelCon
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kScatterNdUpdate, ScatterNdUpdateCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,211 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensor_scatter_update.h"
#include <string.h>
#include <algorithm>
#include <complex>
#include <iostream>
#include <map>
#include "eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kInputNum = 3;
const uint32_t kOutputNum = 1;
const char *kTensorScatterUpdate = "TensorScatterUpdate";
} // namespace
namespace aicpu {
uint32_t TensorScatterUpdateCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check TensorScatterUpdate Input and Output failed.");
Tensor *input_var = ctx.Input(0);
Tensor *input_indices = ctx.Input(1);
Tensor *input_updates = ctx.Input(2);
auto shape_var = input_var->GetTensorShape();
auto shape_indices = input_indices->GetTensorShape();
auto shape_updates = input_updates->GetTensorShape();
if (shape_var->GetDims() < 1) {
KERNEL_LOG_ERROR("[%s] Tensor input_var's rank less than 1.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (shape_indices->GetDims() < 2) {
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank less than 2.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
if (shape_updates->GetDims() < 1) {
KERNEL_LOG_ERROR("[%s] Tensor input_updates's rank less than 1.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto index_size = shape_indices->GetDims() - 1;
auto index_depth = shape_indices->GetDimSize(index_size);
if (index_depth > shape_var->GetDims()) {
KERNEL_LOG_ERROR("[%s] Tensor input_var&input_indices ranks mismatch.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
std::vector<int64_t> batch_shape;
for (int64_t i = 0; i < index_size; ++i) {
batch_shape.push_back(shape_indices->GetDimSize(i));
}
for (int64_t i = index_depth; i <= shape_var->GetDims() - 1; ++i) {
batch_shape.push_back(shape_var->GetDimSize(i));
}
if (batch_shape != shape_updates->GetDimSizes()) {
KERNEL_LOG_ERROR("[%s] Tensor indices's & updates' and var's shape are dismatch .", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
for (int64_t i = 0; i < index_size; i++) {
if (shape_indices->GetDimSize(i) != shape_updates->GetDimSize(i)) {
KERNEL_LOG_ERROR("[%s], Tensor indices and updates should have the same batch number.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
auto data_type_var = input_var->GetDataType();
auto data_type_indices = input_indices->GetDataType();
if (data_type_indices != DT_INT32 && data_type_indices != DT_INT64) {
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_indices).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
switch (data_type_var) {
case DT_INT8:
return DTYPE_CHOOSE<int8_t>(ctx);
case DT_INT16:
return DTYPE_CHOOSE<int16_t>(ctx);
case DT_INT32:
return DTYPE_CHOOSE<int32_t>(ctx);
case DT_INT64:
return DTYPE_CHOOSE<int64_t>(ctx);
case DT_UINT8:
return DTYPE_CHOOSE<uint8_t>(ctx);
case DT_UINT16:
return DTYPE_CHOOSE<uint16_t>(ctx);
case DT_UINT32:
return DTYPE_CHOOSE<uint32_t>(ctx);
case DT_UINT64:
return DTYPE_CHOOSE<uint64_t>(ctx);
case DT_FLOAT16:
return DTYPE_CHOOSE<Eigen::half>(ctx);
case DT_FLOAT:
return DTYPE_CHOOSE<float>(ctx);
case DT_DOUBLE:
return DTYPE_CHOOSE<double>(ctx);
case DT_COMPLEX64:
return DTYPE_CHOOSE<std::complex<float>>(ctx);
case DT_COMPLEX128:
return DTYPE_CHOOSE<std::complex<double>>(ctx);
default:
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_var).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename var_type>
uint32_t TensorScatterUpdateCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
switch (indices_type) {
case DT_INT32:
return TensorScatterUpdateComputeRealKernel<var_type, int32_t>(ctx);
case DT_INT64:
return TensorScatterUpdateComputeRealKernel<var_type, int64_t>(ctx);
default:
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
DTypeStr(indices_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
template <typename var_type, typename indices_type>
uint32_t TensorScatterUpdateCpuKernel::TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx) {
int64_t n_slices = 1;
int64_t slice_size = 1;
const int64_t indices_dims = ctx.Input(1)->GetTensorShape()->GetDims() - 1;
const int64_t indices_nd = ctx.Input(1)->GetTensorShape()->GetDimSize(indices_dims);
const int64_t updates_dims = ctx.Input(2)->GetTensorShape()->GetDims();
auto shape_var = ctx.Input(0)->GetTensorShape()->GetDimSizes();
auto shape_indices = ctx.Input(1)->GetTensorShape();
auto dims_shape = ctx.Input(0)->GetTensorShape()->GetDims();
for (int64_t i = 0; i < dims_shape - indices_nd; i++) {
if (ctx.Input(2)->GetTensorShape()->GetDimSize(i + shape_indices->GetDims() - 1) != shape_var[i + indices_nd]) {
KERNEL_LOG_ERROR("[%s] shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
for (int64_t i = 0; i < indices_dims; ++i) {
n_slices *= ctx.Input(1)->GetTensorShape()->GetDimSize(i);
}
for (int i = indices_dims; i < updates_dims; ++i) {
slice_size *= ctx.Input(2)->GetTensorShape()->GetDimSize(i);
}
const int64_t var_flat_size = ctx.Input(0)->GetTensorShape()->NumElements();
std::vector<int64_t> output_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
int64_t remain_flat_size = var_flat_size;
std::vector<int64_t> dims_to_count(indices_nd, 0);
for (int64_t i = 0; i < indices_nd; ++i) {
dims_to_count[i] = remain_flat_size / output_shape[i];
remain_flat_size = dims_to_count[i];
}
auto Var_data = reinterpret_cast<var_type *>(ctx.Input(0)->GetData());
auto Indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
auto Updates_data = reinterpret_cast<var_type *>(ctx.Input(2)->GetData());
auto Output_data = reinterpret_cast<var_type *>(ctx.Output(0)->GetData());
for (int64_t i = 0; i < var_flat_size; ++i) {
Output_data[i] = Var_data[i];
}
for (int64_t i = 0; i < n_slices; ++i) {
int64_t to_pos = 0;
for (int64_t j = 0; j < indices_nd; ++j) {
int64_t idx = Indices_data[i * indices_nd + j];
if (idx < 0 || idx >= output_shape[j]) {
KERNEL_LOG_ERROR("The indices[%d] is so big or small", idx);
return KERNEL_STATUS_PARAM_INVALID;
}
to_pos += idx * dims_to_count[j];
}
for (int64_t j = 0; j < slice_size; j++) {
Output_data[to_pos + j] = Updates_data[i * slice_size + j];
}
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kTensorScatterUpdate, TensorScatterUpdateCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
#define AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
#include "utils/bcast.h"
#include <string.h>
namespace aicpu {
class TensorScatterUpdateCpuKernel : public CpuKernel {
public:
TensorScatterUpdateCpuKernel() = default;
~TensorScatterUpdateCpuKernel() override = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
template <typename var_type>
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
template <typename var_type, typename indices_type>
uint32_t TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -84,6 +84,10 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
mindspore::kFFTWithSizeOpName,
mindspore::kHistogramDOpName,
mindspore::kIm2colOpName,
mindspore::kGatherNdOpName,
mindspore::kScatterNdOpName,
mindspore::kScatterNdUpdateOpName,
mindspore::kTensorScatterUpdateOpName,
mindspore::kIsInfOpName,
mindspore::kIsNanOpName,
mindspore::kMatrixDeterminantOpName,

View File

@ -40,6 +40,7 @@ from .dynamic_stitch import _dynamic_stitch_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
from .tensor_scatter_update import _tensor_scatter_update_aicpu
from .log1p import _log1p_aicpu
from .asin import _asin_aicpu
from .is_finite import _is_finite_aicpu
@ -148,6 +149,8 @@ from .bias_add_grad import _bias_add_grad_aicpu
from .grid_sampler_2d import _grid_sampler_2d_aicpu
from .grid_sampler_2d_grad import _grid_sampler_2d_grad_aicpu
from .sparse_segment_mean_grad import _sparse_segment_mean_grad_aicpu
from .scatter_nd import _scatter_nd_aicpu
from .scatter_nd_update import _scatter_nd_update_aicpu
from .scatter_nd_max import _scatter_nd_max_aicpu
from .conj import _conj_aicpu
from .scatter_nd_min import _scatter_nd_min_aicpu
@ -225,8 +228,6 @@ from .rgb_to_hsv import _rgb_to_hsv_aicpu
from .rsqrt_grad import _rsqrt_grad_aicpu
from .sample_distorted_bounding_box_v2 import _sample_distorted_bounding_box_v2_aicpu
from .scale_and_translate_grad import _scale_and_translate_grad_aicpu
from .scatter_nd import _scatter_nd_aicpu
from .scatter_nd_update import _scatter_nd_update_aicpu
from .select import _select_aicpu
from .self_adjoint_eig import _self_adjoint_eig_aicpu
from .sin import _sin_aicpu

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
tensor_scatter_update_op_info = AiCPURegOp("TensorScatterUpdate") \
.fusion_type("OPAQUE") \
.input(0, "input_x", "required") \
.input(1, "indices", "required") \
.input(2, "updates", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(tensor_scatter_update_op_info)
def _tensor_scatter_update_aicpu():
"""TensorScatterUpdate AiCPU register"""
return