forked from mindspore-Ecosystem/mindspore
migrate Conj and other ops
This commit is contained in:
parent
3f5759d0dc
commit
4dfd7b7544
|
@ -327,3 +327,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss_grad.cc:aicpu::MultiMarginLossGradCpuKernel::MultiMarginLossGradComputeFP16
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss_grad.cc:aicpu::MultiMarginLossGradCpuKernel::MultiMarginLossGradCompute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/multi_margin_loss.cc:aicpu::MultiMarginLossCpuKernel::MultiMarginLossComputeFP16
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/matrix_band_part.cc:aicpu::MatrixBandPartCpuKernel::BandCompute
|
||||
|
|
|
@ -11,9 +11,8 @@ mindspore.ops.remainder
|
|||
out_{i} = input_{i} \text{ % } other_{i}
|
||||
|
||||
.. warning::
|
||||
- 输入数值不支持0。
|
||||
- 当输入元素超过2048时,操作的精确度无法保证mini表格的千分之二的要求。
|
||||
- 由于架构不同,该操作符在NPU和CPU上的计算结果可能不一致。
|
||||
- 当输入元素超过2048时,可能会有京精度问题。
|
||||
- 在Ascend和CPU上的计算结果可能不一致。
|
||||
- 如果shape表示为(D1,D2…Dn),那么D1 \ * D2……\ * DN <= 1000000,n <= 8。
|
||||
|
||||
参数:
|
||||
|
|
|
@ -200,6 +200,7 @@ constexpr auto kConv3DTransposeOpName = "Conv3DTranspose";
|
|||
constexpr auto kDeformableOffsetsOpName = "DeformableOffsets";
|
||||
constexpr auto kCropAndResizeOpName = "CropAndResize";
|
||||
constexpr auto kCropAndResizeDOpName = "CropAndResizeD";
|
||||
constexpr auto kConjOpName = "Conj";
|
||||
constexpr auto kConvBN1OpName = "ConvBN1";
|
||||
constexpr auto kCOO2CSROpName = "COO2CSR";
|
||||
constexpr auto kCosOpName = "Cos";
|
||||
|
@ -414,6 +415,7 @@ constexpr auto kMaskedSelectOpName = "MaskedSelect";
|
|||
constexpr auto kMaskedSelectGradOpName = "MaskedSelectGrad";
|
||||
constexpr auto kMatMulOpName = "MatMul";
|
||||
constexpr auto kMatMulV2OpName = "MatMulV2";
|
||||
constexpr auto kMatrixBandPartOpName = "MatrixBandPart";
|
||||
constexpr auto kMatrixDeterminantOpName = "MatrixDeterminant";
|
||||
constexpr auto kMatrixDiagOpName = "MatrixDiag";
|
||||
constexpr auto kMatrixDiagDOpName = "MatrixDiagD";
|
||||
|
@ -717,6 +719,7 @@ constexpr auto kUpdateStateOpName = "UpdateState";
|
|||
constexpr auto kDynamicBroadcastToOpName = "DynamicBroadcastTo";
|
||||
constexpr auto kCheckValidOpName = "CheckValid";
|
||||
constexpr auto kSoftmaxGradFusionOpName = "SoftmaxGradFusion";
|
||||
constexpr auto kZerosLikeOpName = "ZerosLike";
|
||||
|
||||
// Sequence ops
|
||||
constexpr auto kScalarToTensorOpName = "ScalarToTensor";
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "conj.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 1;
|
||||
const char *const kConj = "Conj";
|
||||
constexpr int64_t kParallelDataNums = 512 * 1024;
|
||||
|
||||
#define CONJ_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ConjCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("Conj kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ConjCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kConj);
|
||||
KERNEL_HANDLE_ERROR(ConjCheck(ctx), "[%s] check params failed.", kConj);
|
||||
DataType dataType = ctx.Input(0)->GetDataType();
|
||||
switch (dataType) {
|
||||
CONJ_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
CONJ_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("Conj kernel data type [%s] not support.", DTypeStr(dataType).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ConjCpuKernel::ConjCheck(const CpuKernelContext &ctx) const {
|
||||
auto input = ctx.Input(0);
|
||||
auto output = ctx.Output(0);
|
||||
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
|
||||
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConjCpuKernel::ConjCompute(const CpuKernelContext &ctx) const {
|
||||
auto inputX = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto outputY = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
int64_t dataNum = ctx.Input(0)->NumElements();
|
||||
int64_t dataSize = dataNum * static_cast<int64_t>(sizeof(T));
|
||||
if (dataSize <= kParallelDataNums) {
|
||||
for (int64_t i = 0; i < dataNum; i++) {
|
||||
*(outputY + i) = std::conj(*(inputX + i));
|
||||
}
|
||||
} else {
|
||||
uint32_t minCoreNum = 1;
|
||||
int64_t maxCoreNum = std::max(minCoreNum, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
|
||||
if (maxCoreNum > dataNum) {
|
||||
maxCoreNum = dataNum;
|
||||
}
|
||||
auto shardConj = [&inputX, &outputY](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
*(outputY + i) = std::conj(*(inputX + i));
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dataNum, dataNum / maxCoreNum, shardConj),
|
||||
"Conj Compute failed.");
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kConj, ConjCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_CONJ_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_CONJ_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ConjCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ConjCpuKernel() = default;
|
||||
~ConjCpuKernel() override = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ConjCheck(const CpuKernelContext &ctx) const;
|
||||
|
||||
template <typename T>
|
||||
uint32_t ConjCompute(const CpuKernelContext &ctx) const;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "matrix_band_part.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <securec.h>
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
|
||||
namespace {
|
||||
const char *kMatrixBandPart = "MatrixBandPart";
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 3;
|
||||
constexpr int64_t kParallelDataNums = 64 * 1024;
|
||||
|
||||
#define BAND_COMPUTE_CASE(DTYPE, TYPE, X, LOWER, UPPER, Y, M, N, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = BandCompute<TYPE>(X, LOWER, UPPER, Y, M, N, CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("MatrixBandPart kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MatrixBandPartCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "MatrixBandPart check input and output number failed.");
|
||||
Tensor *x = ctx.Input(0);
|
||||
Tensor *num_lower = ctx.Input(1);
|
||||
Tensor *num_upper = ctx.Input(2);
|
||||
Tensor *y = ctx.Output(0);
|
||||
int32_t rank = x->GetTensorShape()->GetDims();
|
||||
KERNEL_CHECK_FALSE((rank >= 2), KERNEL_STATUS_PARAM_INVALID, "Input must be at least 2-dim, but get dims: %d", rank);
|
||||
int64_t m = x->GetTensorShape()->GetDimSize(rank - 2);
|
||||
int64_t n = x->GetTensorShape()->GetDimSize(rank - 1);
|
||||
KERNEL_CHECK_FALSE((num_lower->GetTensorShape()->GetDimSizes().empty()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"num_lower must be scalar.");
|
||||
KERNEL_CHECK_FALSE((num_upper->GetTensorShape()->GetDimSizes().empty()), KERNEL_STATUS_PARAM_INVALID,
|
||||
"num_upper must be scalar.");
|
||||
DataType lower_type = num_lower->GetDataType();
|
||||
KERNEL_CHECK_FALSE((lower_type == DT_INT32 || lower_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Unsupported num_lower data_type[%s], "
|
||||
"only support DT_INT32 and DT_INT64.",
|
||||
DTypeStr(lower_type).c_str());
|
||||
DataType upper_type = num_upper->GetDataType();
|
||||
KERNEL_CHECK_FALSE((upper_type == DT_INT32 || upper_type == DT_INT64), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Unsupported num_upper data_type[%s], "
|
||||
"only support DT_INT32 and DT_INT64.",
|
||||
DTypeStr(upper_type).c_str());
|
||||
int32_t *lower_data = reinterpret_cast<int32_t *>(num_lower->GetData());
|
||||
KERNEL_CHECK_NULLPTR(lower_data, KERNEL_STATUS_PARAM_INVALID, "Get num_lower data failed.");
|
||||
int32_t *upper_data = reinterpret_cast<int32_t *>(num_upper->GetData());
|
||||
KERNEL_CHECK_NULLPTR(upper_data, KERNEL_STATUS_PARAM_INVALID, "Get num_upper data failed.");
|
||||
int64_t lower = *lower_data;
|
||||
int64_t upper = *upper_data;
|
||||
KERNEL_CHECK_FALSE((lower <= m), KERNEL_STATUS_PARAM_INVALID,
|
||||
"num_lower must be negative or less or equal to number "
|
||||
"of rows [%d], got: [%d]",
|
||||
m, lower);
|
||||
KERNEL_CHECK_FALSE((upper <= n), KERNEL_STATUS_PARAM_INVALID,
|
||||
"num_lower must be negative or less or equal to number "
|
||||
"of cols [%d], got: [%d]",
|
||||
n, upper);
|
||||
uint64_t input_size = x->GetDataSize();
|
||||
uint64_t output_size = y->GetDataSize();
|
||||
KERNEL_CHECK_FALSE((input_size == output_size), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input data size[%llu] is not equal to output data size[%llu].", input_size, output_size);
|
||||
DataType data_type = x->GetDataType();
|
||||
switch (data_type) {
|
||||
BAND_COMPUTE_CASE(DT_INT8, int8_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_INT16, int16_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_INT32, int32_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_INT64, int64_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_UINT8, uint8_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_UINT16, uint16_t, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_FLOAT16, Eigen::half, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_FLOAT, float, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_DOUBLE, double, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_BOOL, bool, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, x, lower, upper, y, m, n, ctx)
|
||||
BAND_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, x, lower, upper, y, m, n, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("MatrixBandPart kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MatrixBandPartCpuKernel::BandCompute(Tensor *x, int64_t lower, int64_t upper, Tensor *y, int64_t rows,
|
||||
int64_t cols, CpuKernelContext &ctx) {
|
||||
T *x_addrs = reinterpret_cast<T *>(x->GetData());
|
||||
KERNEL_CHECK_NULLPTR(x_addrs, KERNEL_STATUS_PARAM_INVALID, "Get input data failed.");
|
||||
T *y_addrs = reinterpret_cast<T *>(y->GetData());
|
||||
KERNEL_CHECK_NULLPTR(y_addrs, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.");
|
||||
|
||||
T zero = static_cast<T>(0);
|
||||
int64_t data_num = x->GetDataSize() / sizeof(T);
|
||||
int64_t matrix_size = rows * cols;
|
||||
int64_t total_rows = data_num / cols;
|
||||
bool same_addr = (x_addrs == y_addrs);
|
||||
if (data_num < kParallelDataNums) {
|
||||
if (!same_addr) {
|
||||
std::fill_n(y_addrs, data_num, zero);
|
||||
}
|
||||
int64_t batch_end = (total_rows + rows - 1) / rows;
|
||||
for (int64_t i = 0; i < batch_end; i++) {
|
||||
int64_t row_begin = 0 > i * rows ? 0 % rows : 0;
|
||||
int64_t row_end = total_rows < (i + 1) * rows ? total_rows % rows : rows;
|
||||
for (int64_t m = row_begin; m < row_end; m++) {
|
||||
int64_t base_index = i * matrix_size + m * cols;
|
||||
int64_t band_start = lower < 0 ? 0 : std::min(cols, std::max(int64_t{0}, int64_t{m - lower}));
|
||||
int64_t band_end = upper < 0 ? cols : std::min(cols, int64_t{m + upper + 1});
|
||||
if (same_addr) {
|
||||
if (band_start > 0) {
|
||||
std::fill((y_addrs + base_index), (y_addrs + base_index + band_start), zero);
|
||||
}
|
||||
if (band_end < cols) {
|
||||
std::fill((y_addrs + base_index + band_end), (y_addrs + base_index + cols), zero);
|
||||
}
|
||||
} else {
|
||||
if (band_start < band_end) {
|
||||
(void)memcpy_s((y_addrs + base_index + band_start), (band_end - band_start) * sizeof(T),
|
||||
(x_addrs + base_index + band_start), (band_end - band_start) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx));
|
||||
if (max_core_num > total_rows) {
|
||||
max_core_num = total_rows;
|
||||
}
|
||||
auto shard_band = [&](int64_t start, int64_t end) {
|
||||
if (!same_addr) {
|
||||
std::fill(y_addrs + start * cols, y_addrs + end * cols, zero);
|
||||
}
|
||||
int64_t batch_begin = start / rows;
|
||||
int64_t batch_end = (end + rows - 1) / rows;
|
||||
for (int64_t i = batch_begin; i < batch_end; i++) {
|
||||
int64_t row_begin = start > i * rows ? start % rows : 0;
|
||||
int64_t row_end = end < (i + 1) * rows ? end % rows : rows;
|
||||
for (int64_t m = row_begin; m < row_end; m++) {
|
||||
int64_t base_index = i * matrix_size + m * cols;
|
||||
int64_t band_start = lower < 0 ? 0 : std::min(cols, std::max(int64_t{0}, int64_t{m - lower}));
|
||||
int64_t band_end = upper < 0 ? cols : std::min(cols, int64_t{m + upper + 1});
|
||||
if (same_addr) {
|
||||
if (band_start > 0) {
|
||||
std::fill((y_addrs + base_index), (y_addrs + base_index + band_start), zero);
|
||||
}
|
||||
if (band_end < cols) {
|
||||
std::fill((y_addrs + base_index + band_end), (y_addrs + base_index + cols), zero);
|
||||
}
|
||||
} else {
|
||||
if (band_start < band_end) {
|
||||
(void)memcpy_s((y_addrs + base_index + band_start), (band_end - band_start) * sizeof(T),
|
||||
(x_addrs + base_index + band_start), (band_end - band_start) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, total_rows, total_rows / max_core_num, shard_band),
|
||||
"MatrixBandPart Compute failed.");
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kMatrixBandPart, MatrixBandPartCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_MATRIX_BAND_PART_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_MATRIX_BAND_PART_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
class MatrixBandPartCpuKernel : public CpuKernel {
|
||||
public:
|
||||
MatrixBandPartCpuKernel() = default;
|
||||
~MatrixBandPartCpuKernel() = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
static uint32_t BandParamCheck(Tensor *x, Tensor *num_lower, Tensor *num_upper, Tensor *y);
|
||||
|
||||
template <typename T>
|
||||
static uint32_t BandCompute(Tensor *x, int64_t lower, int64_t upper, Tensor *y, int64_t rows, int64_t cols,
|
||||
CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "zeroslike.h"
|
||||
#include <cstring>
|
||||
|
||||
#include "cpu_kernel_utils.h"
|
||||
#include "utils/eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "securec.h"
|
||||
namespace {
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputNum = 1;
|
||||
const char *const kZerosLike = "ZerosLike";
|
||||
|
||||
#define ZEROSLIKE_COMPUTE_CASE(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ZerosLikePartCompute<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("ZerosLike kernel compute failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ZerosLikeCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.", kZerosLike);
|
||||
KERNEL_HANDLE_ERROR(ZerosLikeCheck(ctx), "[%s] check params failed.", kZerosLike);
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
switch (data_type) {
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_BOOL, bool, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_INT8, int8_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_INT16, int16_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_INT32, int32_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_INT64, int64_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_UINT8, uint8_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_UINT16, uint16_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_UINT32, uint32_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_UINT64, uint64_t, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_FLOAT, float, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_DOUBLE, double, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_COMPLEX64, std::complex<float>, ctx)
|
||||
ZEROSLIKE_COMPUTE_CASE(DT_COMPLEX128, std::complex<double>, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ZerosLike kernel data type [%s] not support.", DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ZerosLikeCpuKernel::ZerosLikeCheck(const CpuKernelContext &ctx) const {
|
||||
auto input = ctx.Input(0);
|
||||
auto output = ctx.Output(0);
|
||||
KERNEL_CHECK_NULLPTR(input->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.")
|
||||
KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed")
|
||||
KERNEL_CHECK_NULLPTR(input->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.")
|
||||
std::vector<int64_t> shape_x = input->GetTensorShape()->GetDimSizes();
|
||||
size_t shape_size = shape_x.size();
|
||||
KERNEL_CHECK_FALSE((shape_size > 0), KERNEL_STATUS_PARAM_INVALID, "Input must be at least rank 1, got [%zu].",
|
||||
shape_x.size())
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t ZerosLikeCpuKernel::ZerosLikePartCompute(const CpuKernelContext &ctx) {
|
||||
size_t data_num = static_cast<size_t>(ctx.Input(0)->NumElements());
|
||||
Tensor *y = ctx.Output(0);
|
||||
auto y_addr = y->GetData();
|
||||
auto ret = memset_s(y_addr, data_num * sizeof(T), 0, data_num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("memset_s error, ret=%d", ret);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_CPU_KERNEL(kZerosLike, ZerosLikeCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_ZEROSLIKE_H
|
||||
#define AICPU_KERNELS_NORMALIZED_ZEROSLIKE_H
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ZerosLikeCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ZerosLikeCpuKernel() = default;
|
||||
~ZerosLikeCpuKernel() override = default;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ZerosLikeCheck(const CpuKernelContext &ctx) const;
|
||||
|
||||
template <typename T>
|
||||
uint32_t ZerosLikePartCompute(const CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -138,7 +138,10 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
mindspore::kCheckNumericsOpName,
|
||||
mindspore::kFloorDivOpName,
|
||||
mindspore::kLog1pOpName,
|
||||
mindspore::kMulOpName};
|
||||
mindspore::kMulOpName,
|
||||
mindspore::kConjOpName,
|
||||
mindspore::kZerosLikeOpName,
|
||||
mindspore::kMatrixBandPartOpName};
|
||||
|
||||
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
|
||||
static const std::string kCpuKernelSoName = "mindspore_cpu_kernels";
|
||||
|
|
|
@ -8471,10 +8471,8 @@ def remainder(x, y):
|
|||
out_{i} = input_{i} \text{ % } other_{i}
|
||||
|
||||
.. warning::
|
||||
- The input data does not support 0.
|
||||
- When the elements of input exceed 2048, the accuracy of operator cannot guarantee the requirement of
|
||||
double thousandths in the mini form.
|
||||
- Due to different architectures, the calculation results of this operator on NPU and CPU may be inconsistent.
|
||||
- When the elements of input exceed 2048, there might be accuracy problems.
|
||||
- The calculation results of this operator on Ascend and CPU might be inconsistent.
|
||||
- If shape is expressed as (D1,D2... ,Dn), then D1\*D2... \*DN<=1000000,n<=8.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue