migrate Conj and other ops

This commit is contained in:
lilinjie 2023-01-13 11:34:18 +08:00
parent 3f5759d0dc
commit 4dfd7b7544
11 changed files with 502 additions and 8 deletions

View File

@ -327,3 +327,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss_grad.cc:aicpu::MultiMarginLossGradCpuKernel::MultiMarginLossGradComputeFP16
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss_grad.cc:aicpu::MultiMarginLossGradCpuKernel::MultiMarginLossGradCompute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossComputeFP16
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_band_part.cc:aicpu::MatrixBandPartCpuKernel::BandCompute

View File

@ -11,9 +11,8 @@ mindspore.ops.remainder
out_{i} = input_{i} \text{ % } other_{i}
.. warning::
- 输入数值不支持0。
- 当输入元素超过2048时操作的精确度无法保证mini表格的千分之二的要求。
- 由于架构不同该操作符在NPU和CPU上的计算结果可能不一致。
- 当输入元素超过2048时可能会有京精度问题。
- 在Ascend和CPU上的计算结果可能不一致。
- 如果shape表示为(D1,D2…Dn)那么D1 \ * D2……\ * DN <= 1000000n <= 8。
参数:

View File

@ -200,6 +200,7 @@ constexpr auto kConv3DTransposeOpName = "Conv3DTranspose";
constexpr auto kDeformableOffsetsOpName = "DeformableOffsets";
constexpr auto kCropAndResizeOpName = "CropAndResize";
constexpr auto kCropAndResizeDOpName = "CropAndResizeD";
constexpr auto kConjOpName = "Conj";
constexpr auto kConvBN1OpName = "ConvBN1";
constexpr auto kCOO2CSROpName = "COO2CSR";
constexpr auto kCosOpName = "Cos";
@ -414,6 +415,7 @@ constexpr auto kMaskedSelectOpName = "MaskedSelect";
constexpr auto kMaskedSelectGradOpName = "MaskedSelectGrad";
constexpr auto kMatMulOpName = "MatMul";
constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kMatrixBandPartOpName = "MatrixBandPart";
constexpr auto kMatrixDeterminantOpName = "MatrixDeterminant";
constexpr auto kMatrixDiagOpName = "MatrixDiag";
constexpr auto kMatrixDiagDOpName = "MatrixDiagD";
@ -717,6 +719,7 @@ constexpr auto kUpdateStateOpName = "UpdateState";
constexpr auto kDynamicBroadcastToOpName = "DynamicBroadcastTo";
constexpr auto kCheckValidOpName = "CheckValid";
constexpr auto kSoftmaxGradFusionOpName = "SoftmaxGradFusion";
constexpr auto kZerosLikeOpName = "ZerosLike";
// Sequence ops
constexpr auto kScalarToTensorOpName = "ScalarToTensor";

View File

@ -0,0 +1,92 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "conj.h"
#include <complex>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const char *const kConj = "Conj";
constexpr int64_t kParallelDataNums = 512 * 1024;
#define CONJ_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = ConjCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("Conj kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t ConjCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kConj);
KERNEL_HANDLE_ERROR(ConjCheck(ctx), "[%s] check params failed.", kConj);
DataType dataType = ctx.Input(0)->GetDataType();
switch (dataType) {
CONJ_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
CONJ_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("Conj kernel data type [%s] not support.", DTypeStr(dataType).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t ConjCpuKernel::ConjCheck(const CpuKernelContext &ctx) const {
auto input = ctx.Input(0);
auto output = ctx.Output(0);
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ConjCpuKernel::ConjCompute(const CpuKernelContext &ctx) const {
auto inputX = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto outputY = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t dataNum = ctx.Input(0)->NumElements();
int64_t dataSize = dataNum * static_cast<int64_t>(sizeof(T));
if (dataSize <= kParallelDataNums) {
for (int64_t i = 0; i < dataNum; i++) {
*(outputY + i) = std::conj(*(inputX + i));
}
} else {
uint32_t minCoreNum = 1;
int64_t maxCoreNum = std::max(minCoreNum, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
if (maxCoreNum > dataNum) {
maxCoreNum = dataNum;
}
auto shardConj = [&inputX, &outputY](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
*(outputY + i) = std::conj(*(inputX + i));
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dataNum, dataNum / maxCoreNum, shardConj),
"Conj Compute failed.");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kConj, ConjCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,36 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_CONJ_H_
#define AICPU_KERNELS_NORMALIZED_CONJ_H_
#include "cpu_ops_kernel.h"
namespace aicpu {
class ConjCpuKernel : public CpuKernel {
public:
ConjCpuKernel() = default;
~ConjCpuKernel() override = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ConjCheck(const CpuKernelContext &ctx) const;
template <typename T>
uint32_t ConjCompute(const CpuKernelContext &ctx) const;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,194 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "matrix_band_part.h"
#include <algorithm>
#include <vector>
#include <securec.h>
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include "kernel_log.h"
#include "status.h"
namespace {
const char *kMatrixBandPart = "MatrixBandPart";
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 3;
constexpr int64_t kParallelDataNums = 64 * 1024;
#define BAND_COMPUTE_CASE(DTYPE, TYPE, X, LOWER, UPPER, Y, M, N, CTX) \
case (DTYPE): { \
uint32_t result = BandCompute<TYPE>(X, LOWER, UPPER, Y, M, N, CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("MatrixBandPart kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t MatrixBandPartCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "MatrixBandPart check input and output number failed.");
Tensor *x = ctx.Input(0);
Tensor *num_lower = ctx.Input(1);
Tensor *num_upper = ctx.Input(2);
Tensor *y = ctx.Output(0);
int32_t rank = x->GetTensorShape()->GetDims();
KERNEL_CHECK_FALSE((rank >= 2), KERNEL_STATUS_PARAM_INVALID, "Input must be at least 2-dim, but get dims: %d", rank);
int64_t m = x->GetTensorShape()->GetDimSize(rank - 2);
int64_t n = x->GetTensorShape()->GetDimSize(rank - 1);
KERNEL_CHECK_FALSE((num_lower->GetTensorShape()->GetDimSizes().empty()), KERNEL_STATUS_PARAM_INVALID,
"num_lower must be scalar.");
KERNEL_CHECK_FALSE((num_upper->GetTensorShape()->GetDimSizes().empty()), KERNEL_STATUS_PARAM_INVALID,
"num_upper must be scalar.");
DataType lower_type = num_lower->GetDataType();
KERNEL_CHECK_FALSE((lower_type == DT_INT32 || lower_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
"Unsupported num_lower data_type[%s], "
"only support DT_INT32 and DT_INT64.",
DTypeStr(lower_type).c_str());
DataType upper_type = num_upper->GetDataType();
KERNEL_CHECK_FALSE((upper_type == DT_INT32 || upper_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
"Unsupported num_upper data_type[%s], "
"only support DT_INT32 and DT_INT64.",
DTypeStr(upper_type).c_str());
int32_t *lower_data = reinterpret_cast<int32_t *>(num_lower->GetData());
KERNEL_CHECK_NULLPTR(lower_data, KERNEL_STATUS_PARAM_INVALID, "Get num_lower data failed.");
int32_t *upper_data = reinterpret_cast<int32_t *>(num_upper->GetData());
KERNEL_CHECK_NULLPTR(upper_data, KERNEL_STATUS_PARAM_INVALID, "Get num_upper data failed.");
int64_t lower = *lower_data;
int64_t upper = *upper_data;
KERNEL_CHECK_FALSE((lower <= m), KERNEL_STATUS_PARAM_INVALID,
"num_lower must be negative or less or equal to number "
"of rows [%d], got: [%d]",
m, lower);
KERNEL_CHECK_FALSE((upper <= n), KERNEL_STATUS_PARAM_INVALID,
"num_lower must be negative or less or equal to number "
"of cols [%d], got: [%d]",
n, upper);
uint64_t input_size = x->GetDataSize();
uint64_t output_size = y->GetDataSize();
KERNEL_CHECK_FALSE((input_size == output_size), KERNEL_STATUS_PARAM_INVALID,
"Input data size[%llu] is not equal to output data size[%llu].", input_size, output_size);
DataType data_type = x->GetDataType();
switch (data_type) {
BAND_COMPUTE_CASE(DT_INT8, int8_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_INT16, int16_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_INT32, int32_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_INT64, int64_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_UINT8, uint8_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_UINT16, uint16_t, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_FLOAT16, Eigen::half, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_FLOAT, float, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_DOUBLE, double, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_BOOL, bool, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, x, lower, upper, y, m, n, ctx)
BAND_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, x, lower, upper, y, m, n, ctx)
default:
KERNEL_LOG_ERROR("MatrixBandPart kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t MatrixBandPartCpuKernel::BandCompute(Tensor *x, int64_t lower, int64_t upper, Tensor *y, int64_t rows,
int64_t cols, CpuKernelContext &ctx) {
T *x_addrs = reinterpret_cast<T *>(x->GetData());
KERNEL_CHECK_NULLPTR(x_addrs, KERNEL_STATUS_PARAM_INVALID, "Get input data failed.");
T *y_addrs = reinterpret_cast<T *>(y->GetData());
KERNEL_CHECK_NULLPTR(y_addrs, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.");
T zero = static_cast<T>(0);
int64_t data_num = x->GetDataSize() / sizeof(T);
int64_t matrix_size = rows * cols;
int64_t total_rows = data_num / cols;
bool same_addr = (x_addrs == y_addrs);
if (data_num < kParallelDataNums) {
if (!same_addr) {
std::fill_n(y_addrs, data_num, zero);
}
int64_t batch_end = (total_rows + rows - 1) / rows;
for (int64_t i = 0; i < batch_end; i++) {
int64_t row_begin = 0 > i * rows ? 0 % rows : 0;
int64_t row_end = total_rows < (i + 1) * rows ? total_rows % rows : rows;
for (int64_t m = row_begin; m < row_end; m++) {
int64_t base_index = i * matrix_size + m * cols;
int64_t band_start = lower < 0 ? 0 : std::min(cols, std::max(int64_t{0}, int64_t{m - lower}));
int64_t band_end = upper < 0 ? cols : std::min(cols, int64_t{m + upper + 1});
if (same_addr) {
if (band_start > 0) {
std::fill((y_addrs + base_index), (y_addrs + base_index + band_start), zero);
}
if (band_end < cols) {
std::fill((y_addrs + base_index + band_end), (y_addrs + base_index + cols), zero);
}
} else {
if (band_start < band_end) {
(void)memcpy_s((y_addrs + base_index + band_start), (band_end - band_start) * sizeof(T),
(x_addrs + base_index + band_start), (band_end - band_start) * sizeof(T));
}
}
}
}
} else {
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx));
if (max_core_num > total_rows) {
max_core_num = total_rows;
}
auto shard_band = [&](int64_t start, int64_t end) {
if (!same_addr) {
std::fill(y_addrs + start * cols, y_addrs + end * cols, zero);
}
int64_t batch_begin = start / rows;
int64_t batch_end = (end + rows - 1) / rows;
for (int64_t i = batch_begin; i < batch_end; i++) {
int64_t row_begin = start > i * rows ? start % rows : 0;
int64_t row_end = end < (i + 1) * rows ? end % rows : rows;
for (int64_t m = row_begin; m < row_end; m++) {
int64_t base_index = i * matrix_size + m * cols;
int64_t band_start = lower < 0 ? 0 : std::min(cols, std::max(int64_t{0}, int64_t{m - lower}));
int64_t band_end = upper < 0 ? cols : std::min(cols, int64_t{m + upper + 1});
if (same_addr) {
if (band_start > 0) {
std::fill((y_addrs + base_index), (y_addrs + base_index + band_start), zero);
}
if (band_end < cols) {
std::fill((y_addrs + base_index + band_end), (y_addrs + base_index + cols), zero);
}
} else {
if (band_start < band_end) {
(void)memcpy_s((y_addrs + base_index + band_start), (band_end - band_start) * sizeof(T),
(x_addrs + base_index + band_start), (band_end - band_start) * sizeof(T));
}
}
}
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, total_rows, total_rows / max_core_num, shard_band),
"MatrixBandPart Compute failed.");
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kMatrixBandPart, MatrixBandPartCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_MATRIX_BAND_PART_H_
#define AICPU_KERNELS_NORMALIZED_MATRIX_BAND_PART_H_
#include "cpu_ops_kernel.h"
#include "cpu_types.h"
namespace aicpu {
class MatrixBandPartCpuKernel : public CpuKernel {
public:
MatrixBandPartCpuKernel() = default;
~MatrixBandPartCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
static uint32_t BandParamCheck(Tensor *x, Tensor *num_lower, Tensor *num_upper, Tensor *y);
template <typename T>
static uint32_t BandCompute(Tensor *x, int64_t lower, int64_t upper, Tensor *y, int64_t rows, int64_t cols,
CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,94 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "zeroslike.h"
#include <cstring>
#include "cpu_kernel_utils.h"
#include "utils/eigen_tensor.h"
#include "utils/kernel_util.h"
#include "Eigen/Dense"
#include "securec.h"
namespace {
const uint32_t kOutputNum = 1;
const uint32_t kInputNum = 1;
const char *const kZerosLike = "ZerosLike";
#define ZEROSLIKE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = ZerosLikePartCompute<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("ZerosLike kernel compute failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
uint32_t ZerosLikeCpuKernel::Compute(CpuKernelContext &ctx) {
// check params
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kZerosLike);
KERNEL_HANDLE_ERROR(ZerosLikeCheck(ctx), "[%s] check params failed.", kZerosLike);
auto data_type = ctx.Input(0)->GetDataType();
switch (data_type) {
ZEROSLIKE_COMPUTE_CASE(DT_BOOL, bool, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_INT8, int8_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_INT16, int16_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_FLOAT, float, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_DOUBLE, double, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
ZEROSLIKE_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
default:
KERNEL_LOG_ERROR("ZerosLike kernel data type [%s] not support.", DTypeStr(data_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
return KERNEL_STATUS_OK;
}
uint32_t ZerosLikeCpuKernel::ZerosLikeCheck(const CpuKernelContext &ctx) const {
auto input = ctx.Input(0);
auto output = ctx.Output(0);
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
KERNEL_CHECK_NULLPTR(input->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.")
std::vector<int64_t> shape_x = input->GetTensorShape()->GetDimSizes();
size_t shape_size = shape_x.size();
KERNEL_CHECK_FALSE((shape_size > 0), KERNEL_STATUS_PARAM_INVALID, "Input must be at least rank 1, got [%zu].",
shape_x.size())
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ZerosLikeCpuKernel::ZerosLikePartCompute(const CpuKernelContext &ctx) {
size_t data_num = static_cast<size_t>(ctx.Input(0)->NumElements());
Tensor *y = ctx.Output(0);
auto y_addr = y->GetData();
auto ret = memset_s(y_addr, data_num * sizeof(T), 0, data_num * sizeof(T));
if (ret != EOK) {
KERNEL_LOG_ERROR("memset_s error, ret=%d", ret);
}
return KERNEL_STATUS_OK;
}
REGISTER_CPU_KERNEL(kZerosLike, ZerosLikeCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,36 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AICPU_KERNELS_NORMALIZED_ZEROSLIKE_H
#define AICPU_KERNELS_NORMALIZED_ZEROSLIKE_H
#include "cpu_ops_kernel.h"
namespace aicpu {
class ZerosLikeCpuKernel : public CpuKernel {
public:
ZerosLikeCpuKernel() = default;
~ZerosLikeCpuKernel() override = default;
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ZerosLikeCheck(const CpuKernelContext &ctx) const;
template <typename T>
uint32_t ZerosLikePartCompute(const CpuKernelContext &ctx);
};
} // namespace aicpu
#endif

View File

@ -138,7 +138,10 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
mindspore::kCheckNumericsOpName,
mindspore::kFloorDivOpName,
mindspore::kLog1pOpName,
mindspore::kMulOpName};
mindspore::kMulOpName,
mindspore::kConjOpName,
mindspore::kZerosLikeOpName,
mindspore::kMatrixBandPartOpName};
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
static const std::string kCpuKernelSoName = "mindspore_cpu_kernels";

View File

@ -8471,10 +8471,8 @@ def remainder(x, y):
out_{i} = input_{i} \text{ % } other_{i}
.. warning::
- The input data does not support 0.
- When the elements of input exceed 2048, the accuracy of operator cannot guarantee the requirement of
double thousandths in the mini form.
- Due to different architectures, the calculation results of this operator on NPU and CPU may be inconsistent.
- When the elements of input exceed 2048, there might be accuracy problems.
- The calculation results of this operator on Ascend and CPU might be inconsistent.
- If shape is expressed as (D1,D2... ,Dn), then D1\*D2... \*DN<=1000000,n<=8.
Args: