diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 4638f6abaa4..94cbd35552c 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -348,3 +348,5 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpSpecialComputeComplex mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/sparse_dense_cwise_utils.cc:aicpu::SparseDenseCwiseOpKernel::SparseDenseCwiseOpBcastComputeComplex mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc:aicpu::ResizeBicubicGrad +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.cc:aicpu::SegmentMaxCpuKernel::SegmentMaxCompute +mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.cc:aicpu::ExtractGlimpseCpuKernel::Compute diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 2421ceae6d6..735452065de 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -312,6 +312,7 @@ constexpr auto kExpandDOpName = "ExpandD"; constexpr auto kExpandDimsOpName = "ExpandDims"; constexpr auto kExpOpName = "Exp"; constexpr auto kExtractGlimpse = "ExtractGlimpse"; +constexpr auto kExtractGlimpseOpName = "ExtractGlimpse"; constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; constexpr auto kEyeOpName = "Eye"; constexpr auto kFastGeLUOpName = "FastGeLU"; @@ -468,6 +469,7 @@ constexpr auto kLSTMGradOpName = "LSTMGrad"; constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; constexpr auto kLSTMOpName = "LSTM"; constexpr auto kLstsqOpName = "Lstsq"; +constexpr auto kLuSolveOpName = "LuSolve"; constexpr auto kLuUnpackOpName = "LuUnpack"; constexpr auto kLuUnpackGradOpName = "LuUnpackGrad"; constexpr auto kMaskedFillOpName = "MaskedFill"; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/common/runtime_tensor_desc.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/common/runtime_tensor_desc.h index 1dcf7c26aae..12142159fd3 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/common/runtime_tensor_desc.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/common/runtime_tensor_desc.h @@ -19,6 +19,7 @@ namespace ge { constexpr int64_t kMaxDimSize = 32; +constexpr int64_t DIM_SIZE2 = 2; #pragma pack(push, 1) struct RuntimeTensorDesc { diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.cc deleted file mode 100644 index 3038e7fce96..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.cc +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cache_swap_table.h" -#include -#include -#include "cpu_types.h" -#include "kernel_log.h" -#include "status.h" -#include "utils/sparse_tensor.h" -#include "utils/kernel_util.h" - -namespace { -const char *const kCacheSwapTable = "CacheSwapTable"; -} - -namespace aicpu { -template -uint32_t CacheSwapTableTask(std::vector &inputs, std::vector &outputs, int64_t batch_size, - int64_t output_size, int64_t one_line_col, int type_size) { - if (inputs.size() == 0 || outputs.size() == 0) { - KERNEL_LOG_ERROR("CacheSwapTable input or output is empty."); - return KERNEL_STATUS_PARAM_INVALID; - } - - char *cache_table = reinterpret_cast(inputs[0]->GetData()); - T *swap_cache_idx = reinterpret_cast(inputs[1]->GetData()); - uint64_t swap_cache_idx_size = inputs[1]->GetDataSize(); - char *miss_value = reinterpret_cast(inputs[2]->GetData()); - - char *old_value = reinterpret_cast(outputs[0]->GetData()); - - errno_t ret = memset_s(old_value, static_cast(output_size * type_size), 0x00, - static_cast(output_size * type_size)); - if (ret != EOK) { - KERNEL_LOG_ERROR("Memset failed, result[%d]", ret); - return KERNEL_STATUS_INNER_ERROR; - } - - uint64_t single_copy_size = static_cast(type_size * one_line_col); - - if (swap_cache_idx_size < static_cast(batch_size)) { - KERNEL_LOG_ERROR( - "The value of swap_cache_idx_size:[%llu] must be less than " - "batch_size:[%lld]", - swap_cache_idx_size, batch_size); - return KERNEL_STATUS_INNER_ERROR; - } - - uint64_t old_value_size = outputs[0]->GetDataSize(); - uint64_t cache_table_size = inputs[0]->GetDataSize(); - for (int64_t i = 0; i < batch_size; ++i) { - if (swap_cache_idx[i] < 0) { - continue; - } - ret = memcpy_s(old_value + i * single_copy_size, old_value_size, cache_table + swap_cache_idx[i] * single_copy_size, - single_copy_size); - old_value_size -= single_copy_size; - if (ret != EOK) { - KERNEL_LOG_ERROR("CacheSwapTable memcpy failed, result [%d].", ret); - return KERNEL_STATUS_INNER_ERROR; - } - ret = memcpy_s(cache_table + swap_cache_idx[i] * single_copy_size, cache_table_size, - miss_value + i * single_copy_size, single_copy_size); - cache_table_size -= single_copy_size; - if (ret != EOK) { - KERNEL_LOG_ERROR("CacheSwapTable memcpy failed, result [%d].", ret); - return KERNEL_STATUS_INNER_ERROR; - } - } - return KERNEL_STATUS_OK; -} - -uint32_t CacheSwapTableMsCpuKernel::DoCompute() { - std::map &, std::vector &, int64_t &, int64_t &, - int64_t &, int &)>> - calls; - calls[DT_INT32] = CacheSwapTableTask; - calls[DT_INT64] = CacheSwapTableTask; - - if (calls.find(indices_type_) == calls.end()) { - KERNEL_LOG_ERROR( - "CacheSwapTableMsCpuKernel op doesn't support indices tensor types: " - "[%s]", - DTypeStr(indices_type_).c_str()); - return KERNEL_STATUS_PARAM_INVALID; - } - - int type_size = GetSizeByDataType(param_type_); - return calls[indices_type_](inputs_, outputs_, batch_size_, output_size_, one_line_col_, type_size); -} - -uint32_t CacheSwapTableMsCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) { - KERNEL_LOG_INFO("GetInputAndCheck start!"); - // get input Tensors - const uint32_t kNumInput = 3; - for (uint32_t i = 0; i < kNumInput; ++i) { - Tensor *tensor = ctx.Input(i); - KERNEL_CHECK_NULLPTR(tensor, KERNEL_STATUS_PARAM_INVALID, "Get input tensor[%d] failed", i) - inputs_.push_back(tensor); - } - // get output Tensors - const uint32_t kNumOutput = 1; - for (uint32_t i = 0; i < kNumOutput; ++i) { - Tensor *tensor = ctx.Output(i); - KERNEL_CHECK_NULLPTR(tensor, KERNEL_STATUS_PARAM_INVALID, "Get output tensor[%d] failed", i) - outputs_.push_back(tensor); - } - // get param type - param_type_ = static_cast(inputs_[0]->GetDataType()); - indices_type_ = static_cast(inputs_[1]->GetDataType()); - KERNEL_LOG_INFO("GetInputAndCheck success!"); - - std::shared_ptr cache_table_shape = ctx.Input(0)->GetTensorShape(); - std::shared_ptr indices_shape = ctx.Input(1)->GetTensorShape(); - - for (int32_t i = 1; i < cache_table_shape->GetDims(); ++i) { - KERNEL_CHECK_ASSIGN_64S_MULTI(one_line_col_, cache_table_shape->GetDimSize(i), one_line_col_, - KERNEL_STATUS_PARAM_INVALID); - } - for (int32_t i = 0; i < indices_shape->GetDims(); ++i) { - KERNEL_CHECK_ASSIGN_64S_MULTI(batch_size_, indices_shape->GetDimSize(i), batch_size_, KERNEL_STATUS_PARAM_INVALID); - } - output_size_ = batch_size_ * one_line_col_; - return KERNEL_STATUS_OK; -} - -uint32_t CacheSwapTableMsCpuKernel::Compute(CpuKernelContext &ctx) { - uint32_t res = GetInputAndCheck(ctx); - if (res != KERNEL_STATUS_OK) { - return res; - } - - res = DoCompute(); - if (res != KERNEL_STATUS_OK) { - KERNEL_LOG_ERROR("Compute failed"); - return res; - } - return KERNEL_STATUS_OK; -} -REGISTER_CPU_KERNEL(kCacheSwapTable, CacheSwapTableMsCpuKernel); -} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.cc new file mode 100644 index 00000000000..f41401ad138 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "extract_glimpse.h" +#include +#include +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +using namespace std; +random_device rd; +mt19937 gen(rd()); +uniform_real_distribution dis_uniform(0.0f, 255.0f); +normal_distribution dis_normal(10, 0.5); +#define SHED 2048 +namespace { +const uint32_t kOutputNum = 1; +const uint32_t kInputNum = 3; +const char *kExtractGlimpse = "ExtractGlimpse"; +} // namespace +namespace aicpu { +uint32_t ExtractGlimpseCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "ExtractGlimpse check input and output number failed."); + KERNEL_HANDLE_ERROR(ExtractGlimpseCheck(ctx), "ExtractGlimpse check params failed."); + Tensor *x = ctx.Input(0); + Tensor *ss = ctx.Input(1); + Tensor *offsets = ctx.Input(2); + Tensor *y = ctx.Output(0); + AttrValue *centered = ctx.GetAttr("centered"); + AttrValue *normalized = ctx.GetAttr("normalized"); + AttrValue *uniform_noise = ctx.GetAttr("uniform_noise"); + AttrValue *noise = ctx.GetAttr("noise"); + float *x_data = (float *)x->GetData(); + int32_t *ss_data = (int32_t *)ss->GetData(); + float *offsets_data = (float *)offsets->GetData(); + float *y_data = (float *)y->GetData(); + uint64_t offsets_cnt = offsets->GetTensorShape()->GetDimSize(0); + uint64_t batch_cnt = x->GetTensorShape()->GetDimSize(0); + KERNEL_CHECK_FALSE(offsets_cnt == batch_cnt, KERNEL_STATUS_PARAM_INVALID, "offsets should equal to batches") + int64_t image_height = x->GetTensorShape()->GetDimSize(1); + int64_t image_width = x->GetTensorShape()->GetDimSize(2); + int64_t channels = x->GetTensorShape()->GetDimSize(3); + uint64_t g_height = ss_data[0], g_width = ss_data[1]; + uint64_t size1 = image_width * image_height * channels; + uint64_t size2 = image_width * channels; + uint64_t size3 = g_height * g_width * channels; + uint64_t size4 = size3 / g_height; + int64_t g_size = g_width * g_height; + if (batch_cnt > SHED) { + uint32_t min_core = 1; + uint64_t max_core = std::max(min_core, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + max_core = min(max_core, (uint64_t)batch_cnt); + auto fun = [&](size_t st, size_t ed) { + for (auto i = st; i < ed; i++) { + float x = offsets_data[i << 1], y = offsets_data[1 + (i << 1)]; + if (normalized->GetBool()) { + x *= image_height; + y *= image_width; + } + if (centered->GetBool()) { + x /= 2.0f; + y /= 2.0f; + x += image_height / 2.0f; + y += image_width / 2.0f; + } + x -= g_height / 2.0f; + y -= g_width / 2.0f; + for (int64_t v = 0; v < g_size; v++) { + int64_t j = v / g_width, k = v % g_width; + int64_t a = (int64_t)x + j, b = (int64_t)y + k; + uint64_t pos_y = i * size3 + j * size4 + k * channels; + if (a < 0 || a >= image_height || b < 0 || b >= image_width) { + for (int u = 0; u < channels; u++) { + if (uniform_noise->GetBool()) + y_data[pos_y + u] = dis_uniform(gen); + else if (noise->GetString() == "zero") + y_data[pos_y + u] = 0.0f; + else if (noise->GetString() == "gaussian") + y_data[pos_y + u] = max(0.0f, dis_normal(gen)); + else { + KERNEL_LOG_ERROR("noise type [%s] unsupported.", noise->GetString().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + } + continue; + } + uint64_t pos_x = i * size1 + a * size2 + b * channels; + for (int u = 0; u < channels; u++) { + y_data[pos_y + u] = x_data[pos_x + u]; + } + } + } + return KERNEL_STATUS_OK; + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_cnt, batch_cnt / max_core, fun), + "ExtractGlimpse Compute failed."); + } else { + for (uint64_t i = 0; i < batch_cnt; i++) { + float x = offsets_data[i << 1], y = offsets_data[1 + (i << 1)]; + if (normalized->GetBool()) { + x *= image_height; + y *= image_width; + } + if (centered->GetBool()) { + x /= 2.0f; + y /= 2.0f; + x += image_height / 2.0f; + y += image_width / 2.0f; + } + x -= g_height / 2.0f; + y -= g_width / 2.0f; + if (g_size < SHED) { + for (int64_t v = 0; v < g_size; v++) { + int64_t j = v / g_width, k = v % g_width; + int64_t a = (int64_t)x + j, b = (int64_t)y + k; + uint64_t pos_y = i * size3 + j * size4 + k * channels; + if (a < 0 || a >= image_height || b < 0 || b >= image_width) { + for (int u = 0; u < channels; u++) { + if (uniform_noise->GetBool()) + y_data[pos_y + u] = dis_uniform(gen); + else if (noise->GetString() == "zero") + y_data[pos_y + u] = 0.0f; + else if (noise->GetString() == "gaussian") + y_data[pos_y + u] = max(0.0f, dis_normal(gen)); + else { + KERNEL_LOG_ERROR("noise type [%s] unsupported.", noise->GetString().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + } + continue; + } + uint64_t pos_x = i * size1 + a * size2 + b * channels; + for (int u = 0; u < channels; u++) { + y_data[pos_y + u] = x_data[pos_x + u]; + } + } + } else { + uint32_t min_core = 1; + uint64_t max_core = std::max(min_core, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + max_core = min(max_core, (uint64_t)g_size); + auto fun = [&](size_t st, size_t ed) { + for (auto v = st; v < ed; v++) { + int64_t j = v / g_width, k = v % g_width; + int64_t a = (int64_t)x + j, b = (int64_t)y + k; + uint64_t pos_y = i * size3 + j * size4 + k * channels; + if (a < 0 || a >= image_height || b < 0 || b >= image_width) { + for (int u = 0; u < channels; u++) + if (uniform_noise->GetBool()) + y_data[pos_y + u] = dis_uniform(gen); + else if (noise->GetString() == "zero") + y_data[pos_y + u] = 0.0f; + else if (noise->GetString() == "gaussian") + y_data[pos_y + u] = max(0.0f, dis_normal(gen)); + else { + KERNEL_LOG_ERROR("noise type [%s] unsupported.", noise->GetString().c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + continue; + } + uint64_t pos_x = i * size1 + a * size2 + b * channels; + for (int u = 0; u < channels; u++) { + y_data[pos_y + u] = x_data[pos_x + u]; + } + } + return KERNEL_STATUS_OK; + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, g_size, g_size / max_core, fun), + "ExtractGlimpse Compute failed."); + } + } + } + return KERNEL_STATUS_OK; +} +uint32_t ExtractGlimpseCpuKernel::ExtractGlimpseCheck(CpuKernelContext &ctx) { + Tensor *x = ctx.Input(0); + Tensor *ss = ctx.Input(1); + Tensor *offsets = ctx.Input(2); + Tensor *y = ctx.Output(0); + AttrValue *centered = ctx.GetAttr("centered"); + AttrValue *normalized = ctx.GetAttr("normalized"); + AttrValue *uniform_noise = ctx.GetAttr("uniform_noise"); + AttrValue *noise = ctx.GetAttr("noise"); + KERNEL_CHECK_NULLPTR(x, KERNEL_STATUS_PARAM_INVALID, "Get input 0 failed.") + KERNEL_CHECK_NULLPTR(ss, KERNEL_STATUS_PARAM_INVALID, "Get input 1 failed.") + KERNEL_CHECK_NULLPTR(offsets, KERNEL_STATUS_PARAM_INVALID, "Get input 2 failed.") + KERNEL_CHECK_NULLPTR(y, KERNEL_STATUS_PARAM_INVALID, "Get output 0 failed.") + KERNEL_CHECK_NULLPTR(centered, KERNEL_STATUS_PARAM_INVALID, "Get attribute centered failed.") + KERNEL_CHECK_NULLPTR(normalized, KERNEL_STATUS_PARAM_INVALID, "Get attribute normalized failed.") + KERNEL_CHECK_NULLPTR(uniform_noise, KERNEL_STATUS_PARAM_INVALID, "Get attribute uniform_noise failed.") + KERNEL_CHECK_NULLPTR(noise, KERNEL_STATUS_PARAM_INVALID, "Get attribute noise failed.") + KERNEL_CHECK_NULLPTR(x->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 0 data failed.") + KERNEL_CHECK_NULLPTR(ss->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 1 data failed.") + KERNEL_CHECK_NULLPTR(offsets->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input 2 data failed.") + KERNEL_CHECK_NULLPTR(y->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output 0 data failed.") + KERNEL_CHECK_FALSE(x->GetDataType() == DT_FLOAT && ss->GetDataType() == DT_INT32 && + offsets->GetDataType() == DT_FLOAT && y->GetDataType() == DT_FLOAT, + KERNEL_STATUS_PARAM_INVALID, "data type error.") + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kExtractGlimpse, ExtractGlimpseCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.h new file mode 100644 index 00000000000..44426f057ef --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/extract_glimpse.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_IMPL_EXTRACT_GLIMPSE_H_ +#define AICPU_IMPL_EXTRACT_GLIMPSE_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class ExtractGlimpseCpuKernel : public CpuKernel { + public: + ExtractGlimpseCpuKernel() = default; + ~ExtractGlimpseCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + static uint32_t ExtractGlimpseCheck(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.cc index 635672a4982..6d75d689887 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.cc @@ -124,27 +124,6 @@ double FFTWithSizeCpuKernel::Getnormalized(int64_t n, std::string normalized, bo if (normalized == "backward") result = 1.0 / n; if (normalized == "ortho") result = 1.0 / sqrt((double)n); } - // if (signal_ndim == 1) { - // result = sqrt((double)out_shape[out_shape.size() - 1]); - // } else if (signal_ndim == 2) { - // result = sqrt((double)(out_shape[out_shape.size() - 1] * - // out_shape[out_shape.size() - 2])); - // } else { - // result = sqrt((double)(out_shape[out_shape.size() - 1] * - // out_shape[out_shape.size() - 2] * - // out_shape[out_shape.size() - 3])); - // } - // if (is_reverse) { - // if (result == 0) { - // KERNEL_LOG_ERROR("DivideByZeroExcepiton"); - // } - // result = 1.0 / result; - // } - // KERNEL_LOG_DEBUG( - // "FFTWithSizeCpuKernel[GetNormalized], " - // "input_shape[%s] normalize[%s]. " - // "is_reverse: [%d]. norm_:[%lf]", - // VectorToString(out_shape).c_str(), normalized, is_reverse, result); std::cout << "result = " << result << std::endl; return result; } @@ -350,14 +329,7 @@ uint32_t FFTWithSizeCpuKernel::FFTWithSizeCompute(CpuKernelContext &ctx, bool on if (is_real) { inverse = real_inverse; } - std::cout << out; - std::cout << "==========="; - // if - // std::vector out_shape(out.dimensions().begin(), - // out.dimensions().end()); - // if (is_real && !inverse) { - // out_shape.back() = x_shape.back(); - // } + std::cout << out; auto cout = x_shape_ptr->NumElements(); auto norm = Getnormalized(cout, normalized, inverse); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.h index 2eed602a998..4e798fdae27 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fft_with_size.h @@ -1,18 +1,3 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ #ifndef AICPU_KERNELS_NORMALIZED_FFTWITHSIZE_H_ #define AICPU_KERNELS_NORMALIZED_FFTWITHSIZE_H_ @@ -37,4 +22,4 @@ class FFTWithSizeCpuKernel : public CpuKernel { static double Getnormalized(int64_t n, std::string normalized, bool is_reverse); }; } // namespace aicpu -#endif \ No newline at end of file +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.cc index 23d07481b4a..917ffa0f6ce 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.cc @@ -15,129 +15,160 @@ */ #include "fill.h" +#include "cpu_kernel_utils.h" #include "utils/eigen_tensor.h" #include "utils/kernel_util.h" namespace { -const char *const kFill = "Fill"; -} +const uint32_t kOutputNum = 1; +const uint32_t kInputNum = 2; +const char *kFill = "Fill"; +const char *kFillV2 = "FillV2"; +const int64_t kParallelDataNumCriticalPoint1 = 128 * 1024; +const int64_t kParallelDataNumCriticalPoint2 = 2 * 1024 * 1024; + +#define CALCULATE_DIMS_DTYPE_CASE(DTYPE, TYPE) \ + case (DTYPE): { \ + if (CalculateDims(dims_tensor, dims) != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("Fill kernel calculate dims failed."); \ + return KERNEL_STATUS_PARAM_INVALID; \ + } \ + break; \ + } + +#define FILL_GENERATE_DTYPE_CASE(DTYPE, TYPE) \ + case (DTYPE): { \ + FillOutput(ctx, value_tensor, output); \ + break; \ + } +} // namespace namespace aicpu { -template -void FillGenerateCase(Tensor *&value_tensor, Tensor *&output) { - auto value = *(reinterpret_cast(value_tensor->GetData())); - if (AddrAlignedCheck(output->GetData())) { - Eigen::TensorMap, Eigen::Aligned> eigen_output(static_cast(output->GetData()), - output->GetTensorShape()->NumElements()); - eigen_output.setConstant(value); - } else { - Eigen::TensorMap, Eigen::Unaligned> eigen_output(static_cast(output->GetData()), - output->GetTensorShape()->NumElements()); - eigen_output.setConstant(value); - } -} +uint32_t FillCpuKernel::Compute(CpuKernelContext &ctx) { + // 校验输入个数和输出个数,以及输入和输入tensor的属性是否为空 + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check input and output number failed."); -uint32_t FillCpuKernel::GetDimsByType(CpuKernelContext &ctx) { - dims.clear(); + std::vector dims; Tensor *dims_tensor = ctx.Input(0); - KERNEL_CHECK_NULLPTR(dims_tensor, KERNEL_STATUS_PARAM_INVALID, "Get dims input failed") - uint32_t ret; auto dims_dtype = dims_tensor->GetDataType(); switch (dims_dtype) { - case (DT_INT32): - ret = CalcDims(dims_tensor, dims); - break; - case (DT_INT64): - ret = CalcDims(dims_tensor, dims); - break; + CALCULATE_DIMS_DTYPE_CASE(DT_INT32, int32_t) + CALCULATE_DIMS_DTYPE_CASE(DT_INT64, int64_t) default: - KERNEL_LOG_ERROR( - "Fill kernel dims data_type [%u] not support, support data_types: " - "DT_INT32, DT_INT64", - dims_dtype); + KERNEL_LOG_ERROR("Fill kernel dims data_type [%u] not support, support data_types: DT_INT32, DT_INT64.", + dims_dtype); return KERNEL_STATUS_PARAM_INVALID; } - if (ret != KERNEL_STATUS_OK) { - KERNEL_LOG_ERROR("Fill kernel calculate dims failed"); - } - return ret; -} -uint32_t FillCpuKernel::Compute(CpuKernelContext &ctx) { - uint32_t check = GetDimsByType(ctx); - if (check != KERNEL_STATUS_OK) { - return check; - } Tensor *value_tensor = ctx.Input(1); - KERNEL_CHECK_NULLPTR(value_tensor, KERNEL_STATUS_PARAM_INVALID, "Get value input failed") - KERNEL_CHECK_NULLPTR(value_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get value input data failed") - KERNEL_CHECK_NULLPTR(value_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get value input shape failed") - if (!value_tensor->GetTensorShape()->GetDimSizes().empty()) { + if (value_tensor->NumElements() != 1) { KERNEL_LOG_ERROR("Fill kernel value input is not a scalar."); return KERNEL_STATUS_PARAM_INVALID; } + Tensor *output = ctx.Output(0); - KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output failed") - KERNEL_CHECK_NULLPTR(output->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get output data failed") - KERNEL_CHECK_NULLPTR(output->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get output shape failed") - if (output->GetTensorShape()->GetDimSizes() != dims) { + if (output->GetTensorShape()->GetDims() != static_cast(dims.size())) { KERNEL_LOG_ERROR("Fill kernel output shape not matched."); return KERNEL_STATUS_PARAM_INVALID; } + if (output->GetTensorShape()->GetDimSizes() != dims) { + output->GetTensorShape()->SetDimSizes(dims); + } + auto input_dtype = value_tensor->GetDataType(); auto output_dtype = output->GetDataType(); if (input_dtype != output_dtype) { - KERNEL_LOG_ERROR("Fill kernel data type not matched, value input dtype [%u], output dtype [%u].", input_dtype, - output_dtype); + KERNEL_LOG_ERROR( + "Fill kernel data type not matched, value input dtype [%u], output dtype [%u], support data_types: " + "DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, " + "DT_UINT32, DT_UINT64, DT_UINT8, DT_BOOL.", + input_dtype, output_dtype); return KERNEL_STATUS_PARAM_INVALID; } - std::map> calls; - calls[DT_INT8] = FillGenerateCase; - calls[DT_UINT8] = FillGenerateCase; - calls[DT_INT16] = FillGenerateCase; - calls[DT_UINT16] = FillGenerateCase; - calls[DT_INT32] = FillGenerateCase; - calls[DT_UINT32] = FillGenerateCase; - calls[DT_INT64] = FillGenerateCase; - calls[DT_UINT64] = FillGenerateCase; - calls[DT_BOOL] = FillGenerateCase; - calls[DT_FLOAT16] = FillGenerateCase; - calls[DT_FLOAT] = FillGenerateCase; - calls[DT_DOUBLE] = FillGenerateCase; - - if (calls.find(output_dtype) == calls.end()) { - KERNEL_LOG_ERROR("Fill kernel data type [%u] not support", output_dtype); - return KERNEL_STATUS_PARAM_INVALID; + switch (output_dtype) { + FILL_GENERATE_DTYPE_CASE(DT_INT8, int8_t) + FILL_GENERATE_DTYPE_CASE(DT_UINT8, uint8_t) + FILL_GENERATE_DTYPE_CASE(DT_INT16, int16_t) + FILL_GENERATE_DTYPE_CASE(DT_UINT16, uint16_t) + FILL_GENERATE_DTYPE_CASE(DT_INT32, int32_t) + FILL_GENERATE_DTYPE_CASE(DT_UINT32, uint32_t) + FILL_GENERATE_DTYPE_CASE(DT_INT64, int64_t) + FILL_GENERATE_DTYPE_CASE(DT_UINT64, uint64_t) + FILL_GENERATE_DTYPE_CASE(DT_BOOL, bool) + FILL_GENERATE_DTYPE_CASE(DT_FLOAT16, Eigen::half) + FILL_GENERATE_DTYPE_CASE(DT_FLOAT, float) + FILL_GENERATE_DTYPE_CASE(DT_DOUBLE, double) + FILL_GENERATE_DTYPE_CASE(DT_COMPLEX64, std::complex) + FILL_GENERATE_DTYPE_CASE(DT_COMPLEX128, std::complex) + default: + KERNEL_LOG_ERROR( + "Fill kernel data type [%u] not support, not support data_types: DT_STRING, DT_DUAL_SUB_INT8, " + "DT_DUAL_SUB_UINT8, DT_QUINT8, DT_QINT8, DT_QINT32, DT_QINT16, DT_QUINT16, DT_RESOURCE, DT_STRING_REF, " + "DT_DUAL, DT_UNDEFINED.", + output_dtype); + return KERNEL_STATUS_PARAM_INVALID; } - calls[output_dtype](value_tensor, output); + return KERNEL_STATUS_OK; } template -uint32_t FillCpuKernel::CalcDims(const Tensor *dims_tensor, std::vector &dim_vec) { +uint32_t FillCpuKernel::CalculateDims(const Tensor *dims_tensor, std::vector &dims) { + // 获取第一个输入tensor中的元素个数,第一个输入是一个一维的tensor(dims_tensor) uint64_t data_num = dims_tensor->GetDataSize() / sizeof(T); - if (data_num == 0) { - KERNEL_LOG_INFO("Fill kernel: dims is empty, fill scalar output."); - return KERNEL_STATUS_OK; - } + auto dims_data = reinterpret_cast(dims_tensor->GetData()); - KERNEL_CHECK_NULLPTR(dims_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get dims data failed") for (uint64_t i = 0; i < data_num; i++) { - auto dim = *(reinterpret_cast(dims_tensor->GetData()) + i); + auto dim = *(dims_data + i); if (dim < 0) { - KERNEL_LOG_ERROR("Fill kernel: input dim [%llu] is negative, value=[%lld]", i, static_cast(dim)); + KERNEL_LOG_ERROR("dims input dim [%llu] is negative, value=[%lld].", i, static_cast(dim)); return KERNEL_STATUS_PARAM_INVALID; } - // zero dim is different from empty dim. if (dim == 0) { - KERNEL_LOG_INFO("Fill kernel: input dim [%llu] is zero", i); + KERNEL_LOG_INFO("dims input dim [%llu] is zero.", i); + dims.clear(); + break; } - dim_vec.emplace_back(dim); + dims.emplace_back(dim); } return KERNEL_STATUS_OK; } +template +void FillCpuKernel::FillOutput(CpuKernelContext &ctx, const Tensor *value_tensor, Tensor *output) { + auto value = reinterpret_cast(value_tensor->GetData()); + auto output_data = reinterpret_cast(output->GetData()); + int64_t data_num = output->NumElements(); + + if (data_num >= kParallelDataNumCriticalPoint1) { + uint32_t min_core_num = 1; + uint32_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx)); + + if (data_num <= kParallelDataNumCriticalPoint2) { + max_core_num = std::min(max_core_num, 4U); + } + + if (max_core_num > data_num) { + max_core_num = data_num; + } + + auto shared_fill = [&](int64_t start, int64_t end) { SpecialFillOutput(start, end, output_data, value); }; + + CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shared_fill); + } else { + SpecialFillOutput(0, data_num, output_data, value); + } +} + +template +void FillCpuKernel::SpecialFillOutput(int64_t start, int64_t end, T *output_data, const T *value) { + for (int64_t i = start; i < end; i++) { + *(output_data + i) = *(value); + } +} + REGISTER_CPU_KERNEL(kFill, FillCpuKernel); +REGISTER_CPU_KERNEL(kFillV2, FillCpuKernel); } // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.h index db718fec59b..ed592979a0c 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/fill.h @@ -1,5 +1,5 @@ /** - * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef AICPU_KERNELS_NORMALIZED_FILL_H -#define AICPU_KERNELS_NORMALIZED_FILL_H +#ifndef AICPU_KERNELS_NORMALIZED_FILL_H_ +#define AICPU_KERNELS_NORMALIZED_FILL_H_ #include "cpu_ops_kernel.h" @@ -23,21 +23,18 @@ namespace aicpu { class FillCpuKernel : public CpuKernel { public: FillCpuKernel() = default; - ~FillCpuKernel() override = default; + ~FillCpuKernel() = default; uint32_t Compute(CpuKernelContext &ctx) override; private: - uint32_t GetDimsByType(CpuKernelContext &ctx); - /** - * @brief calc dims from input dims tensor - * @param dims_tensor input dims tensor - * @param dims output shape dims - * @return status if success - */ template - uint32_t CalcDims(const Tensor *dims_tensor, std::vector &dims); + uint32_t CalculateDims(const Tensor *dims_tensor, std::vector &dims); - std::vector dims; + template + void FillOutput(CpuKernelContext &ctx, const Tensor *value_tensor, Tensor *output); + + template + void SpecialFillOutput(int64_t start, int64_t end, T *output_data, const T *value); }; } // namespace aicpu -#endif // AICPU_KERNELS_NORMALIZED_FILL_H_ +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.cc new file mode 100644 index 00000000000..37367dba9af --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.cc @@ -0,0 +1,132 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_normal_reverse.h" +#include +#include +#include "cpu_kernel_utils.h" +#include "cpu_ops_kernel.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include +#include + +#include "Eigen/Core" +using namespace std; +using namespace Eigen; + +namespace { +const uint32_t kNumInput = 1; +const uint32_t kNumOutput = 1; + +const char *kLogNormalReverse = "LogNormalReverse"; +const int64_t kParallelDataNumSameShape = 16 * 1024; +const int64_t kParallelDataNumMid = 128 * 1024; +} // namespace +namespace aicpu { +uint32_t LogNormalReverseCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kNumInput, kNumOutput), "LogNormalReverse check input and output failed."); + // get and check input + Tensor *input = ctx.Input(0); + inputs_.push_back(input); + + // get output Tensors + Tensor *output = ctx.Output(0); + outputs_.push_back(output); + + return KERNEL_STATUS_OK; +} + +template +uint32_t LogNormalReverseCpuKernel::DoCompute(CpuKernelContext &ctx) { + float input_mean = 1.0; + float input_std = 2.0; + + auto mean_value = ctx.GetAttr("mean"); + auto std_value = ctx.GetAttr("std"); + + if (mean_value != nullptr) { + input_mean = mean_value->GetFloat(); + } + if (std_value != nullptr) { + input_std = std_value->GetFloat(); + } + + T *output_y = reinterpret_cast(outputs_[0]->GetData()); + + static default_random_engine random_engine(time(0)); + static std::normal_distribution normal_value(input_mean, input_std); + + int64_t Nums = inputs_[0]->GetTensorShape()->NumElements(); + + int64_t data_num = Nums; + if (data_num >= kParallelDataNumSameShape) { + uint32_t max_core_num = std::max(1U, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum); + + if (data_num <= kParallelDataNumMid) { + max_core_num = std::min(max_core_num, 4U); + } + if (max_core_num > data_num) { + max_core_num = data_num; + } + + auto shared_lognormalreverse = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + output_y[i] = static_cast(std::exp(normal_value(random_engine))); + } + }; + + if (max_core_num == 0) { + max_core_num = 1; + } + CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shared_lognormalreverse); + } else { + for (int64_t i = 0; i < Nums; i++) { + output_y[i] = static_cast(std::exp(normal_value(random_engine))); + } + } + return KERNEL_STATUS_OK; +} + +uint32_t LogNormalReverseCpuKernel::Compute(CpuKernelContext &ctx) { + uint32_t res = GetInputAndCheck(ctx); + if (res != KERNEL_STATUS_OK) { + return res; + } + + DataType input_type{ctx.Input(0)->GetDataType()}; + switch (input_type) { + case (DT_FLOAT16): { + DoCompute(ctx); + break; + } + case (DT_FLOAT): { + DoCompute(ctx); + break; + } + default: + KERNEL_LOG_ERROR("[%s] Data type of input is not support, input data type is [%s].", ctx.GetOpType().c_str(), + DTypeStr(input_type).c_str()); + res = KERNEL_STATUS_PARAM_INVALID; + } + if (res != KERNEL_STATUS_OK) { + KERNEL_LOG_ERROR("log normal reverse failed"); + return res; + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kLogNormalReverse, LogNormalReverseCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.h similarity index 60% rename from mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.h rename to mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.h index 6fef5ba7bc3..f1e1b60ec7e 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cache_swap_table.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/log_normal_reverse.h @@ -1,44 +1,38 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef AICPU_KERNELS_NORMALIZED_CACHE_SWAP_TABLE_H -#define AICPU_KERNELS_NORMALIZED_CACHE_SWAP_TABLE_H - -#include -#include -#include "cpu_ops_kernel.h" - -namespace aicpu { -class CacheSwapTableMsCpuKernel : public CpuKernel { - public: - ~CacheSwapTableMsCpuKernel() = default; - uint32_t Compute(CpuKernelContext &ctx) override; - - private: - uint32_t DoCompute(); - - uint32_t GetInputAndCheck(CpuKernelContext &ctx); - - int64_t batch_size_ = 1; - int64_t one_line_col_ = 1; - int64_t output_size_ = 1; - - std::vector inputs_; - std::vector outputs_; - DataType param_type_ = DT_FLOAT; - DataType indices_type_ = DT_INT32; -}; -} // namespace aicpu -#endif +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_KERNELS_NORMALIZED_LOGNORMALREVERSE_H_ +#define AICPU_KERNELS_NORMALIZED_LOGNORMALREVERSE_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class LogNormalReverseCpuKernel : public CpuKernel { + public: + LogNormalReverseCpuKernel() = default; + ~LogNormalReverseCpuKernel() override = default; + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t DoCompute(CpuKernelContext &ctx); + uint32_t GetInputAndCheck(CpuKernelContext &ctx); + + std::vector inputs_; + std::vector outputs_; +}; +} // namespace aicpu +#endif // AICPU_KERNELS_NORMALIZED_LOGNORMALREVERSE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/nuclear_norm.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/nuclear_norm.cc index a94fbcc8dea..87d26f96fcb 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/nuclear_norm.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/nuclear_norm.cc @@ -21,7 +21,6 @@ #include #include #include -#include "kernel_util.h" #include "utils/kernel_util.h" #define NoneN 1000 using namespace Eigen; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc index ea3dc755ade..c23c2901927 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.cc @@ -436,4 +436,4 @@ uint32_t ResizeBicubicGradCpuKernel::Compute(CpuKernelContext &ctx) { return KERNEL_STATUS_OK; } REGISTER_CPU_KERNEL(kResizeBicubicGrad, ResizeBicubicGradCpuKernel); -} // namespace aicpu \ No newline at end of file +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.h index d99426c9c66..b681bd25505 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/resize_bicubic_grad.h @@ -35,4 +35,4 @@ class ResizeBicubicGradCpuKernel : public CpuKernel { uint32_t GetInputAndCheck(CpuKernelContext &ctx); }; } // namespace aicpu -#endif \ No newline at end of file +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.cc new file mode 100644 index 00000000000..b5f521096e9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.cc @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "segment_max.h" + +#include "cpu_kernel_utils.h" +#include "utils/eigen_tensor.h" +#include "utils/kernel_util.h" +#include "cpu_kernel/common/runtime_tensor_desc.h" + +namespace { +const uint32_t kInputNum = 2; +const uint32_t kOutputNum = 1; +const char *kSegmentMax = "SegmentMax"; +const int64_t kDataSize = 2 * 1024; + +#define SEGMENTMAX_COMPUTE_CASE(DTYPE, TYPE1, TYPE2, CTX) \ + case (DTYPE): { \ + uint32_t result = SegmentMaxCompute(CTX); \ + if (result != KERNEL_STATUS_OK) { \ + KERNEL_LOG_ERROR("SegmentMax kernel compute failed."); \ + return result; \ + } \ + break; \ + } + +#define SEGMENTMAX_COMPUTE_CASE_ALL(TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_INT8, int8_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_INT16, int16_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_INT32, int32_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_INT64, int64_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_UINT8, uint8_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_UINT16, uint16_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_UINT32, uint32_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_UINT64, uint64_t, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_FLOAT16, Eigen::half, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_FLOAT, float, TYPE, CTX) \ + SEGMENTMAX_COMPUTE_CASE(DT_DOUBLE, double, TYPE, CTX) +} // namespace + +namespace aicpu { +uint32_t SegmentMaxCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "SegmentMax check input and output number failed."); + auto data_type = ctx.Input(0)->GetDataType(); + auto segment_ids_type = ctx.Input(1)->GetDataType(); + switch (segment_ids_type) { + case DT_INT32: { + switch (data_type) { + SEGMENTMAX_COMPUTE_CASE_ALL(int32_t, ctx) + default: + KERNEL_LOG_ERROR("Input[0] data type[%s] not supported.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + break; + } + case DT_INT64: { + switch (data_type) { + SEGMENTMAX_COMPUTE_CASE_ALL(int64_t, ctx) + default: + KERNEL_LOG_ERROR("Input[0] data type[%s] not supported.", DTypeStr(data_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + break; + } + default: { + KERNEL_LOG_ERROR("Input[1] data type[%s] not supported.", DTypeStr(segment_ids_type).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } + } + return KERNEL_STATUS_OK; +} + +template +uint32_t SegmentMaxCpuKernel::SegmentMaxCompute(CpuKernelContext &ctx) { + Tensor *input_x_data = ctx.Input(0); + auto input_x_addr = reinterpret_cast(input_x_data->GetData()); + auto input_x_shape = input_x_data->GetTensorShape(); + auto input_x_dims = input_x_shape->GetDimSizes(); + int64_t input_x_num = input_x_data->NumElements(); + Tensor *segment_ids_data = ctx.Input(1); + auto segment_ids_data_addr = reinterpret_cast(segment_ids_data->GetData()); + int64_t segment_ids_data_num = segment_ids_data->NumElements(); + input_x_dims[0] = segment_ids_data_addr[segment_ids_data_num - 1] + 1; + Tensor *output_data = ctx.Output(0); + auto output_data_addr = reinterpret_cast(output_data->GetData()); + auto output_data_shape = output_data->GetTensorShape(); + if (output_data_shape->GetDimSize(0) < input_x_dims[0]) { + KERNEL_LOG_ERROR("The number of segments of the segmentation result of segment_ids is too large."); + return KERNEL_STATUS_PARAM_INVALID; + } + output_data_shape->SetDimSizes(input_x_dims); + if (!output_data->SetTensorShape(output_data_shape.get())) { + KERNEL_LOG_ERROR("Set output shape failed."); + return KERNEL_STATUS_INNER_ERROR; + } + int64_t output_data_num = output_data->NumElements(); + for (int64_t i = 0; i < output_data_num; i++) { + output_data_addr[i] = static_cast(0); + } + std::vector segments_segment_ids; + if (segment_ids_data_num != (input_x_data->GetTensorShape()->GetDimSize(0))) { + KERNEL_LOG_ERROR("The amount of data for input[1] must be equal to the first dimension of input[0]."); + return KERNEL_STATUS_PARAM_INVALID; + } + if (segment_ids_data_addr[0] < 0) { + KERNEL_LOG_ERROR("Input[1] must be nonnegative data."); + return KERNEL_STATUS_PARAM_INVALID; + } + int64_t seg_tmp = 1; + for (int64_t i = 0; i < segment_ids_data_num - 1; i++) { + if (segment_ids_data_addr[i] > segment_ids_data_addr[i + 1]) { + KERNEL_LOG_ERROR("Input[1] must be an ascending ordered sequence."); + return KERNEL_STATUS_PARAM_INVALID; + } + if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) { + seg_tmp++; + } else { + segments_segment_ids.push_back(seg_tmp); + seg_tmp = 1; + } + if (i == segment_ids_data_num - ge::DIM_SIZE2) { + segments_segment_ids.push_back(seg_tmp); + } + } + const int64_t num_compare_per = input_x_num / (input_x_shape->GetDimSize(0)); + const int64_t num_segments_segment_ids = segments_segment_ids.size(); + if (num_segments_segment_ids < kDataSize) { + for (int64_t i = 0; i < num_segments_segment_ids; i++) { + int64_t count = segments_segment_ids[i]; + int64_t count_no = 0; + for (int64_t j = 0; j < i; j++) { + count_no += segments_segment_ids[j]; + } + int64_t input_addr_base = count_no * num_compare_per; + if (num_compare_per < kDataSize) { + for (int64_t j = 0; j < num_compare_per; j++) { + int64_t max_init_addr = input_addr_base + j; + T1 max_value = input_x_addr[max_init_addr]; + for (int64_t k = 1; k < count; k++) { + int cmp_addr = max_init_addr + k * num_compare_per; + if (max_value < input_x_addr[cmp_addr]) { + max_value = input_x_addr[cmp_addr]; + } + } + output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = max_value; + } + } else { + uint32_t min_core_num = 1; + int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + if (max_core_num > num_compare_per) { + max_core_num = num_compare_per; + } + auto shard_compute = [&](size_t start, size_t end) { + for (size_t j = start; j < end; j++) { + int64_t max_init_addr = input_addr_base + j; + T1 max_value = input_x_addr[max_init_addr]; + for (int64_t k = 1; k < count; k++) { + int cmp_addr = max_init_addr + k * num_compare_per; + if (max_value < input_x_addr[cmp_addr]) { + max_value = input_x_addr[cmp_addr]; + } + } + output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = max_value; + } + }; + KERNEL_HANDLE_ERROR( + CpuKernelUtils::ParallelFor(ctx, num_compare_per, num_compare_per / max_core_num, shard_compute), + "SegmentMax Compute failed."); + } + } + } else { + uint32_t min_core_num_seg = 1; + int64_t max_core_num_seg = std::max(min_core_num_seg, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2); + if (max_core_num_seg > num_segments_segment_ids) { + max_core_num_seg = num_segments_segment_ids; + } + auto shard_compute_seg = [&](size_t start_seg, size_t end_seg) { + for (size_t i = start_seg; i < end_seg; i++) { + int64_t count = segments_segment_ids[i]; + int64_t count_no = 0; + for (size_t j = 0; j < i; j++) { + count_no += segments_segment_ids[j]; + } + int64_t input_addr_base = count_no * num_compare_per; + for (int64_t j = 0; j < num_compare_per; j++) { + int64_t max_init_addr = input_addr_base + j; + T1 max_value = input_x_addr[max_init_addr]; + for (int64_t k = 1; k < count; k++) { + int cmp_addr = max_init_addr + k * num_compare_per; + if (max_value < input_x_addr[cmp_addr]) { + max_value = input_x_addr[cmp_addr]; + } + } + output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = max_value; + } + } + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, num_segments_segment_ids, + num_segments_segment_ids / max_core_num_seg, shard_compute_seg), + "SegmentMax Compute failed."); + } + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kSegmentMax, SegmentMaxCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.h new file mode 100644 index 00000000000..8f5ca43477d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/segment_max.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_NORMALIZED_SEGMENTMAX_H_ +#define AICPU_KERNELS_NORMALIZED_SEGMENTMAX_H_ + +#include "cpu_ops_kernel.h" + +namespace aicpu { +class SegmentMaxCpuKernel : public CpuKernel { + public: + SegmentMaxCpuKernel() = default; + ~SegmentMaxCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + static uint32_t SegmentMaxCompute(CpuKernelContext &ctx); +}; +} // namespace aicpu +#endif diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index 9528f931e22..b7cf5ad81ca 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -54,7 +54,6 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An static const std::set kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2DV1OpName, mindspore::kAdaptiveAvgPool2DGradV1OpName, mindspore::kBucketizeOpName, - mindspore::kCacheSwapTableOpName, mindspore::kCauchyOpName, mindspore::kChannelShuffleOpName, mindspore::kCholeskyOpName, @@ -252,7 +251,10 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An mindspore::kLogicalXorOpName, mindspore::kLogNormalReverseOpName, mindspore::kBetaincOpName, - mindspore::kLessEqualOpName}; + mindspore::kLessEqualOpName, + mindspore::kHSVToRGBOpName, + mindspore::kLuSolveOpName, + mindspore::kExtractGlimpseOpName}; static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels"; static const std::string kCpuKernelSoName = "mindspore_cpu_kernels"; diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index 66956120330..a0b920e850c 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -1160,7 +1160,7 @@ class PoissonNLLLoss(LossBase): Args: log_input (bool, optional): Whether use log input. Default: True. full (bool, optional): Whether include the Stirling approximation term in the loss calculation. Default: False. - eps (float, optional): Lower bound of `x` when calculating logarithms. Default: 1e-8. + eps (float, optional): Lower bound of `x` when calculating logarithms. Default: 1e-08. reduction (str, optional): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: 'mean'. diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 9bfb139b47e..da0ffdbc05c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -350,3 +350,5 @@ from .hsv_to_rgb import _hsv_to_rgb_aicpu from .im2col import _im2col_aicpu from .lu_solve import _lu_solve_aicpu from .relu_grad_v3 import _relu_grad_v3_aicpu +from .resize_bicubic import _resize_bicubic_aicpu +from .extract_glimpse import _extract_glimpse_aicpu diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 97ff1e2570f..56d2888266c 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -86,7 +86,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A MatrixLogarithm, MatrixPower, MatrixSolve, MatrixTriangularSolve, ReduceStd, STFT, NextAfter, Orgqr, Qr, RaggedRange, Digamma, Eig, EuclideanNorm, CompareAndBitpack, ComplexAbs, CumulativeLogsumexp, Gcd, Trace, TridiagonalMatMul, TrilIndices, TriuIndices, Zeta, - Roll, Lgamma, Logit) + Roll, Lgamma, Logit, MatrixSolveLs) from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose, DepthwiseConv2dNative, @@ -647,7 +647,8 @@ __all__ = [ "SparseSlice", "ResizeLinear1D", "ResizeBicubic", - "Logit" + "Logit", + "MatrixSolveLs" ] __custom__ = [ diff --git a/mindspore/python/mindspore/ops/operations/image_ops.py b/mindspore/python/mindspore/ops/operations/image_ops.py index 7bc79793707..4581a5514ca 100644 --- a/mindspore/python/mindspore/ops/operations/image_ops.py +++ b/mindspore/python/mindspore/ops/operations/image_ops.py @@ -465,10 +465,10 @@ class NonMaxSuppressionWithOverlaps(Primitive): Examples: >>> overlaps = Tensor(np.array([[0.6964692, 0.28613934, 0.22685145, 0.5513148], - [0.71946895, 0.42310646, 0.9807642, 0.6848297], - [0.4809319, 0.39211753, 0.343178, 0.7290497], - [0.43857226, 0.059677895, 0.39804426, 0.7379954] - ]), mstype.float32) + ... [0.71946895, 0.42310646, 0.9807642, 0.6848297], + ... [0.4809319, 0.39211753, 0.343178, 0.7290497], + ... [0.43857226, 0.059677895, 0.39804426, 0.7379954] + ... ]), mstype.float32) >>> scores = Tensor(np.array([0.18249173, 0.17545176, 0.53155136, 0.53182757]), mstype.float32) >>> max_output_size = Tensor(4, mstype.int32) >>> overlap_threshold = Tensor(0.1, mstype.float32) diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index c1e57668c5e..df8412b77c3 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -260,6 +260,7 @@ class Addcdiv(Primitive): Raises: TypeError: If dtype of `x1`, `x2`, `value`, `input_data` is not tensor. + TypeError: If dtype of `x1`, `x2`, `value`, `input_data` are not the same. ValueError: If `x1` could not be broadcast to `x2`. ValueError: If `value` could not be broadcast to `x1/x2`. ValueError: If `input_data` could not be broadcast to `value*(x1/x2)`. @@ -303,9 +304,7 @@ class Addcmul(Primitive): Raises: TypeError: If dtype of `x1`, `x2`, `value`, `input_data` is not tensor. - TypeError: If dtype of `input_data` is not one of: float32, float16, int32. - TypeError: If dtype of `x1` or `x2` is not one of: float32, float16, int32. - TypeError: If dtype of `value` is not one of: float32, float16, int32. + TypeError: If dtype of `x1`, `x2`, `value`, `input_data` are not the same. ValueError: If `x1` could not be broadcast to `x2`. ValueError: If `value` could not be broadcast to `x1` * `x2`. ValueError: If `input_data` could not be broadcast to `value*(x1*x2)`.