From 4aac1ea35f1850fad205db2c4b63974d599164b1 Mon Sep 17 00:00:00 2001 From: lilinjie Date: Wed, 15 Feb 2023 14:08:55 +0800 Subject: [PATCH] fix aicpu migration issues --- .jenkins/check/config/whitelizard.txt | 2 + .../api_python/mindspore.ops.primitive.rst | 1 - .../api_python_en/mindspore.ops.primitive.rst | 1 - mindspore/ccsrc/include/common/utils/utils.h | 1 + .../cpu_kernel/ms_kernel/batchmatmul.cc | 251 +++++++++++++++ .../cpu_kernel/ms_kernel/batchmatmul.h | 34 ++ .../ms_kernel/instance_norm_v2_grad.cc | 266 ++++++++++++++++ .../ms_kernel/instance_norm_v2_grad.h | 46 +++ .../cpu_kernel/ms_kernel/sparse_to_dense.cc | 299 ++++++++++++++++++ .../cpu_kernel/ms_kernel/sparse_to_dense.h | 92 ++++++ .../optimizer/mindir/aicpu_lib_select.cc | 6 +- .../mindspore/ops/_op_impl/aicpu/__init__.py | 4 + 12 files changed, 1000 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.h create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.h diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index c5a85699d86..77a207acf9a 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -366,3 +366,5 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d_grad.cc:aicpu::MaxUnpool2DGradCpuKernel::MaxUnpool2DGradCompute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_pool_3d_with_argmax.cc:aicpu::MaxPool3DWithArgmaxCpuKernel::MaxPool3DWithArgmaxCompute mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/layer_norm_grad_grad.cc:aicpu::LayerNormGradGradCpuKernel::LayerNormGradGradCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc:aicpu::BatchMatMulCpuKernel::DoCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc:aicpu::SparseToDenseCpuKernel::ValidParam diff --git a/docs/api/api_python/mindspore.ops.primitive.rst b/docs/api/api_python/mindspore.ops.primitive.rst index fd0d15645c4..ddb04a4936e 100644 --- a/docs/api/api_python/mindspore.ops.primitive.rst +++ b/docs/api/api_python/mindspore.ops.primitive.rst @@ -79,7 +79,6 @@ MindSpore中 `mindspore.ops.primitive` 接口与上一版本相比,新增、 mindspore.ops.MaxUnpool2D mindspore.ops.MirrorPad mindspore.ops.NthElement - mindspore.ops.NuclearNorm mindspore.ops.Pad mindspore.ops.Padding mindspore.ops.PadV3 diff --git a/docs/api/api_python_en/mindspore.ops.primitive.rst b/docs/api/api_python_en/mindspore.ops.primitive.rst index 9d65c128b8f..550e1dae18c 100644 --- a/docs/api/api_python_en/mindspore.ops.primitive.rst +++ b/docs/api/api_python_en/mindspore.ops.primitive.rst @@ -78,7 +78,6 @@ Neural Network mindspore.ops.MaxUnpool2D mindspore.ops.MirrorPad mindspore.ops.NthElement - mindspore.ops.NuclearNorm mindspore.ops.Pad mindspore.ops.EmbeddingLookup mindspore.ops.Padding diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 983954df953..0d681d934b8 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -756,6 +756,7 @@ constexpr auto kSparseSparseMaximumOpName = "SparseSparseMaximum"; constexpr auto kSparseTensorDenseMatMulOpName = "SparseTensorDenseMatMul"; constexpr auto kSparseTensorDenseAddOpName = "SparseTensorDenseAdd"; constexpr auto kSparseTensorToCSRSparseMatrixOpName = "SparseTensorToCSRSparseMatrix"; +constexpr auto kSparseToDenseOpName = "SparseToDense"; constexpr auto kSplitOpName = "Split"; constexpr auto kSplitDOpName = "SplitD"; constexpr auto kSplitVOpName = "SplitV"; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc new file mode 100644 index 00000000000..3141c368ec6 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc @@ -0,0 +1,251 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batchmatmul.h" + +#include +#include "unsupported/Eigen/CXX11/Tensor" + +#include "cpu_kernel_utils.h" +#include "utils/kernel_util.h" +#include "kernel_log.h" +#include "status.h" +#include + +using namespace std; + +namespace { +const char *kBatchMatmul = "BatchMatMul"; +const uint32_t kInputNum = 2; +const uint32_t kOutputNum = 1; +const int64_t kParallelDataNum = 1024; +} // namespace + +namespace aicpu { +template +uint32_t BatchMatMulCpuKernel::DoCompute(CpuKernelContext &ctx) { + auto input0_tensor = ctx.Input(0); + auto input0_tensor_shape = input0_tensor->GetTensorShape(); + int32_t input0_tensor_dims = input0_tensor_shape->GetDims(); + KERNEL_CHECK_FALSE((input0_tensor_dims > 1), KERNEL_STATUS_PARAM_INVALID, "Input[x1] must be a matrix or higher.") + + auto input1_tensor = ctx.Input(1); + auto input1_tensor_shape = input1_tensor->GetTensorShape(); + int32_t input1_tensor_dims = input1_tensor_shape->GetDims(); + KERNEL_CHECK_FALSE((input1_tensor_dims > 1), KERNEL_STATUS_PARAM_INVALID, "Input[x2] must be a matrix or higher.") + + auto output_tensor = ctx.Output(0); + DataType input0_data_type = input0_tensor->GetDataType(); + DataType input1_data_type = input1_tensor->GetDataType(); + DataType output_data_type = output_tensor->GetDataType(); + KERNEL_CHECK_FALSE((input0_data_type == input1_data_type), KERNEL_STATUS_PARAM_INVALID, + "Input[x1] data type[%s] and input[x2] data type[%s] must be same.", + DTypeStr(input0_data_type).c_str(), DTypeStr(input1_data_type).c_str()) + KERNEL_CHECK_FALSE((input0_data_type == output_data_type), KERNEL_STATUS_PARAM_INVALID, + "Input data type[%s] and output data type[%s] must be same.", DTypeStr(input0_data_type).c_str(), + DTypeStr(output_data_type).c_str()) + bool adj_x = false; + bool adj_y = false; + auto adj_x1 = ctx.GetAttr("adj_x1"); + auto adj_x2 = ctx.GetAttr("adj_x2"); + if (adj_x1 != nullptr) { + adj_x = adj_x1->GetBool(); + } + if (adj_x2 != nullptr) { + adj_y = adj_x2->GetBool(); + } + KERNEL_LOG_DEBUG( + "%s Attr[adj_x1] value[%d], " + "Attr[adj_x2] value[%d].", + kBatchMatmul, adj_x, adj_y); + + int32_t x1_dim = adj_x ? input0_tensor_dims - 2 : input0_tensor_dims - 1; + int32_t x2_dim = adj_y ? input1_tensor_dims - 1 : input1_tensor_dims - 2; + KERNEL_CHECK_FALSE((input0_tensor_shape->GetDimSize(x1_dim) == input1_tensor_shape->GetDimSize(x2_dim)), + KERNEL_STATUS_PARAM_INVALID, + "Matrix size incompatible, input[x1] dim[%d] value[%lld], " + "input[x2] dim[%d] value[%lld]", + x1_dim, input0_tensor_shape->GetDimSize(x1_dim), x2_dim, input1_tensor_shape->GetDimSize(x2_dim)) + KERNEL_CHECK_FALSE((input0_tensor_dims == input1_tensor_dims), KERNEL_STATUS_PARAM_INVALID, + "input0_tensor_dims value[%d] is not equal to " + "input1_tensor_dims value[%d]", + input0_tensor_dims, input1_tensor_dims) + + auto input0_shape = input0_tensor_shape->GetDimSizes(); + auto input1_shape = input1_tensor_shape->GetDimSizes(); + auto output_shape = output_tensor->GetTensorShape()->GetDimSizes(); + + for (int32_t i = 0; i < input0_tensor_dims - 2; i++) { + KERNEL_CHECK_FALSE((input0_shape[i] == input1_shape[i]), KERNEL_STATUS_PARAM_INVALID, + "input0_shape dim[%d] value[%lld] is not equal to " + "input1_shape dim[%d] value[%lld]", + i, input0_shape[i], i, input1_shape[i]) + + KERNEL_CHECK_FALSE((input0_shape[i] == output_shape[i]), KERNEL_STATUS_PARAM_INVALID, + "input0_shape dim[%d] value[%lld] is not equal to " + "output_shape dim[%d] value[%lld]", + i, input0_shape[i], i, output_shape[i]) + } + + int32_t map1_l = input0_shape[input0_tensor_dims - 2]; + int32_t map1_r = input0_shape[input0_tensor_dims - 1]; + + int32_t map2_l = input1_shape[input1_tensor_dims - 2]; + int32_t map2_r = input1_shape[input1_tensor_dims - 1]; + + int32_t rows_dim = adj_x ? input0_tensor_dims - 1 : input0_tensor_dims - 2; + int32_t cols_dim = adj_y ? input1_tensor_dims - 2 : input1_tensor_dims - 1; + + int32_t num_rows = input0_tensor_shape->GetDimSize(rows_dim); + int32_t num_cols = input1_tensor_shape->GetDimSize(cols_dim); + + uint64_t num_element = output_tensor->NumElements(); + uint64_t num_batches = num_element / (num_rows * num_cols); + Eigen::Matrix map1(map1_l, map1_r); + auto *input0_data = reinterpret_cast(input0_tensor->GetData()); + Eigen::Matrix map2(map2_l, map2_r); + auto *input1_data = reinterpret_cast(input1_tensor->GetData()); + Eigen::Matrix map_output(num_rows, num_cols); + auto *output_data = reinterpret_cast(output_tensor->GetData()); + uint64_t input0_size = input0_tensor->GetDataSize(); + if (input0_size >= kParallelDataNum) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx)); + if (max_core_num > num_batches) { + max_core_num = num_batches; + } + Eigen::Matrix map1[num_batches]; + Eigen::Matrix map2[num_batches]; + Eigen::Matrix map_output[num_batches]; + auto shared_batchmatmul = [&](int64_t start, int64_t end) { + for (int64_t batch = start; batch < end; ++batch) { + Eigen::Matrix matrix1(map1_l, map1_r); + map1[batch].resize(map1_l, map1_r); + for (int64_t i = 0; i < map1_l; i++) { + for (int64_t j = 0; j < map1_r; j++) { + map1[batch](i, j) = input0_data[batch * map1_l * map1_r + i * map1_r + j]; + } + } + map2[batch].resize(map2_l, map2_r); + for (int64_t i = 0; i < map2_l; i++) { + for (int64_t j = 0; j < map2_r; j++) { + map2[batch](i, j) = input1_data[batch * map2_l * map2_r + i * map2_r + j]; + } + } + if (adj_x) { + if (adj_y) { + map_output[batch] = map1[batch].adjoint() * map2[batch].adjoint(); + } else { + map_output[batch] = map1[batch].adjoint() * map2[batch]; + } + } else { + if (adj_y) { + map_output[batch] = map1[batch] * map2[batch].adjoint(); + } else { + map_output[batch] = map1[batch] * map2[batch]; + } + } + map_output[batch].resize(num_rows, num_cols); + for (int64_t i = 0; i < num_rows; ++i) { + for (int64_t j = 0; j < num_cols; ++j) { + output_data[batch * num_rows * num_cols + i * num_cols + j] = map_output[batch](i, j); + } + } + } + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_batches, num_batches / max_core_num, shared_batchmatmul), + "BatchMatMul Compute failed."); + } else { + for (uint64_t batch = 0; batch < num_batches; ++batch) { + for (int64_t i = 0; i < map1_l; i++) { + for (int64_t j = 0; j < map1_r; j++) { + map1(i, j) = input0_data[batch * map1_l * map1_r + i * map1_r + j]; + } + } + for (int64_t i = 0; i < map2_l; i++) { + for (int64_t j = 0; j < map2_r; j++) { + map2(i, j) = input1_data[batch * map2_l * map2_r + i * map2_r + j]; + } + } + if (adj_x) { + if (adj_y) { + map_output = map1.adjoint() * map2.adjoint(); + } else { + map_output = map1.adjoint() * map2; + } + } else { + if (adj_y) { + map_output = map1 * map2.adjoint(); + } else { + map_output = map1 * map2; + } + } + for (int64_t i = 0; i < num_rows; ++i) { + for (int64_t j = 0; j < num_cols; ++j) { + output_data[batch * num_rows * num_cols + i * num_cols + j] = map_output(i, j); + } + } + } + } + return KERNEL_STATUS_OK; +} + +uint32_t BatchMatMulCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check BatchMatMul params failed."); + DataType input0_data_type = ctx.Input(0)->GetDataType(); + KERNEL_LOG_DEBUG("%s op input[x1] data type is [%s].", kBatchMatmul, DTypeStr(input0_data_type).c_str()); + uint32_t ret = KERNEL_STATUS_OK; + switch (input0_data_type) { + case DT_FLOAT: + ret = DoCompute(ctx); + break; + case DT_DOUBLE: + ret = DoCompute(ctx); + break; + case DT_FLOAT16: + ret = DoCompute(ctx); + break; + case DT_INT16: + ret = DoCompute(ctx); + break; + case DT_INT32: + ret = DoCompute(ctx); + break; + case DT_INT64: + ret = DoCompute(ctx); + break; + case DT_UINT16: + ret = DoCompute(ctx); + break; + case DT_UINT32: + ret = DoCompute(ctx); + break; + case DT_UINT64: + ret = DoCompute(ctx); + break; + case DT_COMPLEX64: + ret = DoCompute>(ctx); + break; + case DT_COMPLEX128: + ret = DoCompute>(ctx); + break; + default: + KERNEL_LOG_ERROR("Unsupported input[x1] data type[%s]", DTypeStr(input0_data_type).c_str()); + ret = KERNEL_STATUS_PARAM_INVALID; + } + return ret; +} +REGISTER_CPU_KERNEL(kBatchMatmul, BatchMatMulCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.h new file mode 100644 index 00000000000..529949032ba --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_HOST_BATCHMATMUL_H_ +#define AICPU_KERNELS_HOST_BATCHMATMUL_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class BatchMatMulCpuKernel : public CpuKernel { + public: + BatchMatMulCpuKernel() = default; + ~BatchMatMulCpuKernel() = default; + + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t DoCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.cc new file mode 100644 index 00000000000..5793e7a7559 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "instance_norm_v2_grad.h" + +#include "Eigen/Dense" +#include "cpu_kernel_utils.h" +#include "cpu_types.h" +#include "utils/kernel_util.h" +#include "kernel_log.h" +#include "securec.h" +#include "status.h" +#include "utils/eigen_tensor.h" +#include +#include +#include + +namespace { +const char *const kInstanceNormV2Grad = "InstanceNormV2Grad"; +constexpr uint32_t kOutputNum = 3; +constexpr uint32_t kInputNum = 7; +constexpr int64_t int64_init_one = 1; +constexpr int64_t kGrainSize = 4 * 1024; +constexpr float float_init_zero = 0.0; +constexpr float float_init_one = 1.0; +constexpr double double_init_zero = 0.0; +constexpr auto kDim3 = 3; +constexpr auto kDim4 = 4; +constexpr auto kDim5 = 5; +constexpr auto InstanceNormV2GradInDyIndex = 0; +constexpr auto InstanceNormV2GradInXIndex = 1; +constexpr auto InstanceNormV2GradInGammaIndex = 2; +constexpr auto InstanceNormV2GradInMeanIndex = 3; +constexpr auto InstanceNormV2GradInVarianceIndex = 4; +constexpr auto InstanceNormV2GradInSaveMeanIndex = 5; +constexpr auto InstanceNormV2GradInSaveVarianceIndex = 6; +constexpr auto InstanceNormV2GradOutDxIndex = 0; +constexpr auto InstanceNormV2GradOutPdGammaIndex = 1; +constexpr auto InstanceNormV2GradOutPdBetaIndex = 2; + +inline double FloatToDouble(float v) { return static_cast(v); } +inline double LongToDouble(int64_t v) { return static_cast(v); } +} // namespace + +namespace aicpu { +uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradTypeCheck(const CpuKernelContext &ctx) { + auto dy_type = ctx.Input(InstanceNormV2GradInDyIndex)->GetDataType(); + auto x_type = ctx.Input(InstanceNormV2GradInXIndex)->GetDataType(); + auto gamma_type = ctx.Input(InstanceNormV2GradInGammaIndex)->GetDataType(); + auto mean_type = ctx.Input(InstanceNormV2GradInMeanIndex)->GetDataType(); + auto variance_type = ctx.Input(InstanceNormV2GradInVarianceIndex)->GetDataType(); + auto save_mean_type = ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetDataType(); + auto save_variance_type = ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetDataType(); + + if (dy_type != x_type) { + KERNEL_LOG_ERROR( + "For primitive[%s]'s input arguments dy and x should have the same " + "data type, but dy type is [%s], x type is [%s].", + kInstanceNormV2Grad, DTypeStr(dy_type).c_str(), DTypeStr(x_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + const std::map types = {{"gamma", gamma_type}, + {"mean", mean_type}, + {"variance", variance_type}, + {"save_mean", save_mean_type}, + {"save_variance", save_variance_type}}; + return CheckTensorTypeSame(types, DT_FLOAT, kInstanceNormV2Grad); +} + +uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradShapeCheck(const CpuKernelContext &ctx) { + auto dy_shape_ptr = ctx.Input(InstanceNormV2GradInDyIndex)->GetTensorShape(); + auto x_shape_ptr = ctx.Input(InstanceNormV2GradInXIndex)->GetTensorShape(); + auto gamma_shape_ptr = ctx.Input(InstanceNormV2GradInGammaIndex)->GetTensorShape(); + auto mean_shape_ptr = ctx.Input(InstanceNormV2GradInMeanIndex)->GetTensorShape(); + auto variance_shape_ptr = ctx.Input(InstanceNormV2GradInVarianceIndex)->GetTensorShape(); + auto save_mean_shape_ptr = ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetTensorShape(); + auto save_variance_shape_ptr = ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetTensorShape(); + auto pd_gamma_shape_ptr = ctx.Output(InstanceNormV2GradOutPdGammaIndex)->GetTensorShape(); + auto pd_beta_shape_ptr = ctx.Output(InstanceNormV2GradOutPdBetaIndex)->GetTensorShape(); + + auto dy_shape = dy_shape_ptr->GetDimSizes(); + auto res = CheckTensorShapeSame({{"input x", x_shape_ptr}}, dy_shape, kInstanceNormV2Grad); + if (res != KERNEL_STATUS_OK) { + return res; + }; + auto x_format = x_shape_ptr->GetFormat(); + std::vector check_shape; + check_shape = dy_shape; + int64_t image_size = 0; + if (dy_shape.size() == kDim4) { + if (x_format == FORMAT_NCHW || x_format == FORMAT_ND) { + // consider (N, C, H, W) as (N*C, H, W, 1), similar to (N, H, W, C) + check_shape[kFormatNCHWIndexH] = int64_init_one; + check_shape[kFormatNCHWIndexW] = int64_init_one; + image_size = dy_shape[kFormatNCHWIndexH] * dy_shape[kFormatNCHWIndexW]; + instance_num = dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC]; + constexpr int64_t kNumberOne = 1; + dy_shape_4d_ = {dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC], dy_shape[kFormatNCHWIndexH], + dy_shape[kFormatNCHWIndexW], kNumberOne}; + batch_channels_2d_ = {dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC], kNumberOne}; + } else if (x_format == FORMAT_NHWC) { + check_shape[kFormatNHWCIndexH] = int64_init_one; + check_shape[kFormatNHWCIndexW] = int64_init_one; + image_size = dy_shape[kFormatNHWCIndexH] * dy_shape[kFormatNHWCIndexW]; + instance_num = dy_shape[kFormatNHWCIndexN] * dy_shape[kFormatNHWCIndexC]; + dy_shape_4d_ = dy_shape; + batch_channels_2d_ = {dy_shape[kFormatNHWCIndexN], dy_shape[kFormatNHWCIndexC]}; + } + } else if (x_format == FORMAT_NC1HWC0 || dy_shape.size() == kDim5) { + // consider (N, C1, H, W, C0) as (N*C1, H, W, C0), similar to (N, H, W, C) + check_shape[kFormatNC1HWC0IndexH] = int64_init_one; + check_shape[kFormatNC1HWC0IndexW] = int64_init_one; + image_size = dy_shape[kFormatNC1HWC0IndexH] * dy_shape[kFormatNC1HWC0IndexW]; + instance_num = dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1] * dy_shape[kFormatNC1HWC0IndexC0]; + dy_shape_4d_ = {dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1], dy_shape[kFormatNC1HWC0IndexH], + dy_shape[kFormatNC1HWC0IndexW], dy_shape[kFormatNC1HWC0IndexC0]}; + batch_channels_2d_ = {dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1], + dy_shape[kFormatNC1HWC0IndexC0]}; + } else { + KERNEL_LOG_ERROR( + "For primitive[%s]'s input arguments dy and x only " + "support NHWC, NCHW and NC1HWC0, but get data format [%s]", + kInstanceNormV2Grad, FormatToSerialString(x_format).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + constexpr int64_t image_min = 1; + if (image_size <= image_min) { + KERNEL_LOG_ERROR( + "For primitive[%s], expected more than 1 value per instance, but get " + "[%ld] value per instance.", + kInstanceNormV2Grad, image_size); + return KERNEL_STATUS_PARAM_INVALID; + } + const std::map shapes = {{"gamma", gamma_shape_ptr}, + {"mean", mean_shape_ptr}, + {"variance", variance_shape_ptr}, + {"save_mean", save_mean_shape_ptr}, + {"save_variance", save_variance_shape_ptr}, + {"pd_gamma", pd_gamma_shape_ptr}, + {"pd_beta", pd_beta_shape_ptr}}; + return CheckTensorShapeSame(shapes, check_shape, kInstanceNormV2Grad); +} + +uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradAttrCheck(const CpuKernelContext &ctx) { + constexpr float epsilon_min = 0.0; + constexpr float epsilon_max = 1.0; + auto epsilon_ptr = ctx.GetAttr("epsilon"); + if (epsilon_ptr) { + epsilon_ = epsilon_ptr->GetFloat(); + } + if (epsilon_ < epsilon_min || epsilon_ >= epsilon_max) { + KERNEL_LOG_ERROR( + "For primitive[%s], attr epsilon value should be in [0, 1), but get " + "[%f].", + kInstanceNormV2Grad, epsilon_); + return KERNEL_STATUS_PARAM_INVALID; + } + auto is_training_ptr = ctx.GetAttr("is_training"); + if (is_training_ptr) { + is_training_ = is_training_ptr->GetBool(); + } + return KERNEL_STATUS_OK; +} + +uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradParamCheck(const CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(InstanceNormV2GradTypeCheck(ctx), "InstanceNormV2Grad check type failed."); + KERNEL_HANDLE_ERROR(InstanceNormV2GradShapeCheck(ctx), "InstanceNormV2Grad check shape failed."); + KERNEL_HANDLE_ERROR(InstanceNormV2GradAttrCheck(ctx), "InstanceNormV2Grad check attr failed."); + return KERNEL_STATUS_OK; +} + +template +uint32_t InstanceNormV2GradCpuKernel::DoCompute(CpuKernelContext &ctx) { + const int64_t batch = dy_shape_4d_[kFormatNHWCIndexN]; + const int64_t channel = dy_shape_4d_[kFormatNHWCIndexC]; + const int64_t image_size = dy_shape_4d_[kFormatNHWCIndexH] * dy_shape_4d_[kFormatNHWCIndexW]; + std::vector dy_shape_3d_ = {batch, image_size, channel}; + auto dy_3d = EigenTensor(dy_shape_3d_, ctx.Input(InstanceNormV2GradInDyIndex)->GetData()).tensor(); + auto in_x_3d = EigenTensor(dy_shape_3d_, ctx.Input(InstanceNormV2GradInXIndex)->GetData()).tensor(); + auto weight_matrix = + EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInGammaIndex)->GetData()).matrix(); + auto running_mean_matrix = + EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInMeanIndex)->GetData()).matrix(); + auto running_var_matrix = + EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInVarianceIndex)->GetData()).matrix(); + auto save_mean_matrix = + EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetData()).matrix(); + auto save_invstd_matrix = + EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetData()).matrix(); + + auto dx_3d = EigenTensor(dy_shape_3d_, ctx.Output(InstanceNormV2GradOutDxIndex)->GetData()).tensor(); + auto grad_weight_matrix = + EigenTensor(batch_channels_2d_, ctx.Output(InstanceNormV2GradOutPdGammaIndex)->GetData()).matrix(); + auto grad_bias_matrix = + EigenTensor(batch_channels_2d_, ctx.Output(InstanceNormV2GradOutPdBetaIndex)->GetData()).matrix(); + auto loop_batch = [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + for (int64_t c_idx = 0; c_idx < channel; ++c_idx) { + float w = weight_matrix(idx, c_idx); + float mean = float_init_zero, invstd = float_init_zero; + mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx); + float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_); + invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_; + + double sum = double_init_zero, dotp = double_init_zero; + for (int64_t img_idx = 0; img_idx < image_size; ++img_idx) { + sum += static_cast(dy_3d(idx, img_idx, c_idx)); + dotp += (static_cast(in_x_3d(idx, img_idx, c_idx)) - FloatToDouble(mean)) * + static_cast(dy_3d(idx, img_idx, c_idx)); + } + float k = static_cast(dotp * FloatToDouble(invstd) * FloatToDouble(invstd) / LongToDouble(image_size)); + float grad_mean = static_cast(sum / LongToDouble(image_size)); + for (int64_t img_idx = 0; img_idx < image_size; ++img_idx) { + float _dx_ = (static_cast(in_x_3d(idx, img_idx, c_idx)) - mean) * k; + dx_3d(idx, img_idx, c_idx) = + is_training_ + ? static_cast((static_cast(dy_3d(idx, img_idx, c_idx)) - grad_mean - _dx_) * invstd * w) + : static_cast(static_cast(dy_3d(idx, img_idx, c_idx)) * invstd * w); + } + grad_weight_matrix(idx, c_idx) = static_cast(dotp * FloatToDouble(invstd)); + grad_bias_matrix(idx, c_idx) = static_cast(sum); + } + } + }; + int64_t block_size = std::max(int64_init_one, (kGrainSize / (channel * image_size))); + return CpuKernelUtils::ParallelFor(ctx, batch, block_size, loop_batch); +} + +uint32_t InstanceNormV2GradCpuKernel::Compute(CpuKernelContext &ctx) { + // check params + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), + "InstanceNormV2Grad check input and output number failed."); + KERNEL_HANDLE_ERROR(InstanceNormV2GradParamCheck(ctx), "InstanceNormV2Grad check params failed."); + auto data_type = ctx.Input(0)->GetDataType(); + uint32_t result; + switch (data_type) { + case (DT_FLOAT16): + result = DoCompute(ctx); + break; + case (DT_FLOAT): + result = DoCompute(ctx); + break; + default: + KERNEL_LOG_ERROR("InstanceNormV2Grad kernel data type [%s] not support.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + if (result != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("InstanceNormV2Grad kernel compute failed."); + } + return result; +} + +REGISTER_CPU_KERNEL(kInstanceNormV2Grad, InstanceNormV2GradCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.h new file mode 100644 index 00000000000..b220a1dd952 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/instance_norm_v2_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_GRAD_H_ +#define AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_GRAD_H_ + +#include "cpu_ops_kernel.h" +#include + +namespace aicpu { +class InstanceNormV2GradCpuKernel : public CpuKernel { + public: + InstanceNormV2GradCpuKernel() = default; + + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + uint32_t InstanceNormV2GradParamCheck(const CpuKernelContext &ctx); + uint32_t InstanceNormV2GradShapeCheck(const CpuKernelContext &ctx); + uint32_t InstanceNormV2GradTypeCheck(const CpuKernelContext &ctx); + uint32_t InstanceNormV2GradAttrCheck(const CpuKernelContext &ctx); + + template + uint32_t DoCompute(CpuKernelContext &ctx); + + bool is_training_ = true; + float epsilon_ = 0.00001; + std::vector dy_shape_4d_; + std::vector batch_channels_2d_; + int64_t instance_num = 0; +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc new file mode 100644 index 00000000000..60834fd9bae --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc @@ -0,0 +1,299 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sparse_to_dense.h" +#include +#include "cpu_types.h" +#include "kernel_log.h" +#include "status.h" +#include "unsupported/Eigen/CXX11/Tensor" +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" + +namespace aicpu { +const char *const SPARSETODENSE = "SparseToDense"; +constexpr int64_t kParallelDateSize = 16 * 1024; +constexpr int64_t kCopyDataSize = 1024; +constexpr uint32_t kInput0 = 0; +constexpr uint32_t kInput1 = 1; +constexpr uint32_t kInput2 = 2; +constexpr uint32_t kInput3 = 3; +constexpr uint32_t kOutput0 = 0; +constexpr int32_t kRank = 2; +} // namespace aicpu + +namespace aicpu { +uint32_t SparseToDenseCpuKernel::SparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices, + Tensor *output) { + KERNEL_LOG_INFO("Start to execute SparseToDense"); + if (indices == nullptr || output == nullptr) { + KERNEL_LOG_ERROR("Indices or output tensor is nullptr."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + + DataType dt = static_cast(output->GetDataType()); + switch (dt) { + case DT_INT8: + return EigenSparseToDense(ctx, st, indices, output); + case DT_UINT8: + return EigenSparseToDense(ctx, st, indices, output); + case DT_INT16: + return EigenSparseToDense(ctx, st, indices, output); + case DT_UINT16: + return EigenSparseToDense(ctx, st, indices, output); + case DT_INT32: + return EigenSparseToDense(ctx, st, indices, output); + case DT_INT64: + return EigenSparseToDense(ctx, st, indices, output); + case DT_FLOAT16: + return EigenSparseToDense(ctx, st, indices, output); + case DT_FLOAT: + return EigenSparseToDense(ctx, st, indices, output); + case DT_BOOL: + return EigenSparseToDense(ctx, st, indices, output); + case DT_DOUBLE: + return EigenSparseToDense(ctx, st, indices, output); + default: + KERNEL_LOG_ERROR("Sparse to dense can't support this data type [%d].", static_cast(dt)); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } +} + +KernelStatus SparseToDenseCpuKernel::ValidParam(const CpuKernelContext &ctx) { + KERNEL_LOG_INFO("Start to execute ValidParam"); + // valid input and output nullptr + Tensor *indices_tensor = ctx.Input(0); + Tensor *shape_tensor = ctx.Input(1); + Tensor *sparse_values = ctx.Input(2); + Tensor *default_value_tensor = ctx.Input(3); + Tensor *output_tensor = ctx.Output(0); + bool validNull = ((output_tensor == nullptr) || default_value_tensor == nullptr || (sparse_values == nullptr) || + (indices_tensor == nullptr) || (shape_tensor == nullptr)); + if (validNull) { + KERNEL_LOG_ERROR("Got input or output param is nullptr."); + return KERNEL_STATUS_PARAM_INVALID; + } + + // valid shape nullptr + auto output_shape = shape_tensor->GetTensorShape(); + auto values_shape = sparse_values->GetTensorShape(); + auto default_value_shape = default_value_tensor->GetTensorShape(); + auto indices_shape = indices_tensor->GetTensorShape(); + bool validShapeNull = ((default_value_shape == nullptr) || values_shape == nullptr || (output_shape == nullptr) || + (indices_shape == nullptr)); + if (validShapeNull) { + KERNEL_LOG_ERROR("Got input shape is nullptr."); + return KERNEL_STATUS_PARAM_INVALID; + } + + // sparse_indices + if (indices_shape->GetDims() > kRank) { + KERNEL_LOG_ERROR( + "Sparse_indices should be a scalar, vector, or matrix, got dim " + "size [%d].", + indices_shape->GetDims()); + return KERNEL_STATUS_PARAM_INVALID; + } + const int64_t elems_num = indices_shape->GetDims() > 0 ? indices_shape->GetDimSize(0) : 1; + const int64_t dims_num = indices_shape->GetDims() > 1 ? indices_shape->GetDimSize(1) : 1; + + // output_shape + if (output_shape->GetDims() != 1) { + KERNEL_LOG_ERROR("Output_shape should be a vector, and got dim size [%d].", output_shape->GetDims()); + return KERNEL_STATUS_PARAM_INVALID; + } + if (shape_tensor->NumElements() != dims_num) { + KERNEL_LOG_ERROR("Output_shape has incorrect number of elements [%lld], should be [%lld]", + shape_tensor->NumElements(), dims_num); + return KERNEL_STATUS_PARAM_INVALID; + } + + // valid data type + DataType IndiceType = indices_tensor->GetDataType(); + DataType outShapeType = shape_tensor->GetDataType(); + bool validIndiceType = ((IndiceType != DT_INT32) && (IndiceType != DT_INT64)); + bool validShapeType = ((outShapeType != DT_INT32) && (outShapeType != DT_INT64)); + if (validShapeType || validIndiceType) { + KERNEL_LOG_ERROR( + "Valid indice or output shape data type failed, indiceType [%d], " + "shapeType [%d].", + static_cast(IndiceType), static_cast(outShapeType)); + return KERNEL_STATUS_PARAM_INVALID; + } + + // sparse_values + int32_t values_dims_size = values_shape->GetDims(); + if ((values_dims_size != 0) && (values_dims_size != 1)) { + KERNEL_LOG_ERROR("Values_shape should be a scalar or a vector, got dim size [%d].", values_shape->GetDims()); + return KERNEL_STATUS_PARAM_INVALID; + } + if ((values_dims_size == 1) && (sparse_values->NumElements() != elems_num)) { + KERNEL_LOG_ERROR("Values_shape has incorrect number of elements [%lld], should be [%lld]", + sparse_values->NumElements(), elems_num); + return KERNEL_STATUS_PARAM_INVALID; + } + + // default_value + if (default_value_shape->GetDims() != 0) { + KERNEL_LOG_ERROR("Default_value should be a scalar, and got dim size [%d].", default_value_shape->GetDims()); + return KERNEL_STATUS_PARAM_INVALID; + } + KERNEL_LOG_INFO("Execute ValidParam end."); + return KERNEL_STATUS_OK; +} + +uint32_t SparseToDenseCpuKernel::ParallelSetDefaultValue(const CpuKernelContext &ctx, + const Tensor *default_value_tensor, + const Tensor *output_tensor, int64_t output_size) { + auto type_size = GetSizeByDataType(static_cast(output_tensor->GetDataType())); + char *default_value_addr = reinterpret_cast(default_value_tensor->GetData()); + char *output_addr = reinterpret_cast(output_tensor->GetData()); + uint32_t min_core_num = 1; + int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + auto default_value = [&](std::int64_t begin, std::int64_t end) { + int64_t total = end - begin; + int64_t remainder = total % kCopyDataSize; + int64_t piece = total / kCopyDataSize; + if (piece == 0) { + for (int64_t index = begin; index < end; index++) { + (void)memcpy_s(output_addr + (index * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + } else { + for (int64_t index = begin; index < begin + kCopyDataSize; index++) { + (void)memcpy_s(output_addr + (index * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + char *temp_addr = output_addr + (begin * type_size); + size_t data_size = static_cast(type_size * kCopyDataSize); + for (int64_t loop = 1; loop < piece; loop++) { + (void)memcpy_s(temp_addr + (loop * type_size * kCopyDataSize), data_size, temp_addr, data_size); + } + char *temp_addr1 = output_addr + (begin * type_size) + (piece * type_size * kCopyDataSize); + for (int64_t loop1 = 0; loop1 < remainder; loop1++) { + (void)memcpy_s(temp_addr1 + (loop1 * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + } + }; + return CpuKernelUtils::ParallelFor(ctx, output_size, output_size / max_core_num, default_value); +} +uint32_t SparseToDenseCpuKernel::SetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor, + const Tensor *output_tensor, int64_t output_size) { + auto type_size = GetSizeByDataType(static_cast(output_tensor->GetDataType())); + if (type_size < 1) { + KERNEL_LOG_ERROR("Don't support output tensor types"); + return KERNEL_STATUS_PARAM_INVALID; + } + char *default_value_addr = reinterpret_cast(default_value_tensor->GetData()); + char *output_addr = reinterpret_cast(output_tensor->GetData()); + if (output_size < kParallelDateSize) { + int64_t remainder = output_size % kCopyDataSize; + int64_t piece = output_size / kCopyDataSize; + if (piece == 0) { + for (int index = 0; index < output_size; index++) { + (void)memcpy_s(output_addr + (index * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + } else { + for (int index = 0; index < kCopyDataSize; index++) { + (void)memcpy_s(output_addr + (index * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + size_t data_size = static_cast(type_size * kCopyDataSize); + for (int loop = 1; loop < piece; loop++) { + (void)memcpy_s(output_addr + (loop * type_size * kCopyDataSize), data_size, output_addr, data_size); + } + char *temp_addr = output_addr + (piece * type_size * kCopyDataSize); + for (int loop1 = 0; loop1 < remainder; loop1++) { + (void)memcpy_s(temp_addr + (loop1 * type_size), static_cast(type_size), default_value_addr, + static_cast(type_size)); + } + } + return KERNEL_STATUS_OK; + } else { + return ParallelSetDefaultValue(ctx, default_value_tensor, output_tensor, output_size); + } +} +uint32_t SparseToDenseCpuKernel::Compute(CpuKernelContext &ctx) { + if (ValidParam(ctx) != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("Valid sparse to dense param error."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + Tensor *indices_tensor = ctx.Input(kInput0); + KERNEL_CHECK_NULLPTR(indices_tensor, static_cast(KERNEL_STATUS_PARAM_INVALID), "Indices_tensor is null") + Tensor *shape_tensor = ctx.Input(kInput1); + KERNEL_CHECK_NULLPTR(shape_tensor, static_cast(KERNEL_STATUS_PARAM_INVALID), "Shape_tensor is null") + Tensor *sparse_values = ctx.Input(kInput2); + KERNEL_CHECK_NULLPTR(sparse_values, static_cast(KERNEL_STATUS_PARAM_INVALID), "Sparse_values is null") + Tensor *default_value_tensor = ctx.Input(kInput3); + KERNEL_CHECK_NULLPTR(default_value_tensor, static_cast(KERNEL_STATUS_PARAM_INVALID), + "Default_value_tensor is null") + Tensor *output_tensor = ctx.Output(kOutput0); + KERNEL_CHECK_NULLPTR(output_tensor, static_cast(KERNEL_STATUS_PARAM_INVALID), "Output_tensor is null") + + auto output_shape = shape_tensor->GetTensorShape(); + std::vector dense_shape; + std::vector order; + int64_t output_size = 1; + size_t output_zero_dim_size = static_cast(output_shape->GetDimSize(0)); + for (size_t index = 0; index < output_zero_dim_size; ++index) { + if (shape_tensor->GetDataType() == DT_INT32) { + int32_t *temp_dim = reinterpret_cast(shape_tensor->GetData()); + dense_shape.emplace_back(static_cast(temp_dim[index])); + } else { + int64_t *temp_dim = reinterpret_cast(shape_tensor->GetData()); + dense_shape.emplace_back(temp_dim[index]); + } + order.push_back(dense_shape[index]); + output_size *= dense_shape[index]; + } + + std::iota(order.begin(), order.end(), 0); + + SparseTensor st; + if (st.CreateSparseTensor(indices_tensor, sparse_values, dense_shape, order) != + static_cast(KERNEL_STATUS_OK)) { + KERNEL_LOG_ERROR("Create sparse tensor failed."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + AttrValue *validate_indices = ctx.GetAttr("validate_indices"); + if (validate_indices == nullptr) { + KERNEL_LOG_ERROR("Get attr:validate_indices failed."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + if (validate_indices->GetBool()) { + if (st.IndicesValid(ctx) != static_cast(KERNEL_STATUS_OK)) { + KERNEL_LOG_ERROR("Indices is valid."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + } + + if (SetDefaultValue(ctx, default_value_tensor, output_tensor, output_size) != + static_cast(KERNEL_STATUS_OK)) { + KERNEL_LOG_ERROR("Sparse_to_dense set default value failed."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + + if (SparseToDense(ctx, st, indices_tensor, output_tensor) != static_cast(KERNEL_STATUS_OK)) { + KERNEL_LOG_ERROR("Sparse_to_dense execute failed."); + return static_cast(KERNEL_STATUS_PARAM_INVALID); + } + return static_cast(KERNEL_STATUS_OK); +} + +REGISTER_CPU_KERNEL(SPARSETODENSE, SparseToDenseCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.h new file mode 100644 index 00000000000..335f0b419c7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.h @@ -0,0 +1,92 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_SPARSETODENSE_H_ +#define AICPU_KERNELS_NORMALIZED_SPARSETODENSE_H_ + +#include "cpu_ops_kernel.h" +#include "utils/sparse_tensor.h" + +namespace aicpu { + +class SparseToDenseCpuKernel : public CpuKernel { + public: + ~SparseToDenseCpuKernel() = default; + + /* + * compute sparse to dense + * @param ctx: cpu kernel context + * @return uint32_t: 0->success other->failed + */ + uint32_t Compute(CpuKernelContext &ctx) override; + + protected: + /* + * valid sparse to dense param + * @param st: sparse tensor + * @param indices: indices tensor + * @param output: output tensor + * @return uint32_t: 0->success other->failed + */ + template + uint32_t EigenSparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices, Tensor *output) { + if (indices->GetDataType() == DT_INT32) { + return st.ToDense(ctx, output); + } else { + return st.ToDense(ctx, output); + } + } + + /* + * valid sparse to dense param + * @param st: sparse tensor + * @param indices: indices tensor + * @param output: output tensor + * @return uint32_t: 0->success other->failed + */ + uint32_t SparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices, Tensor *output); + + /* + * valid sparse to dense param + * @param ctx: cpu kernel context + * @return uint32_t: 0->success other->failed + */ + KernelStatus ValidParam(const CpuKernelContext &ctx); + + /* + * parallel set default value to dense + * @param ctx: cpu kernel context + * @param default_value_tensor: default value of dense tensor + * @param output_tensor: output tensor + * @param output_size: output tensor size + * @return uint32_t: 0->success other->failed + */ + uint32_t ParallelSetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor, + const Tensor *output_tensor, int64_t output_size); + + /* + * set default value to dense + * @param ctx: cpu kernel context + * @param default_value_tensor: default value of dense tensor + * @param output_tensor: output tensor + * @param output_size: output tensor size + * @return uint32_t: 0->success other->failed + */ + uint32_t SetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor, const Tensor *output_tensor, + int64_t output_size); +}; + +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index 6be41432af7..1b21786c221 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -317,7 +317,11 @@ bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const { mindspore::kMaskedSelectOpName, mindspore::kMaskedSelectGradOpName, mindspore::kMultiMarginLossOpName, - mindspore::kMatrixInverseOpName}; + mindspore::kMatrixInverseOpName, + mindspore::kMultiMarginLossGradOpName, + mindspore::kSspaddmmOpName, + mindspore::kBatchMatMulOpName, + mindspore::kSparseToDenseOpName}; static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels"; static const std::string kCpuKernelSoName = "mindspore_cpu_kernels"; diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index fad75bacc11..e38d08daa19 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -410,3 +410,7 @@ from .trace import _trace_aicpu from .tracegrad import _tracegrad_aicpu from .tridiagonal_solve import _tridiagonal_solve_aicpu from .truncated_normal import _truncated_normal_aicpu +from .glu import _glu_aicpu +from .multi_margin_loss import _multi_margin_loss_aicpu +from .multi_margin_loss_grad import _multi_margin_loss_grad_aicpu +from .sparse_to_dense_v2 import _sparse_to_dense_v2_aicpu