forked from mindspore-Ecosystem/mindspore
!47675 merge canndev operator ScatterNdUpdate and ScatterNd to MindSpore
Merge pull request !47675 from 沈竞兴/canndev_merge3
This commit is contained in:
commit
66400d742d
|
@ -282,6 +282,8 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_set_diag_v3.cc:aicpu::MatrixSetDiagV3CpuKernel::DoCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d.cc:aicpu::MaxUnpool2DCpuKernel::MaxUnpool2DCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_solve_ls.cc:aicpu::MatrixSolveLsCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/tensor_scatter_update.cc:aicpu::TensorScatterUpdateCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd.cc:aicpu::ScatterNdCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/col2im.cc:aicpu::Col2imCpuKernel::Col2imParamCheck
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/csr_sparse_matrix_to_sparse_tensor.cc:aicpu::CSRSparseMatrixToSparseTensorCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cumprod.cc:aicpu::CumprodCpuKernel::CumprodCompute
|
||||
|
@ -302,8 +304,8 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/maxpool_grad.cc:aicpu::SpatialMaxPoolWithArgMaxHelper
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/reduce_prod.cc:aicpu::ReduceProdCpuKernel::ReduceProdCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/reduce_prod.cc:aicpu::ReduceProdCpuKernel::ReduceProdCompute_Complex
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/parameterized_truncated_normal.cc:aicpu::Generate
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd.cc:aicpu::ScatterNdCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/parameterized_truncated_normal.cc:aicpu::Generate
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossComputeFP
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_3d_grad.cc:aicpu::MaxUnpool3DGradCpuKernel::MaxUnpool3DGradCompute
|
||||
|
|
|
@ -324,6 +324,7 @@ constexpr auto kGammaOpName = "Gamma";
|
|||
constexpr auto kGatherDGradV2OpName = "GatherDGradV2";
|
||||
constexpr auto kGatherDOpName = "GatherD";
|
||||
constexpr auto kGatherOpName = "Gather";
|
||||
constexpr auto kGatherNdOpName = "GatherNd";
|
||||
constexpr auto kGatherV2OpName = "GatherV2";
|
||||
constexpr auto kGatherV2DOpName = "GatherV2D";
|
||||
constexpr auto kGeLUOpName = "GeLU";
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kGatherNd = "GatherNd";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t GatherNdCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check GatherNd Input and Output failed.");
|
||||
|
||||
Tensor *input_x = ctx.Input(0);
|
||||
Tensor *input_indices = ctx.Input(1);
|
||||
|
||||
auto shape_x = input_x->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto indices_rank = shape_indices->GetDims();
|
||||
auto indices_nd = shape_indices->GetDimSize(indices_rank - 1);
|
||||
|
||||
if (shape_x->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_x's rank is less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (indices_rank < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank is less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (indices_nd > shape_x->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Slice's length must be less than x rank. ", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto data_type0 = input_x->GetDataType();
|
||||
auto data_type1 = input_indices->GetDataType();
|
||||
|
||||
if (data_type1 != DT_INT32 && data_type1 != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type1).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type0) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type0).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_type>
|
||||
uint32_t GatherNdCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return GatherNdComputeRealKernel<int32_t, data_type>(ctx);
|
||||
case DT_INT64:
|
||||
return GatherNdComputeRealKernel<int64_t, data_type>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename indices_type, typename data_type>
|
||||
uint32_t GatherNdCpuKernel::GatherNdComputeRealKernel(CpuKernelContext &ctx) {
|
||||
auto x_shape = ctx.Input(0)->GetTensorShape();
|
||||
auto indices_shape = ctx.Input(1)->GetTensorShape();
|
||||
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
const int64_t indices_dims = indices_shape->GetDims();
|
||||
int64_t indices_nd = indices_shape->GetDimSize(indices_dims - 1);
|
||||
|
||||
const int64_t params_dims = x_shape->GetDims();
|
||||
|
||||
for (int64_t i = 0; i < indices_dims - 1; ++i) {
|
||||
n_slices *= indices_shape->GetDimSize(i);
|
||||
}
|
||||
for (int64_t i = indices_nd; i < params_dims; ++i) {
|
||||
slice_size *= x_shape->GetDimSize(i);
|
||||
}
|
||||
|
||||
int64_t remain_flat_size = x_shape->NumElements();
|
||||
std::vector<int64_t> dims_to_count = std::vector<int64_t>(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / x_shape->GetDimSize(i);
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
|
||||
auto x_data = reinterpret_cast<data_type *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<data_type *>(ctx.Output(0)->GetData());
|
||||
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t from_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
from_pos += indices_data[i * indices_nd + j] * dims_to_count[j];
|
||||
}
|
||||
std::memcpy(output_data + i * slice_size, x_data + from_pos, sizeof(data_type) * slice_size);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kGatherNd, GatherNdCpuKernel);
|
||||
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_GATHERND_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_GATHERND_H_
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class GatherNdCpuKernel : public CpuKernel {
|
||||
public:
|
||||
GatherNdCpuKernel() = default;
|
||||
~GatherNdCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename data_type>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename indices_type, typename data_type>
|
||||
uint32_t GatherNdComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -207,5 +207,4 @@ uint32_t ScatterNdUpdateCpuKernel::ScatterNdUpdateComputeRealKernel(CpuKernelCon
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kScatterNdUpdate, ScatterNdUpdateCpuKernel);
|
||||
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensor_scatter_update.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 3;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kTensorScatterUpdate = "TensorScatterUpdate";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t TensorScatterUpdateCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check TensorScatterUpdate Input and Output failed.");
|
||||
|
||||
Tensor *input_var = ctx.Input(0);
|
||||
Tensor *input_indices = ctx.Input(1);
|
||||
Tensor *input_updates = ctx.Input(2);
|
||||
|
||||
auto shape_var = input_var->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto shape_updates = input_updates->GetTensorShape();
|
||||
|
||||
if (shape_var->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_indices->GetDims() < 2) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank less than 2.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_updates->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_updates's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto index_size = shape_indices->GetDims() - 1;
|
||||
auto index_depth = shape_indices->GetDimSize(index_size);
|
||||
|
||||
if (index_depth > shape_var->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var&input_indices ranks mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::vector<int64_t> batch_shape;
|
||||
for (int64_t i = 0; i < index_size; ++i) {
|
||||
batch_shape.push_back(shape_indices->GetDimSize(i));
|
||||
}
|
||||
|
||||
for (int64_t i = index_depth; i <= shape_var->GetDims() - 1; ++i) {
|
||||
batch_shape.push_back(shape_var->GetDimSize(i));
|
||||
}
|
||||
|
||||
if (batch_shape != shape_updates->GetDimSizes()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor indices's & updates' and var's shape are dismatch .", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < index_size; i++) {
|
||||
if (shape_indices->GetDimSize(i) != shape_updates->GetDimSize(i)) {
|
||||
KERNEL_LOG_ERROR("[%s], Tensor indices and updates should have the same batch number.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_type_var = input_var->GetDataType();
|
||||
auto data_type_indices = input_indices->GetDataType();
|
||||
|
||||
if (data_type_indices != DT_INT32 && data_type_indices != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_indices).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type_var) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_var).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename var_type>
|
||||
uint32_t TensorScatterUpdateCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return TensorScatterUpdateComputeRealKernel<var_type, int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return TensorScatterUpdateComputeRealKernel<var_type, int64_t>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t TensorScatterUpdateCpuKernel::TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx) {
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
|
||||
const int64_t indices_dims = ctx.Input(1)->GetTensorShape()->GetDims() - 1;
|
||||
const int64_t indices_nd = ctx.Input(1)->GetTensorShape()->GetDimSize(indices_dims);
|
||||
const int64_t updates_dims = ctx.Input(2)->GetTensorShape()->GetDims();
|
||||
|
||||
auto shape_var = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto shape_indices = ctx.Input(1)->GetTensorShape();
|
||||
auto dims_shape = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
for (int64_t i = 0; i < dims_shape - indices_nd; i++) {
|
||||
if (ctx.Input(2)->GetTensorShape()->GetDimSize(i + shape_indices->GetDims() - 1) != shape_var[i + indices_nd]) {
|
||||
KERNEL_LOG_ERROR("[%s] shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < indices_dims; ++i) {
|
||||
n_slices *= ctx.Input(1)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
for (int i = indices_dims; i < updates_dims; ++i) {
|
||||
slice_size *= ctx.Input(2)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
|
||||
const int64_t var_flat_size = ctx.Input(0)->GetTensorShape()->NumElements();
|
||||
std::vector<int64_t> output_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
|
||||
int64_t remain_flat_size = var_flat_size;
|
||||
std::vector<int64_t> dims_to_count(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / output_shape[i];
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto Var_data = reinterpret_cast<var_type *>(ctx.Input(0)->GetData());
|
||||
auto Indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
|
||||
auto Updates_data = reinterpret_cast<var_type *>(ctx.Input(2)->GetData());
|
||||
auto Output_data = reinterpret_cast<var_type *>(ctx.Output(0)->GetData());
|
||||
|
||||
for (int64_t i = 0; i < var_flat_size; ++i) {
|
||||
Output_data[i] = Var_data[i];
|
||||
}
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t to_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
int64_t idx = Indices_data[i * indices_nd + j];
|
||||
|
||||
if (idx < 0 || idx >= output_shape[j]) {
|
||||
KERNEL_LOG_ERROR("The indices[%d] is so big or small", idx);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
to_pos += idx * dims_to_count[j];
|
||||
}
|
||||
for (int64_t j = 0; j < slice_size; j++) {
|
||||
Output_data[to_pos + j] = Updates_data[i * slice_size + j];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kTensorScatterUpdate, TensorScatterUpdateCpuKernel);
|
||||
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
#include <string.h>
|
||||
|
||||
namespace aicpu {
|
||||
class TensorScatterUpdateCpuKernel : public CpuKernel {
|
||||
public:
|
||||
TensorScatterUpdateCpuKernel() = default;
|
||||
~TensorScatterUpdateCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename var_type>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -84,6 +84,10 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
mindspore::kFFTWithSizeOpName,
|
||||
mindspore::kHistogramDOpName,
|
||||
mindspore::kIm2colOpName,
|
||||
mindspore::kGatherNdOpName,
|
||||
mindspore::kScatterNdOpName,
|
||||
mindspore::kScatterNdUpdateOpName,
|
||||
mindspore::kTensorScatterUpdateOpName,
|
||||
mindspore::kIsInfOpName,
|
||||
mindspore::kIsNanOpName,
|
||||
mindspore::kMatrixDeterminantOpName,
|
||||
|
|
|
@ -40,6 +40,7 @@ from .dynamic_stitch import _dynamic_stitch_aicpu
|
|||
from .get_next import _get_next_aicpu
|
||||
from .print_tensor import _print_aicpu
|
||||
from .topk import _top_k_aicpu
|
||||
from .tensor_scatter_update import _tensor_scatter_update_aicpu
|
||||
from .log1p import _log1p_aicpu
|
||||
from .asin import _asin_aicpu
|
||||
from .is_finite import _is_finite_aicpu
|
||||
|
@ -148,6 +149,8 @@ from .bias_add_grad import _bias_add_grad_aicpu
|
|||
from .grid_sampler_2d import _grid_sampler_2d_aicpu
|
||||
from .grid_sampler_2d_grad import _grid_sampler_2d_grad_aicpu
|
||||
from .sparse_segment_mean_grad import _sparse_segment_mean_grad_aicpu
|
||||
from .scatter_nd import _scatter_nd_aicpu
|
||||
from .scatter_nd_update import _scatter_nd_update_aicpu
|
||||
from .scatter_nd_max import _scatter_nd_max_aicpu
|
||||
from .conj import _conj_aicpu
|
||||
from .scatter_nd_min import _scatter_nd_min_aicpu
|
||||
|
@ -225,8 +228,6 @@ from .rgb_to_hsv import _rgb_to_hsv_aicpu
|
|||
from .rsqrt_grad import _rsqrt_grad_aicpu
|
||||
from .sample_distorted_bounding_box_v2 import _sample_distorted_bounding_box_v2_aicpu
|
||||
from .scale_and_translate_grad import _scale_and_translate_grad_aicpu
|
||||
from .scatter_nd import _scatter_nd_aicpu
|
||||
from .scatter_nd_update import _scatter_nd_update_aicpu
|
||||
from .select import _select_aicpu
|
||||
from .self_adjoint_eig import _self_adjoint_eig_aicpu
|
||||
from .sin import _sin_aicpu
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TensorScatterUpdate op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
tensor_scatter_update_op_info = AiCPURegOp("TensorScatterUpdate") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input_x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.input(2, "updates", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tensor_scatter_update_op_info)
|
||||
def _tensor_scatter_update_aicpu():
|
||||
"""TensorScatterUpdate AiCPU register"""
|
||||
return
|
Loading…
Reference in New Issue