forked from mindspore-Ecosystem/mindspore
!48986 fix aicpu migration issues
Merge pull request !48986 from 李林杰/0216_fix_aicpu_migration_issues_master
This commit is contained in:
commit
0101b59ec2
|
@ -367,3 +367,5 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_unpool_2d_grad.cc:aicpu::MaxUnpool2DGradCpuKernel::MaxUnpool2DGradCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/max_pool_3d_with_argmax.cc:aicpu::MaxPool3DWithArgmaxCpuKernel::MaxPool3DWithArgmaxCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/layer_norm_grad_grad.cc:aicpu::LayerNormGradGradCpuKernel::LayerNormGradGradCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/batchmatmul.cc:aicpu::BatchMatMulCpuKernel::DoCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_to_dense.cc:aicpu::SparseToDenseCpuKernel::ValidParam
|
||||
|
|
|
@ -79,7 +79,6 @@ MindSpore中 `mindspore.ops.primitive` 接口与上一版本相比,新增、
|
|||
mindspore.ops.MaxUnpool2D
|
||||
mindspore.ops.MirrorPad
|
||||
mindspore.ops.NthElement
|
||||
mindspore.ops.NuclearNorm
|
||||
mindspore.ops.Pad
|
||||
mindspore.ops.Padding
|
||||
mindspore.ops.PadV3
|
||||
|
|
|
@ -78,7 +78,6 @@ Neural Network
|
|||
mindspore.ops.MaxUnpool2D
|
||||
mindspore.ops.MirrorPad
|
||||
mindspore.ops.NthElement
|
||||
mindspore.ops.NuclearNorm
|
||||
mindspore.ops.Pad
|
||||
mindspore.ops.EmbeddingLookup
|
||||
mindspore.ops.Padding
|
||||
|
|
|
@ -758,6 +758,7 @@ constexpr auto kSparseSparseMaximumOpName = "SparseSparseMaximum";
|
|||
constexpr auto kSparseTensorDenseMatMulOpName = "SparseTensorDenseMatMul";
|
||||
constexpr auto kSparseTensorDenseAddOpName = "SparseTensorDenseAdd";
|
||||
constexpr auto kSparseTensorToCSRSparseMatrixOpName = "SparseTensorToCSRSparseMatrix";
|
||||
constexpr auto kSparseToDenseOpName = "SparseToDense";
|
||||
constexpr auto kSplitOpName = "Split";
|
||||
constexpr auto kSplitDOpName = "SplitD";
|
||||
constexpr auto kSplitVOpName = "SplitV";
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "batchmatmul.h"
|
||||
|
||||
#include <complex>
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
const char *kBatchMatmul = "BatchMatMul";
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const int64_t kParallelDataNum = 1024;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
template <typename T>
|
||||
uint32_t BatchMatMulCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
auto input0_tensor = ctx.Input(0);
|
||||
auto input0_tensor_shape = input0_tensor->GetTensorShape();
|
||||
int32_t input0_tensor_dims = input0_tensor_shape->GetDims();
|
||||
KERNEL_CHECK_FALSE((input0_tensor_dims > 1), KERNEL_STATUS_PARAM_INVALID, "Input[x1] must be a matrix or higher.")
|
||||
|
||||
auto input1_tensor = ctx.Input(1);
|
||||
auto input1_tensor_shape = input1_tensor->GetTensorShape();
|
||||
int32_t input1_tensor_dims = input1_tensor_shape->GetDims();
|
||||
KERNEL_CHECK_FALSE((input1_tensor_dims > 1), KERNEL_STATUS_PARAM_INVALID, "Input[x2] must be a matrix or higher.")
|
||||
|
||||
auto output_tensor = ctx.Output(0);
|
||||
DataType input0_data_type = input0_tensor->GetDataType();
|
||||
DataType input1_data_type = input1_tensor->GetDataType();
|
||||
DataType output_data_type = output_tensor->GetDataType();
|
||||
KERNEL_CHECK_FALSE((input0_data_type == input1_data_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input[x1] data type[%s] and input[x2] data type[%s] must be same.",
|
||||
DTypeStr(input0_data_type).c_str(), DTypeStr(input1_data_type).c_str())
|
||||
KERNEL_CHECK_FALSE((input0_data_type == output_data_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input data type[%s] and output data type[%s] must be same.", DTypeStr(input0_data_type).c_str(),
|
||||
DTypeStr(output_data_type).c_str())
|
||||
bool adj_x = false;
|
||||
bool adj_y = false;
|
||||
auto adj_x1 = ctx.GetAttr("adj_x1");
|
||||
auto adj_x2 = ctx.GetAttr("adj_x2");
|
||||
if (adj_x1 != nullptr) {
|
||||
adj_x = adj_x1->GetBool();
|
||||
}
|
||||
if (adj_x2 != nullptr) {
|
||||
adj_y = adj_x2->GetBool();
|
||||
}
|
||||
KERNEL_LOG_DEBUG(
|
||||
"%s Attr[adj_x1] value[%d], "
|
||||
"Attr[adj_x2] value[%d].",
|
||||
kBatchMatmul, adj_x, adj_y);
|
||||
|
||||
int32_t x1_dim = adj_x ? input0_tensor_dims - 2 : input0_tensor_dims - 1;
|
||||
int32_t x2_dim = adj_y ? input1_tensor_dims - 1 : input1_tensor_dims - 2;
|
||||
KERNEL_CHECK_FALSE((input0_tensor_shape->GetDimSize(x1_dim) == input1_tensor_shape->GetDimSize(x2_dim)),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Matrix size incompatible, input[x1] dim[%d] value[%lld], "
|
||||
"input[x2] dim[%d] value[%lld]",
|
||||
x1_dim, input0_tensor_shape->GetDimSize(x1_dim), x2_dim, input1_tensor_shape->GetDimSize(x2_dim))
|
||||
KERNEL_CHECK_FALSE((input0_tensor_dims == input1_tensor_dims), KERNEL_STATUS_PARAM_INVALID,
|
||||
"input0_tensor_dims value[%d] is not equal to "
|
||||
"input1_tensor_dims value[%d]",
|
||||
input0_tensor_dims, input1_tensor_dims)
|
||||
|
||||
auto input0_shape = input0_tensor_shape->GetDimSizes();
|
||||
auto input1_shape = input1_tensor_shape->GetDimSizes();
|
||||
auto output_shape = output_tensor->GetTensorShape()->GetDimSizes();
|
||||
|
||||
for (int32_t i = 0; i < input0_tensor_dims - 2; i++) {
|
||||
KERNEL_CHECK_FALSE((input0_shape[i] == input1_shape[i]), KERNEL_STATUS_PARAM_INVALID,
|
||||
"input0_shape dim[%d] value[%lld] is not equal to "
|
||||
"input1_shape dim[%d] value[%lld]",
|
||||
i, input0_shape[i], i, input1_shape[i])
|
||||
|
||||
KERNEL_CHECK_FALSE((input0_shape[i] == output_shape[i]), KERNEL_STATUS_PARAM_INVALID,
|
||||
"input0_shape dim[%d] value[%lld] is not equal to "
|
||||
"output_shape dim[%d] value[%lld]",
|
||||
i, input0_shape[i], i, output_shape[i])
|
||||
}
|
||||
|
||||
int32_t map1_l = input0_shape[input0_tensor_dims - 2];
|
||||
int32_t map1_r = input0_shape[input0_tensor_dims - 1];
|
||||
|
||||
int32_t map2_l = input1_shape[input1_tensor_dims - 2];
|
||||
int32_t map2_r = input1_shape[input1_tensor_dims - 1];
|
||||
|
||||
int32_t rows_dim = adj_x ? input0_tensor_dims - 1 : input0_tensor_dims - 2;
|
||||
int32_t cols_dim = adj_y ? input1_tensor_dims - 2 : input1_tensor_dims - 1;
|
||||
|
||||
int32_t num_rows = input0_tensor_shape->GetDimSize(rows_dim);
|
||||
int32_t num_cols = input1_tensor_shape->GetDimSize(cols_dim);
|
||||
|
||||
uint64_t num_element = output_tensor->NumElements();
|
||||
uint64_t num_batches = num_element / (num_rows * num_cols);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map1(map1_l, map1_r);
|
||||
auto *input0_data = reinterpret_cast<T *>(input0_tensor->GetData());
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map2(map2_l, map2_r);
|
||||
auto *input1_data = reinterpret_cast<T *>(input1_tensor->GetData());
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map_output(num_rows, num_cols);
|
||||
auto *output_data = reinterpret_cast<T *>(output_tensor->GetData());
|
||||
uint64_t input0_size = input0_tensor->GetDataSize();
|
||||
if (input0_size >= kParallelDataNum) {
|
||||
uint32_t min_core_num = 1;
|
||||
uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx));
|
||||
if (max_core_num > num_batches) {
|
||||
max_core_num = num_batches;
|
||||
}
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map1[num_batches];
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map2[num_batches];
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> map_output[num_batches];
|
||||
auto shared_batchmatmul = [&](int64_t start, int64_t end) {
|
||||
for (int64_t batch = start; batch < end; ++batch) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> matrix1(map1_l, map1_r);
|
||||
map1[batch].resize(map1_l, map1_r);
|
||||
for (int64_t i = 0; i < map1_l; i++) {
|
||||
for (int64_t j = 0; j < map1_r; j++) {
|
||||
map1[batch](i, j) = input0_data[batch * map1_l * map1_r + i * map1_r + j];
|
||||
}
|
||||
}
|
||||
map2[batch].resize(map2_l, map2_r);
|
||||
for (int64_t i = 0; i < map2_l; i++) {
|
||||
for (int64_t j = 0; j < map2_r; j++) {
|
||||
map2[batch](i, j) = input1_data[batch * map2_l * map2_r + i * map2_r + j];
|
||||
}
|
||||
}
|
||||
if (adj_x) {
|
||||
if (adj_y) {
|
||||
map_output[batch] = map1[batch].adjoint() * map2[batch].adjoint();
|
||||
} else {
|
||||
map_output[batch] = map1[batch].adjoint() * map2[batch];
|
||||
}
|
||||
} else {
|
||||
if (adj_y) {
|
||||
map_output[batch] = map1[batch] * map2[batch].adjoint();
|
||||
} else {
|
||||
map_output[batch] = map1[batch] * map2[batch];
|
||||
}
|
||||
}
|
||||
map_output[batch].resize(num_rows, num_cols);
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
for (int64_t j = 0; j < num_cols; ++j) {
|
||||
output_data[batch * num_rows * num_cols + i * num_cols + j] = map_output[batch](i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_batches, num_batches / max_core_num, shared_batchmatmul),
|
||||
"BatchMatMul Compute failed.");
|
||||
} else {
|
||||
for (uint64_t batch = 0; batch < num_batches; ++batch) {
|
||||
for (int64_t i = 0; i < map1_l; i++) {
|
||||
for (int64_t j = 0; j < map1_r; j++) {
|
||||
map1(i, j) = input0_data[batch * map1_l * map1_r + i * map1_r + j];
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < map2_l; i++) {
|
||||
for (int64_t j = 0; j < map2_r; j++) {
|
||||
map2(i, j) = input1_data[batch * map2_l * map2_r + i * map2_r + j];
|
||||
}
|
||||
}
|
||||
if (adj_x) {
|
||||
if (adj_y) {
|
||||
map_output = map1.adjoint() * map2.adjoint();
|
||||
} else {
|
||||
map_output = map1.adjoint() * map2;
|
||||
}
|
||||
} else {
|
||||
if (adj_y) {
|
||||
map_output = map1 * map2.adjoint();
|
||||
} else {
|
||||
map_output = map1 * map2;
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
for (int64_t j = 0; j < num_cols; ++j) {
|
||||
output_data[batch * num_rows * num_cols + i * num_cols + j] = map_output(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t BatchMatMulCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check BatchMatMul params failed.");
|
||||
DataType input0_data_type = ctx.Input(0)->GetDataType();
|
||||
KERNEL_LOG_DEBUG("%s op input[x1] data type is [%s].", kBatchMatmul, DTypeStr(input0_data_type).c_str());
|
||||
uint32_t ret = KERNEL_STATUS_OK;
|
||||
switch (input0_data_type) {
|
||||
case DT_FLOAT:
|
||||
ret = DoCompute<float>(ctx);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
ret = DoCompute<double>(ctx);
|
||||
break;
|
||||
case DT_FLOAT16:
|
||||
ret = DoCompute<Eigen::half>(ctx);
|
||||
break;
|
||||
case DT_INT16:
|
||||
ret = DoCompute<int16_t>(ctx);
|
||||
break;
|
||||
case DT_INT32:
|
||||
ret = DoCompute<int32_t>(ctx);
|
||||
break;
|
||||
case DT_INT64:
|
||||
ret = DoCompute<int64_t>(ctx);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
ret = DoCompute<uint16_t>(ctx);
|
||||
break;
|
||||
case DT_UINT32:
|
||||
ret = DoCompute<uint32_t>(ctx);
|
||||
break;
|
||||
case DT_UINT64:
|
||||
ret = DoCompute<uint64_t>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
ret = DoCompute<std::complex<std::float_t>>(ctx);
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
ret = DoCompute<std::complex<std::double_t>>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Unsupported input[x1] data type[%s]", DTypeStr(input0_data_type).c_str());
|
||||
ret = KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kBatchMatmul, BatchMatMulCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_BATCHMATMUL_H_
|
||||
#define AICPU_KERNELS_HOST_BATCHMATMUL_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class BatchMatMulCpuKernel : public CpuKernel {
|
||||
public:
|
||||
BatchMatMulCpuKernel() = default;
|
||||
~BatchMatMulCpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,266 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "instance_norm_v2_grad.h"
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "securec.h"
|
||||
#include "status.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
|
||||
namespace {
|
||||
const char *const kInstanceNormV2Grad = "InstanceNormV2Grad";
|
||||
constexpr uint32_t kOutputNum = 3;
|
||||
constexpr uint32_t kInputNum = 7;
|
||||
constexpr int64_t int64_init_one = 1;
|
||||
constexpr int64_t kGrainSize = 4 * 1024;
|
||||
constexpr float float_init_zero = 0.0;
|
||||
constexpr float float_init_one = 1.0;
|
||||
constexpr double double_init_zero = 0.0;
|
||||
constexpr auto kDim3 = 3;
|
||||
constexpr auto kDim4 = 4;
|
||||
constexpr auto kDim5 = 5;
|
||||
constexpr auto InstanceNormV2GradInDyIndex = 0;
|
||||
constexpr auto InstanceNormV2GradInXIndex = 1;
|
||||
constexpr auto InstanceNormV2GradInGammaIndex = 2;
|
||||
constexpr auto InstanceNormV2GradInMeanIndex = 3;
|
||||
constexpr auto InstanceNormV2GradInVarianceIndex = 4;
|
||||
constexpr auto InstanceNormV2GradInSaveMeanIndex = 5;
|
||||
constexpr auto InstanceNormV2GradInSaveVarianceIndex = 6;
|
||||
constexpr auto InstanceNormV2GradOutDxIndex = 0;
|
||||
constexpr auto InstanceNormV2GradOutPdGammaIndex = 1;
|
||||
constexpr auto InstanceNormV2GradOutPdBetaIndex = 2;
|
||||
|
||||
inline double FloatToDouble(float v) { return static_cast<double>(v); }
|
||||
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradTypeCheck(const CpuKernelContext &ctx) {
|
||||
auto dy_type = ctx.Input(InstanceNormV2GradInDyIndex)->GetDataType();
|
||||
auto x_type = ctx.Input(InstanceNormV2GradInXIndex)->GetDataType();
|
||||
auto gamma_type = ctx.Input(InstanceNormV2GradInGammaIndex)->GetDataType();
|
||||
auto mean_type = ctx.Input(InstanceNormV2GradInMeanIndex)->GetDataType();
|
||||
auto variance_type = ctx.Input(InstanceNormV2GradInVarianceIndex)->GetDataType();
|
||||
auto save_mean_type = ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetDataType();
|
||||
auto save_variance_type = ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetDataType();
|
||||
|
||||
if (dy_type != x_type) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments dy and x should have the same "
|
||||
"data type, but dy type is [%s], x type is [%s].",
|
||||
kInstanceNormV2Grad, DTypeStr(dy_type).c_str(), DTypeStr(x_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const std::map<std::string, DataType> types = {{"gamma", gamma_type},
|
||||
{"mean", mean_type},
|
||||
{"variance", variance_type},
|
||||
{"save_mean", save_mean_type},
|
||||
{"save_variance", save_variance_type}};
|
||||
return CheckTensorTypeSame(types, DT_FLOAT, kInstanceNormV2Grad);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradShapeCheck(const CpuKernelContext &ctx) {
|
||||
auto dy_shape_ptr = ctx.Input(InstanceNormV2GradInDyIndex)->GetTensorShape();
|
||||
auto x_shape_ptr = ctx.Input(InstanceNormV2GradInXIndex)->GetTensorShape();
|
||||
auto gamma_shape_ptr = ctx.Input(InstanceNormV2GradInGammaIndex)->GetTensorShape();
|
||||
auto mean_shape_ptr = ctx.Input(InstanceNormV2GradInMeanIndex)->GetTensorShape();
|
||||
auto variance_shape_ptr = ctx.Input(InstanceNormV2GradInVarianceIndex)->GetTensorShape();
|
||||
auto save_mean_shape_ptr = ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetTensorShape();
|
||||
auto save_variance_shape_ptr = ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetTensorShape();
|
||||
auto pd_gamma_shape_ptr = ctx.Output(InstanceNormV2GradOutPdGammaIndex)->GetTensorShape();
|
||||
auto pd_beta_shape_ptr = ctx.Output(InstanceNormV2GradOutPdBetaIndex)->GetTensorShape();
|
||||
|
||||
auto dy_shape = dy_shape_ptr->GetDimSizes();
|
||||
auto res = CheckTensorShapeSame({{"input x", x_shape_ptr}}, dy_shape, kInstanceNormV2Grad);
|
||||
if (res != KERNEL_STATUS_OK) {
|
||||
return res;
|
||||
};
|
||||
auto x_format = x_shape_ptr->GetFormat();
|
||||
std::vector<int64_t> check_shape;
|
||||
check_shape = dy_shape;
|
||||
int64_t image_size = 0;
|
||||
if (dy_shape.size() == kDim4) {
|
||||
if (x_format == FORMAT_NCHW || x_format == FORMAT_ND) {
|
||||
// consider (N, C, H, W) as (N*C, H, W, 1), similar to (N, H, W, C)
|
||||
check_shape[kFormatNCHWIndexH] = int64_init_one;
|
||||
check_shape[kFormatNCHWIndexW] = int64_init_one;
|
||||
image_size = dy_shape[kFormatNCHWIndexH] * dy_shape[kFormatNCHWIndexW];
|
||||
instance_num = dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC];
|
||||
constexpr int64_t kNumberOne = 1;
|
||||
dy_shape_4d_ = {dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC], dy_shape[kFormatNCHWIndexH],
|
||||
dy_shape[kFormatNCHWIndexW], kNumberOne};
|
||||
batch_channels_2d_ = {dy_shape[kFormatNCHWIndexN] * dy_shape[kFormatNCHWIndexC], kNumberOne};
|
||||
} else if (x_format == FORMAT_NHWC) {
|
||||
check_shape[kFormatNHWCIndexH] = int64_init_one;
|
||||
check_shape[kFormatNHWCIndexW] = int64_init_one;
|
||||
image_size = dy_shape[kFormatNHWCIndexH] * dy_shape[kFormatNHWCIndexW];
|
||||
instance_num = dy_shape[kFormatNHWCIndexN] * dy_shape[kFormatNHWCIndexC];
|
||||
dy_shape_4d_ = dy_shape;
|
||||
batch_channels_2d_ = {dy_shape[kFormatNHWCIndexN], dy_shape[kFormatNHWCIndexC]};
|
||||
}
|
||||
} else if (x_format == FORMAT_NC1HWC0 || dy_shape.size() == kDim5) {
|
||||
// consider (N, C1, H, W, C0) as (N*C1, H, W, C0), similar to (N, H, W, C)
|
||||
check_shape[kFormatNC1HWC0IndexH] = int64_init_one;
|
||||
check_shape[kFormatNC1HWC0IndexW] = int64_init_one;
|
||||
image_size = dy_shape[kFormatNC1HWC0IndexH] * dy_shape[kFormatNC1HWC0IndexW];
|
||||
instance_num = dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1] * dy_shape[kFormatNC1HWC0IndexC0];
|
||||
dy_shape_4d_ = {dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1], dy_shape[kFormatNC1HWC0IndexH],
|
||||
dy_shape[kFormatNC1HWC0IndexW], dy_shape[kFormatNC1HWC0IndexC0]};
|
||||
batch_channels_2d_ = {dy_shape[kFormatNC1HWC0IndexN] * dy_shape[kFormatNC1HWC0IndexC1],
|
||||
dy_shape[kFormatNC1HWC0IndexC0]};
|
||||
} else {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s]'s input arguments dy and x only "
|
||||
"support NHWC, NCHW and NC1HWC0, but get data format [%s]",
|
||||
kInstanceNormV2Grad, FormatToSerialString(x_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
constexpr int64_t image_min = 1;
|
||||
if (image_size <= image_min) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s], expected more than 1 value per instance, but get "
|
||||
"[%ld] value per instance.",
|
||||
kInstanceNormV2Grad, image_size);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const std::map<std::string, TensorShapePtr> shapes = {{"gamma", gamma_shape_ptr},
|
||||
{"mean", mean_shape_ptr},
|
||||
{"variance", variance_shape_ptr},
|
||||
{"save_mean", save_mean_shape_ptr},
|
||||
{"save_variance", save_variance_shape_ptr},
|
||||
{"pd_gamma", pd_gamma_shape_ptr},
|
||||
{"pd_beta", pd_beta_shape_ptr}};
|
||||
return CheckTensorShapeSame(shapes, check_shape, kInstanceNormV2Grad);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradAttrCheck(const CpuKernelContext &ctx) {
|
||||
constexpr float epsilon_min = 0.0;
|
||||
constexpr float epsilon_max = 1.0;
|
||||
auto epsilon_ptr = ctx.GetAttr("epsilon");
|
||||
if (epsilon_ptr) {
|
||||
epsilon_ = epsilon_ptr->GetFloat();
|
||||
}
|
||||
if (epsilon_ < epsilon_min || epsilon_ >= epsilon_max) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"For primitive[%s], attr epsilon value should be in [0, 1), but get "
|
||||
"[%f].",
|
||||
kInstanceNormV2Grad, epsilon_);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto is_training_ptr = ctx.GetAttr("is_training");
|
||||
if (is_training_ptr) {
|
||||
is_training_ = is_training_ptr->GetBool();
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2GradCpuKernel::InstanceNormV2GradParamCheck(const CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2GradTypeCheck(ctx), "InstanceNormV2Grad check type failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2GradShapeCheck(ctx), "InstanceNormV2Grad check shape failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2GradAttrCheck(ctx), "InstanceNormV2Grad check attr failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t InstanceNormV2GradCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
const int64_t batch = dy_shape_4d_[kFormatNHWCIndexN];
|
||||
const int64_t channel = dy_shape_4d_[kFormatNHWCIndexC];
|
||||
const int64_t image_size = dy_shape_4d_[kFormatNHWCIndexH] * dy_shape_4d_[kFormatNHWCIndexW];
|
||||
std::vector<int64_t> dy_shape_3d_ = {batch, image_size, channel};
|
||||
auto dy_3d = EigenTensor(dy_shape_3d_, ctx.Input(InstanceNormV2GradInDyIndex)->GetData()).tensor<T, kDim3>();
|
||||
auto in_x_3d = EigenTensor(dy_shape_3d_, ctx.Input(InstanceNormV2GradInXIndex)->GetData()).tensor<T, kDim3>();
|
||||
auto weight_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInGammaIndex)->GetData()).matrix<float>();
|
||||
auto running_mean_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInMeanIndex)->GetData()).matrix<float>();
|
||||
auto running_var_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInVarianceIndex)->GetData()).matrix<float>();
|
||||
auto save_mean_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInSaveMeanIndex)->GetData()).matrix<float>();
|
||||
auto save_invstd_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Input(InstanceNormV2GradInSaveVarianceIndex)->GetData()).matrix<float>();
|
||||
|
||||
auto dx_3d = EigenTensor(dy_shape_3d_, ctx.Output(InstanceNormV2GradOutDxIndex)->GetData()).tensor<T, kDim3>();
|
||||
auto grad_weight_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Output(InstanceNormV2GradOutPdGammaIndex)->GetData()).matrix<float>();
|
||||
auto grad_bias_matrix =
|
||||
EigenTensor(batch_channels_2d_, ctx.Output(InstanceNormV2GradOutPdBetaIndex)->GetData()).matrix<float>();
|
||||
auto loop_batch = [&](int64_t begin, int64_t end) {
|
||||
for (int64_t idx = begin; idx < end; ++idx) {
|
||||
for (int64_t c_idx = 0; c_idx < channel; ++c_idx) {
|
||||
float w = weight_matrix(idx, c_idx);
|
||||
float mean = float_init_zero, invstd = float_init_zero;
|
||||
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
|
||||
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
|
||||
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
|
||||
|
||||
double sum = double_init_zero, dotp = double_init_zero;
|
||||
for (int64_t img_idx = 0; img_idx < image_size; ++img_idx) {
|
||||
sum += static_cast<double>(dy_3d(idx, img_idx, c_idx));
|
||||
dotp += (static_cast<double>(in_x_3d(idx, img_idx, c_idx)) - FloatToDouble(mean)) *
|
||||
static_cast<double>(dy_3d(idx, img_idx, c_idx));
|
||||
}
|
||||
float k = static_cast<float>(dotp * FloatToDouble(invstd) * FloatToDouble(invstd) / LongToDouble(image_size));
|
||||
float grad_mean = static_cast<float>(sum / LongToDouble(image_size));
|
||||
for (int64_t img_idx = 0; img_idx < image_size; ++img_idx) {
|
||||
float _dx_ = (static_cast<float>(in_x_3d(idx, img_idx, c_idx)) - mean) * k;
|
||||
dx_3d(idx, img_idx, c_idx) =
|
||||
is_training_
|
||||
? static_cast<T>((static_cast<float>(dy_3d(idx, img_idx, c_idx)) - grad_mean - _dx_) * invstd * w)
|
||||
: static_cast<T>(static_cast<float>(dy_3d(idx, img_idx, c_idx)) * invstd * w);
|
||||
}
|
||||
grad_weight_matrix(idx, c_idx) = static_cast<float>(dotp * FloatToDouble(invstd));
|
||||
grad_bias_matrix(idx, c_idx) = static_cast<float>(sum);
|
||||
}
|
||||
}
|
||||
};
|
||||
int64_t block_size = std::max(int64_init_one, (kGrainSize / (channel * image_size)));
|
||||
return CpuKernelUtils::ParallelFor(ctx, batch, block_size, loop_batch);
|
||||
}
|
||||
|
||||
uint32_t InstanceNormV2GradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum),
|
||||
"InstanceNormV2Grad check input and output number failed.");
|
||||
KERNEL_HANDLE_ERROR(InstanceNormV2GradParamCheck(ctx), "InstanceNormV2Grad check params failed.");
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
uint32_t result;
|
||||
switch (data_type) {
|
||||
case (DT_FLOAT16):
|
||||
result = DoCompute<Eigen::half>(ctx);
|
||||
break;
|
||||
case (DT_FLOAT):
|
||||
result = DoCompute<float>(ctx);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_ERROR("InstanceNormV2Grad kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (result != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("InstanceNormV2Grad kernel compute failed.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kInstanceNormV2Grad, InstanceNormV2GradCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_GRAD_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_INSTANCE_NORM_V2_GRAD_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include <vector>
|
||||
|
||||
namespace aicpu {
|
||||
class InstanceNormV2GradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
InstanceNormV2GradCpuKernel() = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t InstanceNormV2GradParamCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2GradShapeCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2GradTypeCheck(const CpuKernelContext &ctx);
|
||||
uint32_t InstanceNormV2GradAttrCheck(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
|
||||
bool is_training_ = true;
|
||||
float epsilon_ = 0.00001;
|
||||
std::vector<int64_t> dy_shape_4d_;
|
||||
std::vector<int64_t> batch_channels_2d_;
|
||||
int64_t instance_num = 0;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,299 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "sparse_to_dense.h"
|
||||
#include <securec.h>
|
||||
#include "cpu_types.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace aicpu {
|
||||
const char *const SPARSETODENSE = "SparseToDense";
|
||||
constexpr int64_t kParallelDateSize = 16 * 1024;
|
||||
constexpr int64_t kCopyDataSize = 1024;
|
||||
constexpr uint32_t kInput0 = 0;
|
||||
constexpr uint32_t kInput1 = 1;
|
||||
constexpr uint32_t kInput2 = 2;
|
||||
constexpr uint32_t kInput3 = 3;
|
||||
constexpr uint32_t kOutput0 = 0;
|
||||
constexpr int32_t kRank = 2;
|
||||
} // namespace aicpu
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t SparseToDenseCpuKernel::SparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices,
|
||||
Tensor *output) {
|
||||
KERNEL_LOG_INFO("Start to execute SparseToDense");
|
||||
if (indices == nullptr || output == nullptr) {
|
||||
KERNEL_LOG_ERROR("Indices or output tensor is nullptr.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
DataType dt = static_cast<DataType>(output->GetDataType());
|
||||
switch (dt) {
|
||||
case DT_INT8:
|
||||
return EigenSparseToDense<int8_t>(ctx, st, indices, output);
|
||||
case DT_UINT8:
|
||||
return EigenSparseToDense<uint8_t>(ctx, st, indices, output);
|
||||
case DT_INT16:
|
||||
return EigenSparseToDense<int16_t>(ctx, st, indices, output);
|
||||
case DT_UINT16:
|
||||
return EigenSparseToDense<uint16_t>(ctx, st, indices, output);
|
||||
case DT_INT32:
|
||||
return EigenSparseToDense<int32_t>(ctx, st, indices, output);
|
||||
case DT_INT64:
|
||||
return EigenSparseToDense<int64_t>(ctx, st, indices, output);
|
||||
case DT_FLOAT16:
|
||||
return EigenSparseToDense<Eigen::half>(ctx, st, indices, output);
|
||||
case DT_FLOAT:
|
||||
return EigenSparseToDense<float>(ctx, st, indices, output);
|
||||
case DT_BOOL:
|
||||
return EigenSparseToDense<bool>(ctx, st, indices, output);
|
||||
case DT_DOUBLE:
|
||||
return EigenSparseToDense<double>(ctx, st, indices, output);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Sparse to dense can't support this data type [%d].", static_cast<int32_t>(dt));
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
}
|
||||
|
||||
KernelStatus SparseToDenseCpuKernel::ValidParam(const CpuKernelContext &ctx) {
|
||||
KERNEL_LOG_INFO("Start to execute ValidParam");
|
||||
// valid input and output nullptr
|
||||
Tensor *indices_tensor = ctx.Input(0);
|
||||
Tensor *shape_tensor = ctx.Input(1);
|
||||
Tensor *sparse_values = ctx.Input(2);
|
||||
Tensor *default_value_tensor = ctx.Input(3);
|
||||
Tensor *output_tensor = ctx.Output(0);
|
||||
bool validNull = ((output_tensor == nullptr) || default_value_tensor == nullptr || (sparse_values == nullptr) ||
|
||||
(indices_tensor == nullptr) || (shape_tensor == nullptr));
|
||||
if (validNull) {
|
||||
KERNEL_LOG_ERROR("Got input or output param is nullptr.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// valid shape nullptr
|
||||
auto output_shape = shape_tensor->GetTensorShape();
|
||||
auto values_shape = sparse_values->GetTensorShape();
|
||||
auto default_value_shape = default_value_tensor->GetTensorShape();
|
||||
auto indices_shape = indices_tensor->GetTensorShape();
|
||||
bool validShapeNull = ((default_value_shape == nullptr) || values_shape == nullptr || (output_shape == nullptr) ||
|
||||
(indices_shape == nullptr));
|
||||
if (validShapeNull) {
|
||||
KERNEL_LOG_ERROR("Got input shape is nullptr.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// sparse_indices
|
||||
if (indices_shape->GetDims() > kRank) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Sparse_indices should be a scalar, vector, or matrix, got dim "
|
||||
"size [%d].",
|
||||
indices_shape->GetDims());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const int64_t elems_num = indices_shape->GetDims() > 0 ? indices_shape->GetDimSize(0) : 1;
|
||||
const int64_t dims_num = indices_shape->GetDims() > 1 ? indices_shape->GetDimSize(1) : 1;
|
||||
|
||||
// output_shape
|
||||
if (output_shape->GetDims() != 1) {
|
||||
KERNEL_LOG_ERROR("Output_shape should be a vector, and got dim size [%d].", output_shape->GetDims());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_tensor->NumElements() != dims_num) {
|
||||
KERNEL_LOG_ERROR("Output_shape has incorrect number of elements [%lld], should be [%lld]",
|
||||
shape_tensor->NumElements(), dims_num);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// valid data type
|
||||
DataType IndiceType = indices_tensor->GetDataType();
|
||||
DataType outShapeType = shape_tensor->GetDataType();
|
||||
bool validIndiceType = ((IndiceType != DT_INT32) && (IndiceType != DT_INT64));
|
||||
bool validShapeType = ((outShapeType != DT_INT32) && (outShapeType != DT_INT64));
|
||||
if (validShapeType || validIndiceType) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Valid indice or output shape data type failed, indiceType [%d], "
|
||||
"shapeType [%d].",
|
||||
static_cast<int>(IndiceType), static_cast<int>(outShapeType));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// sparse_values
|
||||
int32_t values_dims_size = values_shape->GetDims();
|
||||
if ((values_dims_size != 0) && (values_dims_size != 1)) {
|
||||
KERNEL_LOG_ERROR("Values_shape should be a scalar or a vector, got dim size [%d].", values_shape->GetDims());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if ((values_dims_size == 1) && (sparse_values->NumElements() != elems_num)) {
|
||||
KERNEL_LOG_ERROR("Values_shape has incorrect number of elements [%lld], should be [%lld]",
|
||||
sparse_values->NumElements(), elems_num);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// default_value
|
||||
if (default_value_shape->GetDims() != 0) {
|
||||
KERNEL_LOG_ERROR("Default_value should be a scalar, and got dim size [%d].", default_value_shape->GetDims());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_LOG_INFO("Execute ValidParam end.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t SparseToDenseCpuKernel::ParallelSetDefaultValue(const CpuKernelContext &ctx,
|
||||
const Tensor *default_value_tensor,
|
||||
const Tensor *output_tensor, int64_t output_size) {
|
||||
auto type_size = GetSizeByDataType(static_cast<DataType>(output_tensor->GetDataType()));
|
||||
char *default_value_addr = reinterpret_cast<char *>(default_value_tensor->GetData());
|
||||
char *output_addr = reinterpret_cast<char *>(output_tensor->GetData());
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
auto default_value = [&](std::int64_t begin, std::int64_t end) {
|
||||
int64_t total = end - begin;
|
||||
int64_t remainder = total % kCopyDataSize;
|
||||
int64_t piece = total / kCopyDataSize;
|
||||
if (piece == 0) {
|
||||
for (int64_t index = begin; index < end; index++) {
|
||||
(void)memcpy_s(output_addr + (index * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
} else {
|
||||
for (int64_t index = begin; index < begin + kCopyDataSize; index++) {
|
||||
(void)memcpy_s(output_addr + (index * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
char *temp_addr = output_addr + (begin * type_size);
|
||||
size_t data_size = static_cast<size_t>(type_size * kCopyDataSize);
|
||||
for (int64_t loop = 1; loop < piece; loop++) {
|
||||
(void)memcpy_s(temp_addr + (loop * type_size * kCopyDataSize), data_size, temp_addr, data_size);
|
||||
}
|
||||
char *temp_addr1 = output_addr + (begin * type_size) + (piece * type_size * kCopyDataSize);
|
||||
for (int64_t loop1 = 0; loop1 < remainder; loop1++) {
|
||||
(void)memcpy_s(temp_addr1 + (loop1 * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
return CpuKernelUtils::ParallelFor(ctx, output_size, output_size / max_core_num, default_value);
|
||||
}
|
||||
uint32_t SparseToDenseCpuKernel::SetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor,
|
||||
const Tensor *output_tensor, int64_t output_size) {
|
||||
auto type_size = GetSizeByDataType(static_cast<DataType>(output_tensor->GetDataType()));
|
||||
if (type_size < 1) {
|
||||
KERNEL_LOG_ERROR("Don't support output tensor types");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
char *default_value_addr = reinterpret_cast<char *>(default_value_tensor->GetData());
|
||||
char *output_addr = reinterpret_cast<char *>(output_tensor->GetData());
|
||||
if (output_size < kParallelDateSize) {
|
||||
int64_t remainder = output_size % kCopyDataSize;
|
||||
int64_t piece = output_size / kCopyDataSize;
|
||||
if (piece == 0) {
|
||||
for (int index = 0; index < output_size; index++) {
|
||||
(void)memcpy_s(output_addr + (index * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < kCopyDataSize; index++) {
|
||||
(void)memcpy_s(output_addr + (index * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
size_t data_size = static_cast<size_t>(type_size * kCopyDataSize);
|
||||
for (int loop = 1; loop < piece; loop++) {
|
||||
(void)memcpy_s(output_addr + (loop * type_size * kCopyDataSize), data_size, output_addr, data_size);
|
||||
}
|
||||
char *temp_addr = output_addr + (piece * type_size * kCopyDataSize);
|
||||
for (int loop1 = 0; loop1 < remainder; loop1++) {
|
||||
(void)memcpy_s(temp_addr + (loop1 * type_size), static_cast<size_t>(type_size), default_value_addr,
|
||||
static_cast<size_t>(type_size));
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
} else {
|
||||
return ParallelSetDefaultValue(ctx, default_value_tensor, output_tensor, output_size);
|
||||
}
|
||||
}
|
||||
uint32_t SparseToDenseCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
if (ValidParam(ctx) != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Valid sparse to dense param error.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
Tensor *indices_tensor = ctx.Input(kInput0);
|
||||
KERNEL_CHECK_NULLPTR(indices_tensor, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Indices_tensor is null")
|
||||
Tensor *shape_tensor = ctx.Input(kInput1);
|
||||
KERNEL_CHECK_NULLPTR(shape_tensor, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Shape_tensor is null")
|
||||
Tensor *sparse_values = ctx.Input(kInput2);
|
||||
KERNEL_CHECK_NULLPTR(sparse_values, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Sparse_values is null")
|
||||
Tensor *default_value_tensor = ctx.Input(kInput3);
|
||||
KERNEL_CHECK_NULLPTR(default_value_tensor, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID),
|
||||
"Default_value_tensor is null")
|
||||
Tensor *output_tensor = ctx.Output(kOutput0);
|
||||
KERNEL_CHECK_NULLPTR(output_tensor, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Output_tensor is null")
|
||||
|
||||
auto output_shape = shape_tensor->GetTensorShape();
|
||||
std::vector<int64_t> dense_shape;
|
||||
std::vector<int64_t> order;
|
||||
int64_t output_size = 1;
|
||||
size_t output_zero_dim_size = static_cast<size_t>(output_shape->GetDimSize(0));
|
||||
for (size_t index = 0; index < output_zero_dim_size; ++index) {
|
||||
if (shape_tensor->GetDataType() == DT_INT32) {
|
||||
int32_t *temp_dim = reinterpret_cast<int32_t *>(shape_tensor->GetData());
|
||||
dense_shape.emplace_back(static_cast<int64_t>(temp_dim[index]));
|
||||
} else {
|
||||
int64_t *temp_dim = reinterpret_cast<int64_t *>(shape_tensor->GetData());
|
||||
dense_shape.emplace_back(temp_dim[index]);
|
||||
}
|
||||
order.push_back(dense_shape[index]);
|
||||
output_size *= dense_shape[index];
|
||||
}
|
||||
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
|
||||
SparseTensor st;
|
||||
if (st.CreateSparseTensor(indices_tensor, sparse_values, dense_shape, order) !=
|
||||
static_cast<uint32_t>(KERNEL_STATUS_OK)) {
|
||||
KERNEL_LOG_ERROR("Create sparse tensor failed.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
AttrValue *validate_indices = ctx.GetAttr("validate_indices");
|
||||
if (validate_indices == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get attr:validate_indices failed.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
if (validate_indices->GetBool()) {
|
||||
if (st.IndicesValid(ctx) != static_cast<uint32_t>(KERNEL_STATUS_OK)) {
|
||||
KERNEL_LOG_ERROR("Indices is valid.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
}
|
||||
|
||||
if (SetDefaultValue(ctx, default_value_tensor, output_tensor, output_size) !=
|
||||
static_cast<uint32_t>(KERNEL_STATUS_OK)) {
|
||||
KERNEL_LOG_ERROR("Sparse_to_dense set default value failed.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
if (SparseToDense(ctx, st, indices_tensor, output_tensor) != static_cast<uint32_t>(KERNEL_STATUS_OK)) {
|
||||
KERNEL_LOG_ERROR("Sparse_to_dense execute failed.");
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(SPARSETODENSE, SparseToDenseCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_SPARSETODENSE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_SPARSETODENSE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "utils/sparse_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class SparseToDenseCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~SparseToDenseCpuKernel() = default;
|
||||
|
||||
/*
|
||||
* compute sparse to dense
|
||||
* @param ctx: cpu kernel context
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
protected:
|
||||
/*
|
||||
* valid sparse to dense param
|
||||
* @param st: sparse tensor
|
||||
* @param indices: indices tensor
|
||||
* @param output: output tensor
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
template <typename ValueT>
|
||||
uint32_t EigenSparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices, Tensor *output) {
|
||||
if (indices->GetDataType() == DT_INT32) {
|
||||
return st.ToDense<int32_t, ValueT>(ctx, output);
|
||||
} else {
|
||||
return st.ToDense<int64_t, ValueT>(ctx, output);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* valid sparse to dense param
|
||||
* @param st: sparse tensor
|
||||
* @param indices: indices tensor
|
||||
* @param output: output tensor
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t SparseToDense(const CpuKernelContext &ctx, SparseTensor &st, const Tensor *indices, Tensor *output);
|
||||
|
||||
/*
|
||||
* valid sparse to dense param
|
||||
* @param ctx: cpu kernel context
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
KernelStatus ValidParam(const CpuKernelContext &ctx);
|
||||
|
||||
/*
|
||||
* parallel set default value to dense
|
||||
* @param ctx: cpu kernel context
|
||||
* @param default_value_tensor: default value of dense tensor
|
||||
* @param output_tensor: output tensor
|
||||
* @param output_size: output tensor size
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t ParallelSetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor,
|
||||
const Tensor *output_tensor, int64_t output_size);
|
||||
|
||||
/*
|
||||
* set default value to dense
|
||||
* @param ctx: cpu kernel context
|
||||
* @param default_value_tensor: default value of dense tensor
|
||||
* @param output_tensor: output tensor
|
||||
* @param output_size: output tensor size
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t SetDefaultValue(const CpuKernelContext &ctx, const Tensor *default_value_tensor, const Tensor *output_tensor,
|
||||
int64_t output_size);
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -319,7 +319,11 @@ bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const {
|
|||
mindspore::kMaskedSelectOpName,
|
||||
mindspore::kMaskedSelectGradOpName,
|
||||
mindspore::kMultiMarginLossOpName,
|
||||
mindspore::kMatrixInverseOpName};
|
||||
mindspore::kMatrixInverseOpName,
|
||||
mindspore::kMultiMarginLossGradOpName,
|
||||
mindspore::kSspaddmmOpName,
|
||||
mindspore::kBatchMatMulOpName,
|
||||
mindspore::kSparseToDenseOpName};
|
||||
|
||||
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
|
||||
static const std::string kCpuKernelSoName = "mindspore_cpu_kernels";
|
||||
|
|
|
@ -410,3 +410,7 @@ from .trace import _trace_aicpu
|
|||
from .tracegrad import _tracegrad_aicpu
|
||||
from .tridiagonal_solve import _tridiagonal_solve_aicpu
|
||||
from .truncated_normal import _truncated_normal_aicpu
|
||||
from .glu import _glu_aicpu
|
||||
from .multi_margin_loss import _multi_margin_loss_aicpu
|
||||
from .multi_margin_loss_grad import _multi_margin_loss_grad_aicpu
|
||||
from .sparse_to_dense_v2 import _sparse_to_dense_v2_aicpu
|
||||
|
|
Loading…
Reference in New Issue