forked from mindspore-Ecosystem/mindspore
migrate 3 aicpu ops to branch r1.9
This commit is contained in:
parent
e68cd66dac
commit
a8f754ddc7
|
@ -68,3 +68,29 @@
|
|||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "shadowVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_utils.cc" "knownConditionTrueFalse"
|
||||
|
||||
# AICPU migration
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "redundantAssignment"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "constArgument"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "unknownMacro"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "useStlAlgorithm"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "constParameter"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "truncLongCastAssignment"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "passedByValue"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "uninitMemberVar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unsignedPositive"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "uninitvar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "shadowVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unsignedPositive"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "noConstructor"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "noExplicitConstructor"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "identicalConditionAfterEarlyExit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "uninitMemberVar"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "redundantInitialization"
|
||||
|
|
|
@ -78,3 +78,34 @@
|
|||
"mindspore/mindspore/core/mindrt/include/async/try.h" "runtime/explicit"
|
||||
"mindspore/mindspore/core/mindrt/include/async/failure.h" "runtime/explicit"
|
||||
"mindspore/mindspore/core/mindrt/include/async/defer.h" "runtime/explicit"
|
||||
|
||||
# AICPU migration
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/namespaces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/multiline_comment"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/parens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/alt_tokens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comments"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/string"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/arrays"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "legal/copyright"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/inheritance"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/int"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/empty_if_body"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/operators"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comma"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/indentation_namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/line_length"
|
||||
|
||||
|
|
|
@ -207,4 +207,9 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_
|
|||
mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run
|
||||
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
||||
|
||||
# AICPI migration
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd_update.cc:aicpu::ScatterNdUpdateCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/tensor_scatter_update.cc:aicpu::TensorScatterUpdateCpuKernel::Compute
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/scatter_nd.cc:aicpu::ScatterNdCpuKernel::Compute
|
|
@ -210,6 +210,7 @@ constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBack
|
|||
constexpr auto kGammaOpName = "Gamma";
|
||||
constexpr auto kGatherDGradV2OpName = "GatherDGradV2";
|
||||
constexpr auto kGatherDOpName = "GatherD";
|
||||
constexpr auto kGatherNdOpName = "GatherNd";
|
||||
constexpr auto kGatherOpName = "Gather";
|
||||
constexpr auto kGatherV2OpName = "Gather";
|
||||
constexpr auto kDeformableOffsetsGradOpName = "DeformableOffsetsGrad";
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_input_to_attr_registry.h"
|
||||
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
/*
|
||||
* Parameter is attr in AICPU, but is input in graph.
|
||||
* {
|
||||
* {op_name, {{pos_index, data_type}, ...},
|
||||
* ...
|
||||
* }
|
||||
*/
|
||||
std::map<string, std::map<size_t, std::string>> AicpuOpInputToAttrMap = {
|
||||
{kStridedSliceOpName, {{1, "listInt"}, {2, "listInt"}, {3, "listInt"}}}, {kExpandDimsOpName, {{1, "int"}}}};
|
||||
|
||||
bool GetAicpuOpInputToAttrInfo(const CNodePtr &kernel_node, std::map<size_t, std::string> *input_to_attr_info) {
|
||||
std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (AicpuOpInputToAttrMap.find(op_name) == AicpuOpInputToAttrMap.end()) {
|
||||
return false;
|
||||
} else {
|
||||
*input_to_attr_info = AicpuOpInputToAttrMap[op_name];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_INPUT_TO_ATTR_REGISTRY_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_INPUT_TO_ATTR_REGISTRY_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "kernel/kernel.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool GetAicpuOpInputToAttrInfo(const CNodePtr &kernel_node, std::map<size_t, std::string> *input_to_attr_info);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_INPUT_TO_ATTR_REGISTRY_H
|
|
@ -23,7 +23,6 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
set(AICPU_SRC
|
||||
${PROTO_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/kernel_base.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/kernel_log.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_async_event.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_pulse.cc
|
||||
|
@ -76,4 +75,12 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
|
||||
)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpu_kernel/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpu_kernel/common)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpu_kernel/cpu_proto)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpu_kernel/utils)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpu_kernel/)
|
||||
add_subdirectory(cpu_kernel)
|
||||
endif()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,50 +21,32 @@
|
|||
#include <iostream>
|
||||
#include <utility>
|
||||
#include "common/kernel_errcode.h"
|
||||
#include "toolchain/slog.h"
|
||||
|
||||
inline int64_t GetTid(void) {
|
||||
thread_local static const int64_t tid = syscall(__NR_gettid);
|
||||
return tid;
|
||||
}
|
||||
static const int LOG_COUNT = 0;
|
||||
|
||||
namespace aicpu {
|
||||
#define AICPU_LOG_DEBUG 0
|
||||
#define AICPU_LOG_INFO 1
|
||||
#define AICPU_LOG_WARN 2
|
||||
#define AICPU_LOG_ERROR 3
|
||||
#define AICPU_LOG_EVENT 0x10
|
||||
#define AICPU_MODULE_NAME static_cast<int32_t>(AICPU)
|
||||
#define KERNEL_MODULE "AICPU"
|
||||
|
||||
inline void PrintLog(const int level) { std::cerr << level << std::endl; }
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline void PrintLog(const int level, T &&head, Args &&... tail) {
|
||||
std::cerr << std::forward<T>(head) << " ";
|
||||
PrintLog(level, std::forward<Args>(tail)...);
|
||||
}
|
||||
|
||||
int LogSetLevel(int level);
|
||||
|
||||
int LogGetLevel(void);
|
||||
|
||||
bool CheckLogLevel(int log_level_check);
|
||||
|
||||
#define AICPU_LOGD(fmt, ...) \
|
||||
AICPU_LOG(AICPU_LOG_DEBUG, "%s:%s:%d[tid:%lu]:" #fmt, __FUNCTION__, __FILE__, __LINE__, GetTid(), ##__VA_ARGS__);
|
||||
#define AICPU_LOGI(fmt, ...) \
|
||||
AICPU_LOG(AICPU_LOG_INFO, "%s:%s:%d[tid:%lu]:" #fmt, __FUNCTION__, __FILE__, __LINE__, GetTid(), ##__VA_ARGS__);
|
||||
#define AICPU_LOGW(fmt, ...) \
|
||||
AICPU_LOG(AICPU_LOG_WARN, "%s:%s:%d[tid:%lu]:" #fmt, __FUNCTION__, __FILE__, __LINE__, GetTid(), ##__VA_ARGS__);
|
||||
#define AICPU_LOGE(fmt, ...) \
|
||||
AICPU_LOG(AICPU_LOG_ERROR, "%s:%s:%d[tid:%lu]:" #fmt, __FUNCTION__, __FILE__, __LINE__, GetTid(), ##__VA_ARGS__);
|
||||
#define AICPU_LOGEVENT(fmt, ...) \
|
||||
AICPU_LOG(AICPU_LOG_EVENT, "%s:%s:%d[tid:%lu]:" #fmt, __FUNCTION__, __FILE__, __LINE__, GetTid(), ##__VA_ARGS__);
|
||||
#define AICPU_LOG(level, fmt, ...) \
|
||||
do { \
|
||||
if (aicpu::CheckLogLevel(level)) { \
|
||||
aicpu::PrintLog(level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (LOG_COUNT != 0)
|
||||
#define AICPU_LOGD(fmt, ...) \
|
||||
dlog_debug(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define AICPU_LOGI(fmt, ...) \
|
||||
dlog_info(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define AICPU_LOGW(fmt, ...) \
|
||||
dlog_warn(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define AICPU_LOGE(fmt, ...) \
|
||||
dlog_error(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define AICPU_LOGEVENT(fmt, ...) \
|
||||
dlog_event(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
|
||||
#define AICPU_CHK_STATUS_RET(expr...) \
|
||||
do { \
|
||||
|
@ -91,5 +73,69 @@ bool CheckLogLevel(int log_level_check);
|
|||
AICPU_LOGE(logText); \
|
||||
return errorCode; \
|
||||
}
|
||||
|
||||
#define KERNEL_LOG_DEBUG(fmt, ...) \
|
||||
dlog_debug(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define KERNEL_LOG_INFO(fmt, ...) \
|
||||
dlog_info(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define KERNEL_LOG_WARN(fmt, ...) \
|
||||
dlog_warn(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define KERNEL_LOG_ERROR(fmt, ...) \
|
||||
dlog_error(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
#define KERNEL_LOG_EVENT(fmt, ...) \
|
||||
dlog_event(AICPU_MODULE_NAME, "[%s][%s:%d][tid:%lu]:" fmt, KERNEL_MODULE, __FUNCTION__, __LINE__, GetTid(), \
|
||||
##__VA_ARGS__);
|
||||
|
||||
#define KERNEL_CHECK_NULLPTR_VOID(value, logText...) \
|
||||
if (value == nullptr) { \
|
||||
AICPU_LOGE(logText); \
|
||||
return; \
|
||||
}
|
||||
|
||||
#define KERNEL_CHECK_FALSE(condition, errorCode, logText...) \
|
||||
if (!(condition)) { \
|
||||
AICPU_LOGE(logText); \
|
||||
return errorCode; \
|
||||
}
|
||||
|
||||
#define KERNEL_CHECK_NULLPTR(value, errorCode, logText...) \
|
||||
if (value == nullptr) { \
|
||||
AICPU_LOGE(logText); \
|
||||
return errorCode; \
|
||||
}
|
||||
|
||||
#define KERNEL_CHECK_ASSIGN_64S_MULTI(A, B, result, errorCode) \
|
||||
do { \
|
||||
if ((A) != 0 && (B) != 0 && ((INT64_MAX) / (A)) <= (B)) { \
|
||||
AICPU_LOGE("Integer reversed multiA: %llu * multiB: %llu", (A), (B)); \
|
||||
return errorCode; \
|
||||
} \
|
||||
(result) = ((A) * (B)); \
|
||||
} while (0)
|
||||
|
||||
#define KERNEL_CHECK_FALSE_VOID(condition, logText...) \
|
||||
if (!(condition)) { \
|
||||
AICPU_LOGE(logText); \
|
||||
return; \
|
||||
}
|
||||
|
||||
#define KERNEL_HANDLE_ERROR(expression, logText...) \
|
||||
; \
|
||||
do { \
|
||||
uint32_t ret = expression; \
|
||||
if (ret != static_cast<uint32_t>(KERNEL_STATUS_OK)) { \
|
||||
AICPU_LOGE(logText); \
|
||||
return ret; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define KERNEL_CHECK_FALSE_EXEC(condition, execExpr...) \
|
||||
if (!(condition)) { \
|
||||
execExpr; \
|
||||
}
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_OPS_AICPU_COMMON_KERNEL_LOG_H_
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
set(CPU_PROTO_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_proto/proto/cpu_attr.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_proto/proto/cpu_node_def.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_proto/proto/cpu_tensor_shape.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_proto/proto/cpu_tensor.proto
|
||||
)
|
||||
|
||||
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${CPU_PROTO_SRC})
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/common COMMON_LISTS)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_proto CPU_PROTO_LISTS)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/utils UTILS_LISTS)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/format_transfer FORMAT_TRANSFER_LISTS)
|
||||
set(CPU_SRC
|
||||
${COMMON_LISTS}
|
||||
${CPU_PROTO_LISTS}
|
||||
${UTILS_LISTS}
|
||||
${FORMAT_TRANSFER_LISTS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../aicpu_sharder/aicpu_context.cc
|
||||
)
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/ms_kernel MS_KERNELS)
|
||||
set(CPU_OPS_SRC
|
||||
${MS_KERNELS}
|
||||
)
|
||||
|
||||
add_library(mindspore_cpu_kernels SHARED
|
||||
${PROTO_SRCS}
|
||||
${CPU_SRC}
|
||||
${CPU_OPS_SRC}
|
||||
)
|
||||
|
||||
target_compile_options(mindspore_cpu_kernels PRIVATE
|
||||
-march=armv8-a
|
||||
-O2
|
||||
-fvisibility-inlines-hidden
|
||||
-fvisibility=hidden
|
||||
-fno-strict-aliasing
|
||||
-fno-common
|
||||
)
|
||||
|
||||
target_link_libraries(mindspore_cpu_kernels PRIVATE
|
||||
-ldl
|
||||
-shared
|
||||
PUBLIC
|
||||
${SECUREC_ARM_LIBRARY}
|
||||
-Wl,--whole-archive
|
||||
-Wl,--no-whole-archive
|
||||
-Wl,-Bsymbolic
|
||||
-rdynamic
|
||||
mindspore::protobuf_arm
|
||||
-pthread
|
||||
)
|
||||
|
||||
set(INSTALL_LIBRARY_DIR lib)
|
||||
install(TARGETS mindspore_cpu_kernels OPTIONAL
|
||||
EXPORT mindspore_cpu_kernels-targets
|
||||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/async_cpu_kernel.h"
|
||||
#include "cpu_kernel/common/notification.h"
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t AsyncCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
Notification n;
|
||||
uint32_t ret = ComputeAsync(ctx, [&n](uint32_t status) { n.Notify(); });
|
||||
n.WaitForNotification();
|
||||
return ret;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ASYNC_CPU_KERNEL_H
|
||||
#define ASYNC_CPU_KERNEL_H
|
||||
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY AsyncCpuKernel : public CpuKernel {
|
||||
public:
|
||||
using CpuKernel::CpuKernel;
|
||||
|
||||
using DoneCallback = std::function<void(uint32_t status)>;
|
||||
|
||||
virtual uint32_t ComputeAsync(CpuKernelContext &ctx, DoneCallback done) = 0;
|
||||
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // ASYNC_CPU_KERNEL_H
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cpu_kernel/common/async_event_util.h"
|
||||
#include <dlfcn.h>
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace {
|
||||
const char *kSharderPath = "/usr/lib64/libaicpu_sharder.so";
|
||||
const char *kNotifyWaitFunc = "AicpuNotifyWait";
|
||||
const char *kRegEventCbFunc = "AicpuRegEventCb";
|
||||
const char *kRegEventCbWithTimesFunc = "AicpuRegEventCbWithTimes";
|
||||
const char *kUnregEventCbFunc = "AicpuUnregEventCb";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
AsyncEventUtil &AsyncEventUtil::GetInstance() {
|
||||
static AsyncEventUtil async_event_util;
|
||||
return async_event_util;
|
||||
}
|
||||
|
||||
void AsyncEventUtil::InitEventUtil() {
|
||||
notify_wait_func_ = reinterpret_cast<NotifyWaitFunc>(dlsym(sharder_, kNotifyWaitFunc));
|
||||
if (notify_wait_func_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get Function[%s] address failed, error[%s]", kNotifyWaitFunc, dlerror());
|
||||
}
|
||||
reg_event_cb_func_ = reinterpret_cast<RegEventCbFunc>(dlsym(sharder_, kRegEventCbFunc));
|
||||
if (reg_event_cb_func_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get Function[%s] address failed, error[%s]", kRegEventCbFunc, dlerror());
|
||||
}
|
||||
reg_event_cb_with_times_func_ = reinterpret_cast<RegEventCbWithTimesFunc>(dlsym(sharder_, kRegEventCbWithTimesFunc));
|
||||
if (reg_event_cb_with_times_func_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get Function[%s] address failed, error[%s]", kRegEventCbWithTimesFunc, dlerror());
|
||||
}
|
||||
unreg_event_cb_func_ = reinterpret_cast<UnregEventCbFunc>(dlsym(sharder_, kUnregEventCbFunc));
|
||||
if (unreg_event_cb_func_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get Function[%s] address failed, error[%s]", kUnregEventCbFunc, dlerror());
|
||||
}
|
||||
}
|
||||
|
||||
AsyncEventUtil::AsyncEventUtil() {
|
||||
sharder_ = dlopen(kSharderPath, RTLD_LAZY | RTLD_GLOBAL);
|
||||
if (sharder_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Device sharder dlopen so [%s] failed, error[%s]", kSharderPath, dlerror());
|
||||
notify_wait_func_ = nullptr;
|
||||
reg_event_cb_func_ = nullptr;
|
||||
reg_event_cb_with_times_func_ = nullptr;
|
||||
unreg_event_cb_func_ = nullptr;
|
||||
} else {
|
||||
InitEventUtil();
|
||||
KERNEL_LOG_INFO("Device sharder dlopen so[%s] success.", kSharderPath);
|
||||
}
|
||||
}
|
||||
|
||||
AsyncEventUtil::~AsyncEventUtil() {
|
||||
if (sharder_ != nullptr) {
|
||||
(void)dlclose(sharder_);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncEventUtil::NotifyWait(void *notify_param, const uint32_t param_len) const {
|
||||
if (notify_wait_func_ != nullptr) {
|
||||
notify_wait_func_(notify_param, param_len);
|
||||
return;
|
||||
}
|
||||
KERNEL_LOG_WARN("Function[%s] is null", kNotifyWaitFunc);
|
||||
}
|
||||
|
||||
bool AsyncEventUtil::RegEventCb(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb) {
|
||||
if (reg_event_cb_func_ != nullptr) {
|
||||
return reg_event_cb_func_(event_id, sub_event_id, cb);
|
||||
}
|
||||
KERNEL_LOG_WARN("Function[%s] is null.", kRegEventCbFunc);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AsyncEventUtil::RegEventCb(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb, const int32_t times) {
|
||||
if (reg_event_cb_with_times_func_ != nullptr) {
|
||||
return reg_event_cb_with_times_func_(event_id, sub_event_id, cb, times);
|
||||
}
|
||||
KERNEL_LOG_WARN("Function[%s] is null.", kRegEventCbWithTimesFunc);
|
||||
return false;
|
||||
}
|
||||
|
||||
void AsyncEventUtil::UnregEventCb(const uint32_t event_id, const uint32_t sub_event_id) {
|
||||
if (unreg_event_cb_func_ != nullptr) {
|
||||
return unreg_event_cb_func_(event_id, sub_event_id);
|
||||
}
|
||||
KERNEL_LOG_WARN("Function[%s] is null.", kUnregEventCbFunc);
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_CONTEXT_COMMON_ASYNC_EVENT_H_
|
||||
#define AICPU_CONTEXT_COMMON_ASYNC_EVENT_H_
|
||||
|
||||
#include <functional>
|
||||
#include "aicpu_sharder/aicpu_context.h"
|
||||
|
||||
namespace aicpu {
|
||||
typedef void (*NotifyWaitFunc)(void *notify_param, const uint32_t param_len);
|
||||
typedef bool (*RegEventCbFunc)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb);
|
||||
typedef bool (*RegEventCbWithTimesFunc)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb, const int32_t times);
|
||||
typedef void (*UnregEventCbFunc)(const uint32_t event_id, const uint32_t sub_event_id);
|
||||
|
||||
class AsyncEventUtil {
|
||||
public:
|
||||
static AsyncEventUtil &GetInstance();
|
||||
|
||||
void NotifyWait(void *notify_param, const uint32_t param_len) const;
|
||||
|
||||
bool RegEventCb(const uint32_t event_id, const uint32_t sub_event_id, const std::function<void(void *)> &cb);
|
||||
|
||||
bool RegEventCb(const uint32_t event_id, const uint32_t sub_event_id, const std::function<void(void *)> &cb,
|
||||
const int32_t times);
|
||||
|
||||
void UnregEventCb(const uint32_t event_id, const uint32_t sub_event_id);
|
||||
|
||||
private:
|
||||
AsyncEventUtil();
|
||||
~AsyncEventUtil();
|
||||
void InitEventUtil();
|
||||
|
||||
private:
|
||||
void *sharder_;
|
||||
NotifyWaitFunc notify_wait_func_;
|
||||
RegEventCbFunc reg_event_cb_func_;
|
||||
RegEventCbWithTimesFunc reg_event_cb_with_times_func_;
|
||||
UnregEventCbFunc unreg_event_cb_func_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_ASYNC_EVENT_H_
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "aicpu_sharder/aicpu_context.h"
|
||||
#include "cpu_kernel/common/cpu_node_def.h"
|
||||
#include "cpu_kernel/common/device.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "proto/cpu_attr.pb.h"
|
||||
#include "proto/cpu_node_def.pb.h"
|
||||
#include "cpu_kernel/common/sharder.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
CpuKernelContext::CpuKernelContext(DeviceType type) {
|
||||
Device *device = new (std::nothrow) Device(type);
|
||||
if (device != nullptr) {
|
||||
device_.reset(device);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CpuKernelContext::Init(NodeDef *node_def) {
|
||||
KERNEL_CHECK_NULLPTR(node_def, KERNEL_STATUS_PARAM_INVALID, "Node def is null.")
|
||||
op_ = node_def->GetOpType();
|
||||
KERNEL_LOG_DEBUG("Construct the ctx of the op[%s] begin.", op_.c_str());
|
||||
for (int32_t i = 0; i < node_def->InputsSize(); i++) {
|
||||
auto input = node_def->MutableInputs(i);
|
||||
KERNEL_CHECK_NULLPTR(input, KERNEL_STATUS_PARAM_INVALID, "Get input[%d] tensor failed in op[%s].", i, op_.c_str())
|
||||
inputs_.emplace_back(std::move(input));
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < node_def->OutputsSize(); i++) {
|
||||
auto output = node_def->MutableOutputs(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%d] tensor failed in op[%s].", i, op_.c_str())
|
||||
outputs_.emplace_back(std::move(output));
|
||||
}
|
||||
|
||||
auto attrMap = node_def->Attrs();
|
||||
for (auto iter = attrMap.begin(); iter != attrMap.end(); ++iter) {
|
||||
auto attr_value_ptr = iter->second;
|
||||
KERNEL_CHECK_NULLPTR(attr_value_ptr, KERNEL_STATUS_PARAM_INVALID, "Get attr[%s] failed in op[%s].",
|
||||
iter->first.c_str(), op_.c_str())
|
||||
auto ret = attrs_.insert(std::make_pair(iter->first, std::move(attr_value_ptr)));
|
||||
if (!ret.second) {
|
||||
KERNEL_LOG_ERROR("Insert attr[%s] failed in op[%s].", iter->first.c_str(), op_.c_str());
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
KERNEL_LOG_DEBUG("Construct the ctx of the op[%s] success.", op_.c_str());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* get op type.
|
||||
* @return string: op type
|
||||
*/
|
||||
std::string CpuKernelContext::GetOpType() const { return op_; }
|
||||
|
||||
/*
|
||||
* get input tensor.
|
||||
* @return Tensor *: not null->success, null->failed
|
||||
*/
|
||||
Tensor *CpuKernelContext::Input(uint32_t index) const {
|
||||
if (index >= inputs_.size()) {
|
||||
KERNEL_LOG_WARN(
|
||||
"Input index[%u] should be less than input tensors total "
|
||||
"size[%zu].",
|
||||
index, inputs_.size());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs_[index].get();
|
||||
}
|
||||
|
||||
/*
|
||||
* get output tensor.
|
||||
* @return Tensor *: not null->success, null->failed
|
||||
*/
|
||||
Tensor *CpuKernelContext::Output(uint32_t index) const {
|
||||
if (index >= outputs_.size()) {
|
||||
KERNEL_LOG_WARN(
|
||||
"Output index[%u] should be less than output tensors total "
|
||||
"size[%zu].",
|
||||
index, outputs_.size());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return outputs_[index].get();
|
||||
}
|
||||
|
||||
/*
|
||||
* get attr.
|
||||
* @return AttrValue *: not null->success, null->failed
|
||||
*/
|
||||
AttrValue *CpuKernelContext::GetAttr(std::string name) const {
|
||||
auto it = attrs_.find(name);
|
||||
if (it == attrs_.end()) {
|
||||
KERNEL_LOG_WARN("Attr[%s] is not exist.", name.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return (it->second).get();
|
||||
}
|
||||
|
||||
/*
|
||||
* get input size.
|
||||
* @return uint32_t: input size
|
||||
*/
|
||||
uint32_t CpuKernelContext::GetInputsSize() const { return inputs_.size(); }
|
||||
|
||||
/*
|
||||
* get output size.
|
||||
* @return uint32_t: output size
|
||||
*/
|
||||
uint32_t CpuKernelContext::GetOutputsSize() const { return outputs_.size(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,637 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/cpu_kernel_cache.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cce/aicpu_engine_struct.h"
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
#include "cpu_kernel/common/cpu_kernel_register.h"
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/runtime_tensor_desc.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace {
|
||||
// max io address number limit is 1024
|
||||
constexpr uint32_t kMaxIoAddrNumParamLen = 1024;
|
||||
// max LRU cache number is 256
|
||||
constexpr uint32_t kMaxLRUCacheNum = 256;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* Init kernel cache.
|
||||
*/
|
||||
int32_t CpuKernelCache::InitParameter() {
|
||||
if (!GetSessionFlag()) {
|
||||
SetCapacity(kMaxLRUCacheNum);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* update framework output tensor shape.
|
||||
*/
|
||||
uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg &ext_info_msg, const CpuKernelContext &ctx) const {
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
for (size_t i = 0; i < ctx.GetOutputsSize(); ++i) {
|
||||
Tensor *output = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
|
||||
auto shape = output->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
|
||||
|
||||
for (int32_t index = 0; index < shape->GetDims(); ++index) {
|
||||
ext_info_msg.output_shape_and_type[i]->dims[index] = shape->GetDimSize(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto it = ext_info_msg.unknown_shape_output_index_addr.begin();
|
||||
it != ext_info_msg.unknown_shape_output_index_addr.end(); ++it) {
|
||||
Tensor *output = ctx.Output(it->first);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%u] failed.", it->first)
|
||||
auto shape = output->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%u] shape failed.", it->first)
|
||||
ge::RuntimeTensorDesc *tensor_desc = reinterpret_cast<ge::RuntimeTensorDesc *>(static_cast<uintptr_t>(it->second));
|
||||
KERNEL_CHECK_FALSE((shape->GetDims() <= ge::kMaxDimSize), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Max shape size[32], but got output[%u] shape size[%d]", it->first, shape->GetDims())
|
||||
tensor_desc->shape[0] = shape->GetDims();
|
||||
for (int32_t index = 0; index < shape->GetDims(); ++index) {
|
||||
tensor_desc->shape[index + 1] = shape->GetDimSize(index);
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* get shape information from framework.
|
||||
*/
|
||||
void CpuKernelCache::GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *shape_and_type,
|
||||
std::vector<int64_t> &dims) const {
|
||||
for (uint32_t index = 0; index < FWKAdapter::kMaxShapeDims; ++index) {
|
||||
// LLONG_MIN for dim end flag
|
||||
if (shape_and_type->dims[index] == LLONG_MIN) {
|
||||
break;
|
||||
}
|
||||
int64_t dim_value = shape_and_type->dims[index];
|
||||
KERNEL_LOG_INFO("Get extend shape[%u] is [%ld]", index, dim_value);
|
||||
dims.emplace_back(dim_value);
|
||||
}
|
||||
}
|
||||
|
||||
void CpuKernelCache::GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> &dims) const {
|
||||
for (size_t index = 0; index < len; ++index) {
|
||||
KERNEL_LOG_INFO("Get arrays shape[%zu] is [%ld]", index, shape[index]);
|
||||
dims.emplace_back(shape[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* update tensor information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg &ext_info_msg,
|
||||
CpuKernelContext &ctx) const {
|
||||
KERNEL_LOG_INFO("Update tensor info begin.");
|
||||
if (io_addrs.size() != ctx.GetInputsSize() + ctx.GetOutputsSize()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Addr number[%zu] is not equal to the sum of inputs[%zu] and "
|
||||
"output[%zu].",
|
||||
io_addrs.size(), ctx.GetInputsSize(), ctx.GetOutputsSize());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if ((ext_info_msg.unknown_shape) && ((ext_info_msg.input_shape_and_type.size() != ctx.GetInputsSize()) ||
|
||||
(ext_info_msg.output_shape_and_type.size() != ctx.GetOutputsSize()))) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Input shape_and_type size error, input size[%zu], input "
|
||||
"shape_and_type "
|
||||
"size[%zu], output size[%zu], output shape_and_type size[%zu].",
|
||||
ctx.GetInputsSize(), ext_info_msg.input_shape_and_type.size(), ctx.GetOutputsSize(),
|
||||
ext_info_msg.output_shape_and_type.size());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
size_t addr_index = 0;
|
||||
for (size_t i = 0; i < ctx.GetInputsSize(); ++i, ++addr_index) {
|
||||
Tensor *input = ctx.Input(i);
|
||||
KERNEL_CHECK_NULLPTR(input, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] failed.", i)
|
||||
auto iter = ext_info_msg.unknown_shape_input_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg.unknown_shape_input_index_addr.end()) {
|
||||
iter->second = io_addrs[addr_index];
|
||||
ge::RuntimeTensorDesc *tensor_desc =
|
||||
reinterpret_cast<ge::RuntimeTensorDesc *>(static_cast<uintptr_t>(io_addrs[addr_index]));
|
||||
std::vector<int64_t> dims;
|
||||
KERNEL_CHECK_FALSE((tensor_desc->shape[0] <= ge::kMaxDimSize), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Max shape size[%lld], but got input[%zu] shape size[%lld]", ge::kMaxDimSize, i,
|
||||
tensor_desc->shape[0])
|
||||
GetDimsFromArrays(&(tensor_desc->shape[1]), static_cast<size_t>(tensor_desc->shape[0]), dims);
|
||||
auto shape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
input->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(tensor_desc->data_addr)));
|
||||
} else {
|
||||
input->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(io_addrs[addr_index])));
|
||||
}
|
||||
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
std::vector<int64_t> dims;
|
||||
GetDimsFromShapeAndType(ext_info_msg.input_shape_and_type[i], dims);
|
||||
auto shape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
}
|
||||
|
||||
KERNEL_CHECK_FALSE((input->NumElements() >= 0), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Input[%zu] data elements number must be >= 0, "
|
||||
"got size[%lld].",
|
||||
i, input->NumElements());
|
||||
input->SetDataSize(std::max(uint64_t(0), static_cast<uint64_t>(input->CalcDataSizeByShape())));
|
||||
KERNEL_LOG_INFO("Set input[%zu] addr[%lu] success.", i, io_addrs[addr_index]);
|
||||
}
|
||||
|
||||
bool no_tiling = ext_info_msg.unknown_shape_output_index_addr.empty();
|
||||
|
||||
for (size_t i = 0; i < ctx.GetOutputsSize(); i++, addr_index++) {
|
||||
Tensor *output = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
|
||||
auto iter = ext_info_msg.unknown_shape_output_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg.unknown_shape_output_index_addr.end()) {
|
||||
iter->second = io_addrs[addr_index];
|
||||
ge::RuntimeTensorDesc *tensor_desc =
|
||||
reinterpret_cast<ge::RuntimeTensorDesc *>(static_cast<uintptr_t>(io_addrs[addr_index]));
|
||||
output->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(tensor_desc->data_addr)));
|
||||
} else {
|
||||
output->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(io_addrs[addr_index])));
|
||||
}
|
||||
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
std::vector<int64_t> dims;
|
||||
GetDimsFromShapeAndType(ext_info_msg.output_shape_and_type[i], dims);
|
||||
auto shape = output->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
}
|
||||
|
||||
KERNEL_CHECK_FALSE((ext_info_msg.unknown_shape || (!no_tiling) || (output->NumElements() >= 0)),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Output[%zu] data elements number must be >= 0 "
|
||||
"when known shape, got size[%lld].",
|
||||
i, output->NumElements());
|
||||
output->SetDataSize(std::max(uint64_t(0), static_cast<uint64_t>(output->CalcDataSizeByShape())));
|
||||
KERNEL_LOG_INFO("Set output[%zu] addr[%lu] success.", i, io_addrs[addr_index]);
|
||||
}
|
||||
KERNEL_LOG_INFO("Update tensor info success.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* parse extend tensor shape types information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) const {
|
||||
if (ext_info->infoLen != sizeof(int32_t)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend shape type failed, as info length must be [%zu], but got "
|
||||
"[%u].",
|
||||
sizeof(int32_t), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
unknown_shape = true;
|
||||
KERNEL_LOG_INFO("Kernel has unknown shape.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* parse extend tensor shape and types information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtShapeAndType(bool unknown_shape, FWKAdapter::ExtInfo *ext_info,
|
||||
std::vector<FWKAdapter::ShapeAndType *> &shape_and_type) const {
|
||||
if (!unknown_shape) {
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t size = (ext_info->infoLen) / sizeof(FWKAdapter::ShapeAndType);
|
||||
KERNEL_LOG_INFO("Parse extend shape and type, size[%u].", size);
|
||||
uint32_t check = (ext_info->infoLen) % sizeof(FWKAdapter::ShapeAndType);
|
||||
if (check != 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend info length[%u] failed, must be integer multiple of the "
|
||||
"[%zu].",
|
||||
ext_info->infoLen, sizeof(FWKAdapter::ShapeAndType));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto shapes = reinterpret_cast<FWKAdapter::ShapeAndType *>(ext_info->infoMsg);
|
||||
for (uint32_t index = 0; index < size; ++index) {
|
||||
shape_and_type.emplace_back(&shapes[index]);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* parse extend session information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t &kernel_id) const {
|
||||
// no overflow
|
||||
KERNEL_LOG_INFO("Parse extend session info.");
|
||||
auto need_len = sizeof(SessionInfo);
|
||||
if (ext_info->infoLen != need_len) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend session info failed, as info length must be "
|
||||
"[%zu], but got [%u].",
|
||||
sizeof(SessionInfo), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto session = reinterpret_cast<SessionInfo *>(ext_info->infoMsg);
|
||||
kernel_id = session->kernelId;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* get bit status.
|
||||
*/
|
||||
bool CpuKernelCache::GetBitStatus(uint64_t num, uint64_t pos) { return ((num & (1 << pos)) != 0); }
|
||||
|
||||
/*
|
||||
* parse bitmap information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) {
|
||||
if (ext_info->infoLen != sizeof(int64_t)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend bitmap failed, as info length must be [%zu], but got "
|
||||
"[%u].",
|
||||
sizeof(int64_t), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint64_t bit_map = *(reinterpret_cast<const int64_t *>(ext_info->infoMsg));
|
||||
unknown_shape = (!GetBitStatus(bit_map, 0));
|
||||
KERNEL_LOG_INFO("Unknown_shape_ is [%d].", unknown_shape);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
// parse async wait info
|
||||
uint32_t CpuKernelCache::ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t &wait_type, uint32_t &wait_id) const {
|
||||
if (ext_info->infoLen != sizeof(FWKAdapter::AsyncWait)) {
|
||||
KERNEL_LOG_ERROR("Parse extend async wait failed, as info length must be [%zu], but got [%u].",
|
||||
sizeof(FWKAdapter::AsyncWait), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
FWKAdapter::AsyncWait *async_info = reinterpret_cast<FWKAdapter::AsyncWait *>(ext_info->infoMsg);
|
||||
wait_type = async_info->waitType;
|
||||
wait_id = async_info->waitId;
|
||||
KERNEL_LOG_INFO("async wait type [%u], notify_id[%u].", wait_type, wait_id);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelCache::ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info,
|
||||
std::map<uint32_t, uint64_t> &unknown_shape_index_addr) const {
|
||||
if (ext_info->infoLen % sizeof(uint32_t) != 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse unknown shape index extend info length[%u] failed, must be "
|
||||
"integer multiple of the [%zu].",
|
||||
ext_info->infoLen, sizeof(uint32_t));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
uint32_t size = ext_info->infoLen / sizeof(uint32_t);
|
||||
KERNEL_LOG_INFO("Parse extend unknown shape index, size[%u].", size);
|
||||
auto indexes = reinterpret_cast<uint32_t *>(ext_info->infoMsg);
|
||||
for (uint32_t i = 0U; i < size; ++i) {
|
||||
unknown_shape_index_addr[indexes[i]] = 0U;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* parse extend information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg &ext_info_msg) {
|
||||
KERNEL_LOG_INFO("Parse extend info and update shape begin.");
|
||||
uint32_t offset = 0;
|
||||
ext_info_msg.async_flag = false;
|
||||
FWKAdapter::ExtInfo *ext_info = nullptr;
|
||||
char *extInfo_buf = reinterpret_cast<char *>(static_cast<uintptr_t>(param_head->extInfoAddr));
|
||||
while (offset + sizeof(FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
|
||||
ext_info = reinterpret_cast<FWKAdapter::ExtInfo *>(extInfo_buf + offset);
|
||||
if (ext_info == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Extend info is nullptr, extInfo length[%u], extend info addr[%p], "
|
||||
"offset[%u].",
|
||||
param_head->extInfoLength, param_head->extInfoAddr, offset);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint32_t ret = KERNEL_STATUS_OK;
|
||||
switch (ext_info->infoType) {
|
||||
case FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE:
|
||||
ret = ParseExtShapeType(ext_info, ext_info_msg.unknown_shape);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_INPUT_SHAPE:
|
||||
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.input_shape_and_type);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_OUTPUT_SHAPE:
|
||||
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.output_shape_and_type);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_SESSION_INFO:
|
||||
ext_info_msg.has_sess_info = true;
|
||||
ret = ParseExtSessionInfo(ext_info, ext_info_msg.kernel_id);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_BITMAP:
|
||||
ret = ParseExtBitMap(ext_info, ext_info_msg.unknown_shape);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT: {
|
||||
ret = ParseAsyncWait(ext_info, ext_info_msg.wait_type, ext_info_msg.wait_id);
|
||||
bool flag = ((ret == KERNEL_STATUS_OK) &&
|
||||
(ext_info_msg.wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_NULL) &&
|
||||
(ext_info_msg.wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_INVALID));
|
||||
if (flag) {
|
||||
ext_info_msg.async_flag = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_INPUT_INDEX:
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_input_index_addr);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_OUTPUT_INDEX:
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_output_index_addr);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_INFO("Ignore infoType[%d], infoLen[%u].", ext_info->infoType, ext_info->infoLen);
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// not overflow
|
||||
offset += FWKAdapter::kExtInfoHeadSize;
|
||||
offset += ext_info->infoLen;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* parse io address.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> &io_addrs, char *&nodedef,
|
||||
uint32_t &nodedef_len) const {
|
||||
auto param_base = reinterpret_cast<char *>(param_head);
|
||||
char *extend_param_base = param_base + sizeof(AicpuParamHead);
|
||||
uint32_t extend_param_len = param_head->length - sizeof(AicpuParamHead);
|
||||
|
||||
if (param_head->ioAddrNum > 0) {
|
||||
if (param_head->ioAddrNum > kMaxIoAddrNumParamLen) {
|
||||
KERNEL_LOG_ERROR("Param ioAddrNum[%u] is over %u.", param_head->ioAddrNum, kMaxIoAddrNumParamLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint32_t addr_len = param_head->ioAddrNum * sizeof(uint64_t);
|
||||
if (extend_param_len < addr_len) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Extend param is not enough for io addr, ioAddrNum[%u], "
|
||||
"extend_param_len[%u].",
|
||||
param_head->ioAddrNum, extend_param_len);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto io_addr_base = reinterpret_cast<uint64_t *>(extend_param_base);
|
||||
for (uint32_t i = 0; i < param_head->ioAddrNum; ++i) {
|
||||
io_addrs.push_back(io_addr_base[i]);
|
||||
}
|
||||
extend_param_base = extend_param_base + addr_len;
|
||||
extend_param_len -= addr_len;
|
||||
}
|
||||
|
||||
if (extend_param_len < sizeof(uint32_t)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Extend param is not enough for addr, needLen[%zu], "
|
||||
"extend_param_len[%u].",
|
||||
sizeof(uint32_t), extend_param_len);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
nodedef_len = *reinterpret_cast<uint32_t *>(extend_param_base);
|
||||
extend_param_base += sizeof(uint32_t);
|
||||
nodedef = extend_param_base;
|
||||
KERNEL_LOG_INFO("Parse io addr success, io number[%zu], nodedef length[%u].", io_addrs.size(), nodedef_len);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
*/
|
||||
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContext(bool has_sess_info, uint64_t kernel_id,
|
||||
const char *nodedef, uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> &nodedef_proto) {
|
||||
std::shared_ptr<CpuKernelContext> ctx = nullptr;
|
||||
KERNEL_LOG_INFO("Get cpu kernel context begin, kernel id[%lu].", kernel_id);
|
||||
if (has_sess_info) {
|
||||
CpuCacheData *cache = GetCache(kernel_id);
|
||||
if (cache != nullptr) {
|
||||
KERNEL_LOG_INFO("Get kernel from cache success.");
|
||||
return cache->context;
|
||||
}
|
||||
}
|
||||
|
||||
std::string str_data(nodedef, nodedef_len);
|
||||
nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr), "Create node def failed.")
|
||||
if (!nodedef_proto->ParseFromString(str_data)) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
CpuKernelContext *tmp = new (std::nothrow) CpuKernelContext(DEVICE);
|
||||
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context failed.")
|
||||
ctx = std::shared_ptr<CpuKernelContext>(tmp);
|
||||
uint32_t ret = ctx->Init(nodedef_proto.get());
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (has_sess_info) {
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
|
||||
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
|
||||
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
|
||||
SetCache(kernel_id, cache_shared);
|
||||
KERNEL_LOG_INFO("Cache cpu kernel data success, kernel id[%lu].", kernel_id);
|
||||
}
|
||||
KERNEL_LOG_INFO("Get cpu kernel context success, kernel id[%lu].", kernel_id);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
/*
|
||||
* run kernel.
|
||||
*/
|
||||
int32_t CpuKernelCache::RunKernel(void *param) {
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
std::vector<uint64_t> io_addrs;
|
||||
char *nodedef = nullptr;
|
||||
uint32_t nodedef_len = 0;
|
||||
uint32_t ret = ParseIoAddr(param_head, io_addrs, nodedef, nodedef_len);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
std::shared_ptr<ExtInfoMsg> ext_info_msg = nullptr;
|
||||
try {
|
||||
ext_info_msg = std::make_shared<ExtInfoMsg>();
|
||||
} catch (std::bad_alloc &) {
|
||||
KERNEL_LOG_ERROR("Create ExtInfoMsg failed");
|
||||
return -1;
|
||||
}
|
||||
ret = ParseExtMsg(param_head, *ext_info_msg);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeDef> nodedef_proto = nullptr;
|
||||
auto ctx =
|
||||
GetCpuKernelContext(ext_info_msg->has_sess_info, ext_info_msg->kernel_id, nodedef, nodedef_len, nodedef_proto);
|
||||
KERNEL_CHECK_NULLPTR(ctx, KERNEL_STATUS_INNER_ERROR, "Get cpu kernel context from buff failed.")
|
||||
|
||||
ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (ext_info_msg->async_flag) {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernelAsync(
|
||||
*ctx, ext_info_msg->wait_type, ext_info_msg->wait_id,
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(*ext_info_msg, *ctx); });
|
||||
} else {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernel(*ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
|
||||
}
|
||||
if (ret == KERNEL_STATUS_END_OF_SEQUENCE) {
|
||||
return ret;
|
||||
}
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* run kernel with blockdim info.
|
||||
*/
|
||||
int32_t CpuKernelCache::RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info) {
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
std::vector<uint64_t> io_addrs;
|
||||
char *nodedef = nullptr;
|
||||
uint32_t nodedef_len = 0;
|
||||
uint32_t ret = ParseIoAddr(param_head, io_addrs, nodedef, nodedef_len);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
std::shared_ptr<ExtInfoMsg> ext_info_msg = nullptr;
|
||||
try {
|
||||
ext_info_msg = std::make_shared<ExtInfoMsg>();
|
||||
} catch (std::bad_alloc &) {
|
||||
KERNEL_LOG_ERROR("Create ExtInfoMsg failed");
|
||||
return -1;
|
||||
}
|
||||
ret = ParseExtMsg(param_head, *ext_info_msg);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeDef> nodedef_proto = nullptr;
|
||||
auto ctx = GetCpuKernelContextWithBlock(ext_info_msg, nodedef, nodedef_len, nodedef_proto, blkdim_info);
|
||||
KERNEL_CHECK_NULLPTR(ctx, KERNEL_STATUS_INNER_ERROR, "Get cpu kernel context from buff failed.")
|
||||
|
||||
ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (ext_info_msg->async_flag) {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernelAsync(
|
||||
*ctx, ext_info_msg->wait_type, ext_info_msg->wait_id,
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(*ext_info_msg, *ctx); });
|
||||
} else {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernel(*ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
|
||||
}
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
*/
|
||||
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContextWithBlock(std::shared_ptr<ExtInfoMsg> extInfoMsg,
|
||||
const char *nodedef,
|
||||
uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> &nodedef_proto,
|
||||
struct BlkDimInfo *blkdim_info) {
|
||||
std::shared_ptr<CpuKernelContext> ctx = nullptr;
|
||||
KERNEL_LOG_INFO("Get cpu kernel context with block info begin. kernel id[%lu]", extInfoMsg->kernel_id);
|
||||
if (extInfoMsg->has_sess_info && blkdim_info->blockNum == 1) {
|
||||
CpuCacheData *cache = GetCache(extInfoMsg->kernel_id);
|
||||
if (cache != nullptr) {
|
||||
KERNEL_LOG_INFO("Get kernel from cache success.");
|
||||
return cache->context;
|
||||
}
|
||||
}
|
||||
std::string str_data(nodedef, nodedef_len);
|
||||
nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr),
|
||||
"Create node def with block info failed.")
|
||||
if (!nodedef_proto->ParseFromString(str_data)) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (blkdim_info->blockNum != 1) {
|
||||
auto blockNum = CpuKernelUtils::CreateAttrValue();
|
||||
blockNum->SetInt(blkdim_info->blockNum);
|
||||
nodedef_proto->AddAttrs("block_num", blockNum.get());
|
||||
|
||||
auto blockid = CpuKernelUtils::CreateAttrValue();
|
||||
blockid->SetInt(blkdim_info->blockId);
|
||||
nodedef_proto->AddAttrs("block_id", blockid.get());
|
||||
KERNEL_LOG_INFO("AddAttrs block info , blockNum[%u] blockId[%u].", blkdim_info->blockNum, blkdim_info->blockId);
|
||||
}
|
||||
|
||||
CpuKernelContext *tmp = new (std::nothrow) CpuKernelContext(DEVICE);
|
||||
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context with block info failed.")
|
||||
ctx = std::shared_ptr<CpuKernelContext>(tmp);
|
||||
uint32_t ret = ctx->Init(nodedef_proto.get());
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (extInfoMsg->has_sess_info) {
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
|
||||
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
|
||||
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
|
||||
SetCache(extInfoMsg->kernel_id, cache_shared);
|
||||
KERNEL_LOG_INFO("Cache cpu kernel data success. kernel id[%lu]", extInfoMsg->kernel_id);
|
||||
}
|
||||
KERNEL_LOG_INFO("Get cpu kernel context success. kernel id[%lu]", extInfoMsg->kernel_id);
|
||||
return ctx;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,205 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CPU_KERNEL_CACHE_H_
|
||||
#define AICPU_CPU_KERNEL_CACHE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "aicpu/common/aicpu_task_struct.h"
|
||||
#include "cce/fwk_adpt_struct.h"
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "cpu_kernel/common/cpu_node_def.h"
|
||||
#include "cpu_kernel/common/kernel_cache.h"
|
||||
#include "cpu_kernel/common/device_cpu_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
struct ExtInfoMsg {
|
||||
bool has_sess_info = false;
|
||||
uint64_t kernel_id = 0U;
|
||||
bool unknown_shape = false;
|
||||
bool async_flag = false;
|
||||
uint8_t wait_type = 0U;
|
||||
uint32_t wait_id = 0U;
|
||||
std::vector<FWKAdapter::ShapeAndType *> input_shape_and_type;
|
||||
std::vector<FWKAdapter::ShapeAndType *> output_shape_and_type;
|
||||
std::map<uint32_t, uint64_t> unknown_shape_input_index_addr;
|
||||
std::map<uint32_t, uint64_t> unknown_shape_output_index_addr;
|
||||
};
|
||||
|
||||
struct CpuCacheData {
|
||||
std::shared_ptr<NodeDef> proto = nullptr;
|
||||
std::shared_ptr<CpuKernelContext> context = nullptr;
|
||||
CpuCacheData(std::shared_ptr<NodeDef> proto, std::shared_ptr<CpuKernelContext> context)
|
||||
: proto(proto), context(context) {}
|
||||
};
|
||||
|
||||
class CpuKernelCache : public KernelCache<CpuCacheData> {
|
||||
public:
|
||||
CpuKernelCache() = default;
|
||||
~CpuKernelCache() = default;
|
||||
|
||||
/*
|
||||
* Init kernel cache.
|
||||
* @return int32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
int32_t InitParameter() override;
|
||||
|
||||
/*
|
||||
* run kernel.
|
||||
* @param param: kernel context
|
||||
* @return int32_t: 0 indicates success, whilWe the others fail
|
||||
*/
|
||||
int32_t RunKernel(void *param) override;
|
||||
|
||||
/*
|
||||
* run kernel with blockDimInfo.
|
||||
* @param param: kernel context and blkDimInfo
|
||||
* @return int32_t: 0 indicates success, whilWe the others fail
|
||||
*/
|
||||
int32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info) override;
|
||||
|
||||
private:
|
||||
CpuKernelCache(const CpuKernelCache &) = delete;
|
||||
CpuKernelCache(CpuKernelCache &&) = delete;
|
||||
CpuKernelCache &operator=(const CpuKernelCache &) = delete;
|
||||
CpuKernelCache &operator=(CpuKernelCache &&) = delete;
|
||||
|
||||
/*
|
||||
* update framework output tensor shape.
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t UpdateFWKOutputShape(ExtInfoMsg &ext_info_msg, const CpuKernelContext &ctx) const;
|
||||
|
||||
/*
|
||||
* get shape information from framework.
|
||||
* @param dims: shape information
|
||||
*/
|
||||
void GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *shape_and_type, std::vector<int64_t> &dims) const;
|
||||
|
||||
/*
|
||||
* get shape information from arrays.
|
||||
* @param dims: shape information
|
||||
*/
|
||||
void GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> &dims) const;
|
||||
|
||||
/*
|
||||
* update tensor information.
|
||||
* @param ctx: kernel context
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg &ext_info_msg, CpuKernelContext &ctx) const;
|
||||
|
||||
/*
|
||||
* parse extend tensor shape types information.
|
||||
* @param ext_info: extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) const;
|
||||
|
||||
/*
|
||||
* parse extend tensor bitmap information.
|
||||
* @param ext_info: extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape);
|
||||
|
||||
/*
|
||||
* parse extend tensor shape and types information.
|
||||
* @param ext_info: extend information
|
||||
* @param shape_and_type: shape and types from extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtShapeAndType(bool unknown_shape, FWKAdapter::ExtInfo *ext_info,
|
||||
std::vector<FWKAdapter::ShapeAndType *> &shape_and_type) const;
|
||||
|
||||
/*
|
||||
* parse extend unknown shape index information.
|
||||
* @param ext_info: extend information
|
||||
* @param unknown_shape_index_addr: unknown shape index and addr map
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info,
|
||||
std::map<uint32_t, uint64_t> &unknown_shape_index_addr) const;
|
||||
|
||||
/*
|
||||
* parse extend session information.
|
||||
* @param ext_info: extend information
|
||||
* @param kernel_id: kernel id from extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t &kernel_id) const;
|
||||
|
||||
/*
|
||||
* parse extend async wait info
|
||||
* @param ext_info : extend information
|
||||
* @param wait_type: event wait type
|
||||
* @param wait_id : event wait id
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t &wait_type, uint32_t &wait_id) const;
|
||||
|
||||
/*
|
||||
* parse extend information.
|
||||
* @param param_head: kernel context
|
||||
* @param ext_info_msg: extend info msg
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg &ext_info_msg);
|
||||
|
||||
/*
|
||||
* parse io address.
|
||||
* @param param_head: kernel context
|
||||
* @param io_addrs: kernel inputs and outputs address
|
||||
* @param nodedef: kernel node def
|
||||
* @param nodedef_len: kernel node def length
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> &io_addrs, char *&nodedef,
|
||||
uint32_t &nodedef_len) const;
|
||||
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
* @param has_sess_info: whether has session info
|
||||
* @param kernel_id: kernel id, the key of cache
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
std::shared_ptr<CpuKernelContext> GetCpuKernelContext(bool has_sess_info, uint64_t kernel_id, const char *nodedef,
|
||||
uint32_t nodedef_len, std::shared_ptr<NodeDef> &nodedef_proto);
|
||||
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
* @param has_sess_info: whether has session info
|
||||
* @param kernel_id: kernel id, the key of cache
|
||||
* @param blkDimInfo: kernel blockdim info
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
std::shared_ptr<CpuKernelContext> GetCpuKernelContextWithBlock(std::shared_ptr<ExtInfoMsg> extInfoMsg,
|
||||
const char *nodedef, uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> &nodedef_proto,
|
||||
struct BlkDimInfo *blkdim_info);
|
||||
|
||||
/*
|
||||
* get bit status on pos
|
||||
* @param num: input number
|
||||
* @param pos: bit pos
|
||||
* @return bool: bit is 1 or 0
|
||||
*/
|
||||
bool GetBitStatus(uint64_t num, uint64_t pos);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CPU_KERNEL_CACHE_H_
|
|
@ -0,0 +1,190 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/cpu_kernel_register.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "aicpu_sharder/aicpu_context.h"
|
||||
#include "aicpu_sharder/aicpu_async_event.h"
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "cpu_kernel/common/async_event_util.h"
|
||||
#include "cpu_kernel/common/async_cpu_kernel.h"
|
||||
|
||||
namespace {
|
||||
#define TYPE_REGISTAR(type, fun) type##Registerar(type, fun)
|
||||
// protect creatorMap_
|
||||
std::mutex g_mutex;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* register kernel.
|
||||
*/
|
||||
bool RegistCpuKernel(const std::string &type, const KERNEL_CREATOR_FUN &fun) {
|
||||
CpuKernelRegister::Registerar TYPE_REGISTAR(type, fun);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* get instance.
|
||||
* @return CpuKernelRegister &: CpuKernelRegister instance
|
||||
*/
|
||||
CpuKernelRegister &CpuKernelRegister::Instance() {
|
||||
static CpuKernelRegister instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
/*
|
||||
* get cpu kernel.
|
||||
* param opType: the op type of kernel
|
||||
* @return shared_ptr<CpuKernel>: cpu kernel ptr
|
||||
*/
|
||||
std::shared_ptr<CpuKernel> CpuKernelRegister::GetCpuKernel(const std::string &opType) {
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
auto iter = creatorMap_.find(opType);
|
||||
if (iter != creatorMap_.end()) {
|
||||
return iter->second();
|
||||
}
|
||||
KERNEL_LOG_WARN("The kernel[%s] is not registered.", opType.c_str());
|
||||
return std::shared_ptr<CpuKernel>(nullptr);
|
||||
}
|
||||
|
||||
/*
|
||||
* get all cpu kernel registered op types.
|
||||
* @return std::vector<string>: all cpu kernel registered op type
|
||||
*/
|
||||
std::vector<std::string> CpuKernelRegister::GetAllRegisteredOpTypes() const {
|
||||
std::vector<std::string> ret;
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
for (auto iter = creatorMap_.begin(); iter != creatorMap_.end(); ++iter) {
|
||||
ret.push_back(iter->first);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* run cpu kernel.
|
||||
* param ctx: context of kernel
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t CpuKernelRegister::RunCpuKernel(CpuKernelContext &ctx) {
|
||||
std::string type = ctx.GetOpType();
|
||||
KERNEL_LOG_INFO("RunCpuKernel[%s] begin.", type.c_str());
|
||||
auto kernel = GetCpuKernel(type);
|
||||
if (kernel == nullptr) {
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (aicpu::SetThreadLocalCtx != nullptr) {
|
||||
if (aicpu::SetThreadLocalCtx(aicpu::kContextKeyOpName, type) != aicpu::AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("Set kernel name[%s] to context failed.", type.c_str());
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
if (aicpu::SetOpname != nullptr) {
|
||||
(void)aicpu::SetOpname(type);
|
||||
}
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
uint32_t ret = kernel->Compute(ctx);
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
double dr_us = std::chrono::duration<double, std::micro>(end - start).count();
|
||||
KERNEL_LOG_EVENT("RunCpuKernel[%s], run time is [%lf] us.", type.c_str(), dr_us);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
KERNEL_LOG_INFO("RunCpuKernel[%s] success.", type.c_str());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelRegister::RunCpuKernelAsync(CpuKernelContext &ctx, const uint8_t wait_type, const uint32_t wait_id,
|
||||
std::function<uint32_t()> cb) {
|
||||
std::string type = ctx.GetOpType();
|
||||
KERNEL_LOG_INFO("RunCpuKernelAsync[%s] begin.", type.c_str());
|
||||
auto kernel = GetCpuKernel(type);
|
||||
if (kernel == nullptr) {
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
AsyncCpuKernel *async_kernel = dynamic_cast<AsyncCpuKernel *>(kernel.get());
|
||||
if (async_kernel == nullptr) {
|
||||
KERNEL_LOG_ERROR("kernel name[%s] does not hava async impl.", type.c_str());
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (aicpu::SetThreadLocalCtx != nullptr) {
|
||||
if (aicpu::SetThreadLocalCtx(aicpu::kContextKeyOpName, type) != aicpu::AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("Set kernel name[%s] to context failed.", type.c_str());
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (aicpu::SetThreadLocalCtx(aicpu::kContextKeyWaitType, std::to_string(wait_type)) != aicpu::AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("Set wait type to context failed.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
if (aicpu::SetThreadLocalCtx(aicpu::kContextKeyWaitId, std::to_string(wait_id)) != aicpu::AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("Set wait id to context failed.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
if (aicpu::SetOpname != nullptr) {
|
||||
(void)aicpu::SetOpname(type);
|
||||
}
|
||||
std::shared_ptr<AsyncNotifyInfo> notify_info = std::make_shared<AsyncNotifyInfo>();
|
||||
aicpu::GetTaskAndStreamId(¬ify_info->task_id, ¬ify_info->stream_id);
|
||||
(void)aicpu::aicpuGetContext(¬ify_info->ctx);
|
||||
notify_info->wait_type = wait_type;
|
||||
notify_info->wait_id = wait_id;
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
auto done = [notify_info, kernel, type, cb, start](uint32_t status) {
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
double dr_us = std::chrono::duration<double, std::micro>(end - start).count();
|
||||
KERNEL_LOG_EVENT("RunCpuKernel[%s], run time is [%lf] us.", type.c_str(), dr_us);
|
||||
if (status == KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_INFO("RunCpuKernel[%s] success.", type.c_str());
|
||||
status = cb();
|
||||
}
|
||||
notify_info->ret_code = status;
|
||||
void *param = reinterpret_cast<void *>(notify_info.get());
|
||||
KERNEL_LOG_INFO(
|
||||
"RunCpuKernelAsync notify event wait, wait_type[%u], "
|
||||
"wait_id[%u], task_id[%u], stream_id[%u], status[%u].",
|
||||
notify_info->wait_type, notify_info->wait_id, notify_info->task_id, notify_info->stream_id,
|
||||
notify_info->ret_code);
|
||||
AsyncEventUtil::GetInstance().NotifyWait(param, sizeof(AsyncNotifyInfo));
|
||||
};
|
||||
return async_kernel->ComputeAsync(ctx, done);
|
||||
}
|
||||
|
||||
CpuKernelRegister::Registerar::Registerar(const std::string &type, const KERNEL_CREATOR_FUN &fun) {
|
||||
CpuKernelRegister::Instance().Register(type, fun);
|
||||
}
|
||||
|
||||
// register creator, this function will call in the constructor
|
||||
void CpuKernelRegister::Register(const std::string &type, const KERNEL_CREATOR_FUN &fun) {
|
||||
std::unique_lock<std::mutex> lock(g_mutex);
|
||||
std::map<std::string, KERNEL_CREATOR_FUN>::iterator iter = creatorMap_.find(type);
|
||||
if (iter != creatorMap_.end()) {
|
||||
KERNEL_LOG_WARN("Register[%s] creator already exist", type.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
creatorMap_[type] = fun;
|
||||
KERNEL_LOG_DEBUG("Kernel[%s] register successfully", type.c_str());
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_INC_REGISTAR_H_
|
||||
#define AICPU_CONTEXT_INC_REGISTAR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernelRegister {
|
||||
public:
|
||||
/*
|
||||
* get instance.
|
||||
* @return CpuKernelRegister &: CpuKernelRegister instance
|
||||
*/
|
||||
static CpuKernelRegister &Instance();
|
||||
|
||||
/*
|
||||
* get cpu kernel.
|
||||
* param op_type: the op type of kernel
|
||||
* @return shared_ptr<CpuKernel>: cpu kernel ptr
|
||||
*/
|
||||
std::shared_ptr<CpuKernel> GetCpuKernel(const std::string &opType);
|
||||
|
||||
/*
|
||||
* get all cpu kernel registered op types.
|
||||
* @return std::vector<string>: all cpu kernel registered op type
|
||||
*/
|
||||
std::vector<std::string> GetAllRegisteredOpTypes() const;
|
||||
|
||||
/*
|
||||
* run cpu kernel.
|
||||
* param ctx: context of kernel
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t RunCpuKernel(CpuKernelContext &ctx);
|
||||
|
||||
/*
|
||||
* run async cpu kernel.
|
||||
* @param ctx: context of kernel
|
||||
* @param wait_type : event wait type
|
||||
* @param wait_id : event wait id
|
||||
* @param cb : callback function
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t RunCpuKernelAsync(CpuKernelContext &ctx, const uint8_t wait_type, const uint32_t wait_id,
|
||||
std::function<uint32_t()> cb);
|
||||
|
||||
// CpuKernel registration function to register different types of kernel to
|
||||
// the factory
|
||||
class Registerar {
|
||||
public:
|
||||
Registerar(const std::string &type, const KERNEL_CREATOR_FUN &fun);
|
||||
~Registerar() = default;
|
||||
|
||||
Registerar(const Registerar &) = delete;
|
||||
Registerar(Registerar &&) = delete;
|
||||
Registerar &operator=(const Registerar &) = delete;
|
||||
Registerar &operator=(Registerar &&) = delete;
|
||||
};
|
||||
|
||||
protected:
|
||||
CpuKernelRegister() = default;
|
||||
~CpuKernelRegister() = default;
|
||||
|
||||
CpuKernelRegister(const CpuKernelRegister &) = delete;
|
||||
CpuKernelRegister(CpuKernelRegister &&) = delete;
|
||||
CpuKernelRegister &operator=(const CpuKernelRegister &) = delete;
|
||||
CpuKernelRegister &operator=(CpuKernelRegister &&) = delete;
|
||||
|
||||
// register creator, this function will call in the constructor
|
||||
void Register(const std::string &type, const KERNEL_CREATOR_FUN &fun);
|
||||
|
||||
private:
|
||||
std::map<std::string, KERNEL_CREATOR_FUN> creatorMap_; // kernel map
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_INC_REGISTAR_H_
|
|
@ -0,0 +1,206 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
|
||||
#include "cpu_kernel/cpu_proto/attr_value_impl.h"
|
||||
#include "cpu_kernel/common/device.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/cpu_proto/node_def_impl.h"
|
||||
#include "cpu_kernel/common/sharder.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_impl.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_shape_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* construct Tensor for memory self-management.
|
||||
*/
|
||||
std::shared_ptr<Tensor> CpuKernelUtils::CreateTensor() {
|
||||
auto proto_ptr = new (std::nothrow) aicpuops::Tensor();
|
||||
KERNEL_CHECK_NULLPTR(proto_ptr, std::shared_ptr<Tensor>(nullptr), "New Tensor proto failed.")
|
||||
|
||||
auto wrapper_ptr = new (std::nothrow) TensorImpl(proto_ptr, [](aicpuops::Tensor *p) { delete p; });
|
||||
if (wrapper_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("New TensorProto failed");
|
||||
delete proto_ptr;
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto class_ptr = new (std::nothrow) Tensor(wrapper_ptr);
|
||||
if (class_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("New Tensor failed");
|
||||
delete wrapper_ptr;
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
return std::shared_ptr<Tensor>(class_ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> CpuKernelUtils::CreateTensor(TensorImpl *tensor) {
|
||||
KERNEL_CHECK_NULLPTR(tensor, std::shared_ptr<Tensor>(nullptr), "Tensor is null.")
|
||||
auto class_ptr = new (std::nothrow) Tensor(tensor);
|
||||
KERNEL_CHECK_NULLPTR(class_ptr, std::shared_ptr<Tensor>(nullptr), "New Tensor failed.")
|
||||
return std::shared_ptr<Tensor>(class_ptr);
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor impl.
|
||||
*/
|
||||
std::shared_ptr<TensorImpl> CpuKernelUtils::GetImpl(const Tensor *tensor) { return tensor->impl_; }
|
||||
|
||||
/*
|
||||
* get tensor name.
|
||||
*/
|
||||
std::string CpuKernelUtils::GetTensorName(const Tensor *tensor) {
|
||||
auto impl = GetImpl(tensor);
|
||||
KERNEL_CHECK_NULLPTR(impl, std::string(), "Get Tensor impl failed.")
|
||||
return impl->GetName();
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor name.
|
||||
*/
|
||||
void CpuKernelUtils::SetTensorName(const std::string &name, std::shared_ptr<Tensor> &tensor) {
|
||||
KERNEL_LOG_INFO("Set tensor name[%s]", name.c_str());
|
||||
auto impl = GetImpl(tensor.get());
|
||||
KERNEL_CHECK_NULLPTR_VOID(impl, "Get Tensor impl failed.")
|
||||
impl->SetName(name);
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorShape> CpuKernelUtils::CreateTensorShape() {
|
||||
auto proto_ptr = new (std::nothrow) aicpuops::TensorShape();
|
||||
KERNEL_CHECK_NULLPTR(proto_ptr, std::shared_ptr<TensorShape>(nullptr), "New TensorShape proto failed.")
|
||||
|
||||
auto wrapper_ptr = new (std::nothrow) TensorShapeImpl(proto_ptr, [](aicpuops::TensorShape *p) { delete p; });
|
||||
if (wrapper_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new TensorShapeImpl failed");
|
||||
delete proto_ptr;
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
auto class_ptr = new (std::nothrow) TensorShape(wrapper_ptr);
|
||||
if (class_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new TensorShape failed");
|
||||
delete wrapper_ptr;
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
return std::shared_ptr<TensorShape>(class_ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorShape> CpuKernelUtils::CreateTensorShape(TensorShapeImpl *tensor_shape) {
|
||||
KERNEL_CHECK_NULLPTR(tensor_shape, std::shared_ptr<TensorShape>(nullptr), "Tensor shape proto is null.")
|
||||
auto class_ptr = new (std::nothrow) TensorShape(tensor_shape);
|
||||
KERNEL_CHECK_NULLPTR(class_ptr, std::shared_ptr<TensorShape>(nullptr), "New TensorShape failed.")
|
||||
return std::shared_ptr<TensorShape>(class_ptr);
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor shape impl.
|
||||
*/
|
||||
std::shared_ptr<TensorShapeImpl> CpuKernelUtils::GetImpl(const TensorShape *tensor_shape) {
|
||||
return tensor_shape->impl_;
|
||||
}
|
||||
|
||||
/*
|
||||
* construct AttrValue for memory self-management.
|
||||
*/
|
||||
std::shared_ptr<AttrValue> CpuKernelUtils::CreateAttrValue() {
|
||||
auto proto_ptr = new (std::nothrow) aicpuops::AttrValue();
|
||||
KERNEL_CHECK_NULLPTR(proto_ptr, std::shared_ptr<AttrValue>(nullptr), "New AttrValue proto failed.")
|
||||
|
||||
auto wrapper_ptr = new (std::nothrow) AttrValueImpl(proto_ptr, [](aicpuops::AttrValue *p) { delete p; });
|
||||
if (wrapper_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new AttrValueImpl failed");
|
||||
delete proto_ptr;
|
||||
return std::shared_ptr<AttrValue>(nullptr);
|
||||
}
|
||||
|
||||
auto class_ptr = new (std::nothrow) AttrValue(wrapper_ptr);
|
||||
if (class_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new AttrValue failed");
|
||||
delete wrapper_ptr;
|
||||
return std::shared_ptr<AttrValue>(nullptr);
|
||||
}
|
||||
|
||||
return std::shared_ptr<AttrValue>(class_ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<AttrValue> CpuKernelUtils::CreateAttrValue(AttrValueImpl *impl) {
|
||||
KERNEL_CHECK_NULLPTR(impl, std::shared_ptr<AttrValue>(nullptr), "Impl is null.")
|
||||
auto class_ptr = new (std::nothrow) AttrValue(impl);
|
||||
KERNEL_CHECK_NULLPTR(class_ptr, std::shared_ptr<AttrValue>(nullptr), "New AttrValue failed.")
|
||||
return std::shared_ptr<AttrValue>(class_ptr);
|
||||
}
|
||||
|
||||
/*
|
||||
* get attr value impl.
|
||||
*/
|
||||
std::shared_ptr<AttrValueImpl> CpuKernelUtils::GetImpl(const AttrValue *attr_value) { return attr_value->impl_; }
|
||||
|
||||
/*
|
||||
* construct NodeDef for memory self-management.
|
||||
*/
|
||||
std::shared_ptr<NodeDef> CpuKernelUtils::CreateNodeDef() {
|
||||
auto proto_ptr = new (std::nothrow) aicpuops::NodeDef();
|
||||
KERNEL_CHECK_NULLPTR(proto_ptr, std::shared_ptr<NodeDef>(nullptr), "New NodeDef proto failed.")
|
||||
|
||||
auto wrapper_ptr = new (std::nothrow) NodeDefImpl(proto_ptr, [](aicpuops::NodeDef *p) { delete p; });
|
||||
if (wrapper_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new NodeDefImpl failed");
|
||||
delete proto_ptr;
|
||||
return std::shared_ptr<NodeDef>(nullptr);
|
||||
}
|
||||
|
||||
auto class_ptr = new (std::nothrow) NodeDef(wrapper_ptr);
|
||||
if (class_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("new NodeDef failed");
|
||||
delete wrapper_ptr;
|
||||
return std::shared_ptr<NodeDef>(nullptr);
|
||||
}
|
||||
|
||||
return std::shared_ptr<NodeDef>(class_ptr);
|
||||
}
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t CpuKernelUtils::ParallelFor(const CpuKernelContext &ctx, int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) {
|
||||
KERNEL_CHECK_NULLPTR(ctx.device_, KERNEL_STATUS_INNER_ERROR, "Device is null.")
|
||||
|
||||
const Sharder *sharder = ctx.device_->GetSharder();
|
||||
KERNEL_CHECK_NULLPTR(sharder, KERNEL_STATUS_INNER_ERROR, "Get sharder is null.")
|
||||
|
||||
sharder->ParallelFor(total, perUnitSize, work);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @return CPU number
|
||||
*/
|
||||
uint32_t CpuKernelUtils::GetCPUNum(const CpuKernelContext &ctx) {
|
||||
KERNEL_CHECK_NULLPTR(ctx.device_, 0, "Device is null.")
|
||||
|
||||
const Sharder *sharder = ctx.device_->GetSharder();
|
||||
KERNEL_CHECK_NULLPTR(sharder, 0, "Get sharder is null.")
|
||||
|
||||
return sharder->GetCPUNum();
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_INC_UTILS_H_
|
||||
#define AICPU_CONTEXT_INC_UTILS_H_
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "cpu_kernel/common/cpu_node_def.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernelUtils {
|
||||
public:
|
||||
/*
|
||||
* create Tensor.
|
||||
* @return std::shared_ptr<Tensor>: Tensor ptr
|
||||
*/
|
||||
static std::shared_ptr<Tensor> CreateTensor();
|
||||
|
||||
/*
|
||||
* create Tensor.
|
||||
* @param tensor: Tensor impl
|
||||
* @return std::shared_ptr<Tensor>: Tensor ptr
|
||||
*/
|
||||
static std::shared_ptr<Tensor> CreateTensor(TensorImpl *tensor);
|
||||
|
||||
/*
|
||||
* get tensor impl.
|
||||
*/
|
||||
static std::shared_ptr<TensorImpl> GetImpl(const Tensor *tensor);
|
||||
|
||||
/*
|
||||
* get tensor name.
|
||||
*/
|
||||
static std::string GetTensorName(const Tensor *tensor);
|
||||
|
||||
/*
|
||||
* set tensor name.
|
||||
*/
|
||||
static void SetTensorName(const std::string &name, std::shared_ptr<Tensor> &tensor);
|
||||
|
||||
/*
|
||||
* create Tensor shape.
|
||||
* @return std::shared_ptr<TensorShape>: TensorShape ptr
|
||||
*/
|
||||
static std::shared_ptr<TensorShape> CreateTensorShape();
|
||||
|
||||
/*
|
||||
* create Tensor Shape.
|
||||
* @param tensorShape: Tensor shape impl
|
||||
* @return std::shared_ptr<TensorShape>: TensorShape ptr
|
||||
*/
|
||||
static std::shared_ptr<TensorShape> CreateTensorShape(TensorShapeImpl *tensor_shape);
|
||||
|
||||
/*
|
||||
* get tensor shape impl.
|
||||
*/
|
||||
static std::shared_ptr<TensorShapeImpl> GetImpl(const TensorShape *tensorShape);
|
||||
|
||||
/*
|
||||
* create attr value.
|
||||
* @return std::shared_ptr<AttrValue>: attr value ptr
|
||||
*/
|
||||
static std::shared_ptr<AttrValue> CreateAttrValue();
|
||||
|
||||
/*
|
||||
* create attr value.
|
||||
* @param attr_value: attr value impl
|
||||
* @return std::shared_ptr<AttrValue>: attr value ptr
|
||||
*/
|
||||
static std::shared_ptr<AttrValue> CreateAttrValue(AttrValueImpl *attr_value);
|
||||
|
||||
/*
|
||||
* get attr value impl.
|
||||
*/
|
||||
static std::shared_ptr<AttrValueImpl> GetImpl(const AttrValue *attr_value);
|
||||
|
||||
/*
|
||||
* create node def.
|
||||
* @return std::shared_ptr<NodeDef>: node def ptr
|
||||
*/
|
||||
static std::shared_ptr<NodeDef> CreateNodeDef();
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
* @param ctx: context info of kernel
|
||||
* @param total: size of total work
|
||||
* @param per_unit_size: expect size of per unit work
|
||||
* @param work: process of per unit work
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
static uint32_t ParallelFor(const CpuKernelContext &ctx, int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work);
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @param ctx: context info of kernel
|
||||
* @return CPU number
|
||||
*/
|
||||
static uint32_t GetCPUNum(const CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_INC_UTILS_H_
|
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_INC_NODE_DEF_H_
|
||||
#define AICPU_CONTEXT_INC_NODE_DEF_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class NodeDefImpl;
|
||||
class AICPU_VISIBILITY NodeDef {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
NodeDef() = delete;
|
||||
~NodeDef() = default;
|
||||
|
||||
/*
|
||||
* parse parameter from string.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool ParseFromString(const std::string &str);
|
||||
|
||||
/*
|
||||
* serialize string to node def.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SerializeToString(std::string &str) const;
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
* @param op: op type
|
||||
*/
|
||||
void SetOpType(const std::string &op);
|
||||
|
||||
/*
|
||||
* get op type of node def.
|
||||
* @return string: op type
|
||||
*/
|
||||
std::string GetOpType() const;
|
||||
|
||||
/*
|
||||
* add input tensor to node def.
|
||||
* @return shared_ptr<Tensor>: not null->success, null->failed
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddInputs();
|
||||
|
||||
/*
|
||||
* add output tensor to node def.
|
||||
* @return shared_ptr<Tensor>: not null->success, null->failed
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddOutputs();
|
||||
|
||||
/*
|
||||
* add attr to node def.
|
||||
* @param name: attr name
|
||||
* @param attr: attr need to add
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool AddAttrs(const std::string &name, const AttrValue *attr);
|
||||
|
||||
/*
|
||||
* get input tensor size of node def.
|
||||
* @return int32_t: input tensor size of node def
|
||||
*/
|
||||
int32_t InputsSize() const;
|
||||
|
||||
/*
|
||||
* get output tensor size of node def.
|
||||
* @return int32_t: input tensor size of node def
|
||||
*/
|
||||
int32_t OutputsSize() const;
|
||||
|
||||
/*
|
||||
* get input tensor of node def.
|
||||
* @param index: index of input tensor
|
||||
* @return shared_ptr<Tensor>: input tensor ptr of node def
|
||||
*/
|
||||
std::shared_ptr<Tensor> MutableInputs(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get output tensor of node def.
|
||||
* @param index: index of output tensor
|
||||
* @return shared_ptr<Tensor>: output tensor ptr of node def
|
||||
*/
|
||||
std::shared_ptr<Tensor> MutableOutputs(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get attr of node def.
|
||||
* @return unordered_map<std::string, std::shared_ptr<AttrValue>>: attrs of
|
||||
* node def
|
||||
*/
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue> > Attrs() const;
|
||||
|
||||
private:
|
||||
explicit NodeDef(NodeDefImpl *impl);
|
||||
|
||||
private:
|
||||
std::shared_ptr<NodeDefImpl> impl_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_INC_NODE_DEF_H_
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cpu_kernel/common/device.h"
|
||||
|
||||
#include <new>
|
||||
|
||||
#include "cpu_kernel/common/device_sharder.h"
|
||||
#include "cpu_kernel/common/host_sharder.h"
|
||||
|
||||
namespace aicpu {
|
||||
Device::Device(DeviceType device) : device_(device), sharder_(InitSharder(device)){};
|
||||
|
||||
Device::~Device() {
|
||||
if (sharder_ != nullptr) {
|
||||
delete sharder_;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get device type.
|
||||
* @return DeviceType: HOST/DEVICE
|
||||
*/
|
||||
DeviceType Device::GetDeviceType() const { return device_; }
|
||||
|
||||
/*
|
||||
* get sharder.
|
||||
* @return Sharder *: host or device sharder
|
||||
*/
|
||||
const Sharder *Device::GetSharder() const {
|
||||
if (sharder_ != nullptr) {
|
||||
return sharder_;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/*
|
||||
* init sharder.
|
||||
* param device: type of device
|
||||
* @return Sharder *: not null->success, null->success
|
||||
*/
|
||||
Sharder *Device::InitSharder(DeviceType device_type) const {
|
||||
if (device_type == DEVICE) {
|
||||
return new (std::nothrow) DeviceSharder(device_type);
|
||||
} else {
|
||||
return new (std::nothrow) HostSharder(device_type);
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_DEVICE_H
|
||||
#define AICPU_CONTEXT_COMMON_DEVICE_H
|
||||
|
||||
#include "cpu_kernel/common/sharder.h"
|
||||
|
||||
namespace aicpu {
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(DeviceType device);
|
||||
|
||||
~Device();
|
||||
|
||||
/*
|
||||
* get device type.
|
||||
* @return DeviceType: HOST/DEVICE
|
||||
*/
|
||||
DeviceType GetDeviceType() const;
|
||||
|
||||
/*
|
||||
* get sharder.
|
||||
* @return Sharder *: host or device sharder
|
||||
*/
|
||||
const Sharder *GetSharder() const;
|
||||
|
||||
private:
|
||||
Device(const Device &) = delete;
|
||||
Device(Device &&) = delete;
|
||||
Device &operator=(const Device &) = delete;
|
||||
Device &operator=(Device &&) = delete;
|
||||
|
||||
/*
|
||||
* init sharder.
|
||||
* param device_type: type of device
|
||||
* @return Sharder *: not null->success, null->success
|
||||
*/
|
||||
Sharder *InitSharder(DeviceType device_type) const;
|
||||
|
||||
private:
|
||||
DeviceType device_; // type of device
|
||||
Sharder *sharder_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_DEVICE_H
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/device_cpu_kernel.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "aicpu_sharder/aicpu_context.h"
|
||||
#include "cce/aicpu_engine_struct.h"
|
||||
#include "cce/fwk_adpt_struct.h"
|
||||
#include "cpu_kernel/common/cpu_kernel_cache.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/session_cache.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
using namespace aicpu;
|
||||
namespace {
|
||||
// max param len limit 10k.
|
||||
constexpr uint32_t kMaxParamLen = 10240;
|
||||
// max extend info len limit 20k.
|
||||
constexpr uint32_t kMaxExtendLen = 20480;
|
||||
const std::string kContextKeyStreamId = "streamId";
|
||||
|
||||
uint32_t ParseExtSessionInfo(AicpuParamHead *param_head, SessionInfo *&session) {
|
||||
KERNEL_LOG_INFO("Parse extend session info begin.");
|
||||
uint32_t offset = 0;
|
||||
FWKAdapter::ExtInfo *ext_info = nullptr;
|
||||
char *ext_info_buf = reinterpret_cast<char *>(static_cast<uintptr_t>(param_head->extInfoAddr));
|
||||
while (offset + sizeof(FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
|
||||
ext_info = reinterpret_cast<FWKAdapter::ExtInfo *>(ext_info_buf + offset);
|
||||
if (ext_info == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Extend info is nullptr, extend info length[%u], extend info "
|
||||
"offset[%u].",
|
||||
param_head->extInfoLength, offset);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (ext_info->infoType == FWKAdapter::FWK_ADPT_EXT_SESSION_INFO) {
|
||||
auto need_len = sizeof(SessionInfo);
|
||||
if (ext_info->infoLen != need_len) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend session info failed, as info length must be "
|
||||
"[%zu], but %u.",
|
||||
sizeof(SessionInfo), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
session = reinterpret_cast<SessionInfo *>(ext_info->infoMsg);
|
||||
KERNEL_LOG_INFO("Parse extend session info success.");
|
||||
}
|
||||
|
||||
// not overflow
|
||||
offset += FWKAdapter::kExtInfoHeadSize;
|
||||
offset += ext_info->infoLen;
|
||||
}
|
||||
|
||||
KERNEL_LOG_INFO("Parse extend session info end.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t RunCpuKernel(void *param) {
|
||||
KERNEL_LOG_INFO("RunCpuKernel C begin");
|
||||
if (param == nullptr) {
|
||||
KERNEL_LOG_ERROR("Param is null.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// parse param_len
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
(param_head->extInfoLength > kMaxExtendLen)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Param length[%u] not in [%zu, %u] or extend info length[%u] is "
|
||||
"greater "
|
||||
"than the limit[%u].",
|
||||
param_head->length, sizeof(AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
SessionInfo *session = nullptr;
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, session);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (session == nullptr) {
|
||||
KERNEL_LOG_INFO("RunCpuKernel directly.");
|
||||
CpuKernelCache cache;
|
||||
cache.Init(false);
|
||||
return cache.RunKernel(param);
|
||||
}
|
||||
|
||||
std::string stream_id_value = "0";
|
||||
auto status = GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("GetThreadLocalCtx failed, ret[%d].", status);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
uint64_t stream_id = atoi(stream_id_value.c_str());
|
||||
KERNEL_LOG_INFO(
|
||||
"RunCpuKernel from cache, stream id[%lu], session id[%lu], session "
|
||||
"flag[%d].",
|
||||
stream_id, session->sessionId, session->sessFlag);
|
||||
return SessionCache<CpuCacheData>::Instance().RunKernel<CpuKernelCache>(param, session->sessionId, stream_id,
|
||||
session->sessFlag);
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) uint32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info) {
|
||||
KERNEL_LOG_INFO("RunCpuKernelWithBlock C begin. blockid[%u], blockdim[%u].", blkdim_info->blockId,
|
||||
blkdim_info->blockNum);
|
||||
if (param == nullptr || blkdim_info == nullptr) {
|
||||
KERNEL_LOG_ERROR("Param is null.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// parse param_len
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
(param_head->extInfoLength > kMaxExtendLen)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Param length[%u] not in [%zu, %u] or extend info length[%u] is "
|
||||
"greater "
|
||||
"than the limit[%u].",
|
||||
param_head->length, sizeof(AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
SessionInfo *session = nullptr;
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, session);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (session == nullptr) {
|
||||
KERNEL_LOG_INFO("RunCpuKernelWithBlock directly.");
|
||||
CpuKernelCache cache;
|
||||
cache.Init(false);
|
||||
return cache.RunCpuKernelWithBlock(param, blkdim_info);
|
||||
}
|
||||
|
||||
std::string stream_id_value = "0";
|
||||
auto status = GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("GetThreadLocalCtx failed, ret[%d].", status);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
uint64_t stream_id = atoi(stream_id_value.c_str());
|
||||
KERNEL_LOG_INFO(
|
||||
"RunCpuKernel from cache, stream id[%lu], session id[%lu], session "
|
||||
"flag[%d].",
|
||||
stream_id, session->sessionId, session->sessFlag);
|
||||
return SessionCache<CpuCacheData>::Instance().RunCpuKernelWithBlock<CpuKernelCache>(
|
||||
param, session->sessionId, stream_id, session->sessFlag, blkdim_info);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_DEVICE_CPU_KERNEL_H
|
||||
#define AICPU_CONTEXT_COMMON_DEVICE_CPU_KERNEL_H
|
||||
#include <cstdint>
|
||||
|
||||
struct BlkDimInfo {
|
||||
uint32_t blockNum; // blockdim_num
|
||||
uint32_t blockId; // blockid
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
uint32_t RunCpuKernel(void *param);
|
||||
uint32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info);
|
||||
}
|
||||
#endif // AICPU_CONTEXT_COMMON_DEVICE_CPU_KERNEL_H
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/device_sharder.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace {
|
||||
const char *kSharderPath = "/usr/lib64/libaicpu_sharder.so";
|
||||
const char *kParallelForFunc = "ParallelFor";
|
||||
const char *kGetCPUNumFunc = "GetCPUNum";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
DeviceSharder::DeviceSharder(DeviceType device) : Sharder(device) {
|
||||
sharder_ = dlopen(kSharderPath, RTLD_LAZY | RTLD_GLOBAL);
|
||||
if (sharder_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Device sharder dlopen so[%s] failed, error[%s]", kSharderPath, dlerror());
|
||||
parallel_for_ = nullptr;
|
||||
get_cpu_num_ = nullptr;
|
||||
} else {
|
||||
parallel_for_ = reinterpret_cast<ParallelForFunc>(dlsym(sharder_, kParallelForFunc));
|
||||
if (parallel_for_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get function[%s] address failed, error[%s]", kParallelForFunc, dlerror());
|
||||
}
|
||||
|
||||
get_cpu_num_ = reinterpret_cast<GetCPUNumFunc>(dlsym(sharder_, kGetCPUNumFunc));
|
||||
if (get_cpu_num_ == nullptr) {
|
||||
KERNEL_LOG_WARN("Get function[%s] address failed, error[%s]", kGetCPUNumFunc, dlerror());
|
||||
}
|
||||
KERNEL_LOG_INFO("Device sharder dlopen so[%s] success", kSharderPath);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceSharder::~DeviceSharder() {
|
||||
if (sharder_ != nullptr) {
|
||||
(void)dlclose(sharder_);
|
||||
sharder_ = nullptr;
|
||||
}
|
||||
parallel_for_ = nullptr;
|
||||
}
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
*/
|
||||
void DeviceSharder::ParallelFor(int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) const {
|
||||
if (parallel_for_ != nullptr) {
|
||||
parallel_for_(total, perUnitSize, work);
|
||||
return;
|
||||
}
|
||||
|
||||
KERNEL_LOG_WARN("Function[%s] is null", kParallelForFunc);
|
||||
work(0, total);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
*/
|
||||
uint32_t DeviceSharder::GetCPUNum() const {
|
||||
if (get_cpu_num_ != nullptr) {
|
||||
return get_cpu_num_();
|
||||
}
|
||||
|
||||
KERNEL_LOG_WARN("Function[%s] is null", kGetCPUNumFunc);
|
||||
return 1;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_DEVICE_SHARDER_H
|
||||
#define AICPU_CONTEXT_COMMON_DEVICE_SHARDER_H
|
||||
#include "cpu_kernel/common/sharder.h"
|
||||
|
||||
namespace aicpu {
|
||||
using ParallelForFunc = void (*)(int64_t total, int64_t perUnitSize, const std::function<void(int64_t, int64_t)> &work);
|
||||
using GetCPUNumFunc = uint32_t (*)();
|
||||
class DeviceSharder : public Sharder {
|
||||
public:
|
||||
explicit DeviceSharder(DeviceType device);
|
||||
|
||||
~DeviceSharder() override;
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
* @param total: size of total work
|
||||
* @param perUnitSize: expect size of per unit work
|
||||
* @param work: process of per unit work
|
||||
*/
|
||||
void ParallelFor(int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) const override;
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @return CPU number
|
||||
*/
|
||||
uint32_t GetCPUNum() const override;
|
||||
|
||||
private:
|
||||
DeviceSharder(const DeviceSharder &) = delete;
|
||||
DeviceSharder(DeviceSharder &&) = delete;
|
||||
DeviceSharder &operator=(const DeviceSharder &) = delete;
|
||||
DeviceSharder &operator=(DeviceSharder &&) = delete;
|
||||
|
||||
private:
|
||||
void *sharder_;
|
||||
ParallelForFunc parallel_for_;
|
||||
GetCPUNumFunc get_cpu_num_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_DEVICE_SHARDER_H
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/eigen_threadpool.h"
|
||||
|
||||
#include <sys/sysinfo.h>
|
||||
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kTaskSize = 40000;
|
||||
const uint32_t kMaxOverShardingFactor = 4;
|
||||
const uint32_t kTotalCostFactor = 210000;
|
||||
constexpr uint32_t kMaxTaskSize = kTaskSize * kMaxOverShardingFactor;
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
std::mutex EigenThreadPool::mutex_;
|
||||
bool EigenThreadPool::init_flag_(false);
|
||||
int32_t EigenThreadPool::core_num_(0);
|
||||
std::unique_ptr<Eigen::ThreadPool> EigenThreadPool::eigen_threadpool_(nullptr);
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> EigenThreadPool::threadpool_device_(nullptr);
|
||||
|
||||
EigenThreadPool *EigenThreadPool::GetInstance() {
|
||||
KERNEL_LOG_INFO("EigenThreadPool GetInstance begin");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (!init_flag_) {
|
||||
core_num_ = get_nprocs(); // obtains the number of CPU cores that can be
|
||||
// used by users.
|
||||
if (core_num_ <= 0) {
|
||||
KERNEL_LOG_INFO(
|
||||
"Get the number of CPU cores that can be used failed, core "
|
||||
"number[%d]",
|
||||
core_num_);
|
||||
return nullptr;
|
||||
}
|
||||
eigen_threadpool_.reset(new Eigen::ThreadPool(core_num_));
|
||||
threadpool_device_.reset(new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), core_num_));
|
||||
init_flag_ = true;
|
||||
KERNEL_LOG_INFO("EigenThreadPool init success, core number[%d]", core_num_);
|
||||
}
|
||||
}
|
||||
|
||||
static EigenThreadPool instance;
|
||||
KERNEL_LOG_INFO("EigenThreadPool GetInstance success");
|
||||
return &instance;
|
||||
}
|
||||
|
||||
void EigenThreadPool::ParallelFor(int64_t total, int64_t per_unit_size, const SharderWork &work) const {
|
||||
KERNEL_LOG_INFO("Eigen threadpool parallel for begin, total[%ld], per_unit_size[%ld]", total, per_unit_size);
|
||||
if ((total <= 0) || (work == nullptr) || (per_unit_size <= 0)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Invalid param: total[%ld] <= 0 or per_unit_size[%ld] <= 0 or work "
|
||||
"is "
|
||||
"nullptr",
|
||||
total, per_unit_size);
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t total_check = static_cast<int64_t>(static_cast<Eigen::Index>(total));
|
||||
if (total_check != total) {
|
||||
KERNEL_LOG_ERROR("Invalid param: total[%ld], value[%ld] after eigen conversion", total, total_check);
|
||||
return;
|
||||
}
|
||||
|
||||
double per_unit_cost = 1.0;
|
||||
if (per_unit_size >= total) {
|
||||
// use the current thread to process the task
|
||||
per_unit_cost = 1.0 * kTaskSize / total;
|
||||
} else if ((per_unit_size) <= (total / core_num_)) {
|
||||
// run tasks with the maximum number of threads, maximum =
|
||||
// kMaxOverShardingFactor * core_num_
|
||||
per_unit_cost = (1.0 * kMaxTaskSize * core_num_ / total) > (1.0 * kTotalCostFactor / total)
|
||||
? (1.0 * kMaxTaskSize * core_num_ / total)
|
||||
: (1.0 * kTotalCostFactor / total);
|
||||
} else {
|
||||
// the task is fragmented based on the number of data slices.
|
||||
per_unit_cost = 1.0 * kMaxTaskSize / per_unit_size;
|
||||
}
|
||||
|
||||
KERNEL_LOG_INFO("Eigen threadpool parallel for, per_unit_cost[%.6f]", per_unit_cost);
|
||||
|
||||
threadpool_device_->parallelFor(total, Eigen::TensorOpCost(0, 0, per_unit_cost),
|
||||
[&work](Eigen::Index first, Eigen::Index last) { work(first, last); });
|
||||
KERNEL_LOG_INFO("Eigen threadpool parallel for success");
|
||||
}
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
*/
|
||||
uint32_t EigenThreadPool::GetCPUNum() const { return static_cast<uint32_t>(core_num_); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
||||
#define AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <unsupported/Eigen/CXX11/Tensor>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace aicpu {
|
||||
using SharderWork = std::function<void(int64_t, int64_t)>;
|
||||
|
||||
class EigenThreadPool {
|
||||
public:
|
||||
static EigenThreadPool *GetInstance();
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
*/
|
||||
void ParallelFor(int64_t total, int64_t per_unit_size, const SharderWork &work) const;
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @return CPU number
|
||||
*/
|
||||
uint32_t GetCPUNum() const;
|
||||
|
||||
private:
|
||||
EigenThreadPool() = default;
|
||||
~EigenThreadPool() = default;
|
||||
|
||||
EigenThreadPool(const EigenThreadPool &) = delete;
|
||||
EigenThreadPool(EigenThreadPool &&) = delete;
|
||||
EigenThreadPool &operator=(const EigenThreadPool &) = delete;
|
||||
EigenThreadPool &operator=(EigenThreadPool &&) = delete;
|
||||
|
||||
private:
|
||||
static std::mutex mutex_; // protect init_flag_
|
||||
static bool init_flag_; // true means initialized
|
||||
static int32_t core_num_; // the number of CPU cores that can be used by users
|
||||
static std::unique_ptr<Eigen::ThreadPool> eigen_threadpool_;
|
||||
static std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
|
||||
};
|
||||
}; // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/host_sharder.h"
|
||||
|
||||
#include "cpu_kernel/common/eigen_threadpool.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
*/
|
||||
void HostSharder::ParallelFor(int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) const {
|
||||
EigenThreadPool *threadpool = EigenThreadPool::GetInstance();
|
||||
if (threadpool == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get eigen thread pool failed");
|
||||
return;
|
||||
}
|
||||
|
||||
threadpool->ParallelFor(total, perUnitSize, work);
|
||||
}
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
*/
|
||||
uint32_t HostSharder::GetCPUNum() const {
|
||||
EigenThreadPool *threadpool = EigenThreadPool::GetInstance();
|
||||
if (threadpool == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get eigen thread pool failed");
|
||||
return 0;
|
||||
}
|
||||
|
||||
return threadpool->GetCPUNum();
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_HOST_SHARDER_H
|
||||
#define AICPU_CONTEXT_COMMON_HOST_SHARDER_H
|
||||
#include "cpu_kernel/common/sharder.h"
|
||||
|
||||
namespace aicpu {
|
||||
class HostSharder : public Sharder {
|
||||
public:
|
||||
explicit HostSharder(DeviceType device) : Sharder(device){};
|
||||
|
||||
~HostSharder() = default;
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
* @param total: size of total work
|
||||
* @param perUnitSize: expect size of per unit work
|
||||
* @param work: process of per unit work
|
||||
*/
|
||||
void ParallelFor(int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) const override;
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @return CPU number
|
||||
*/
|
||||
uint32_t GetCPUNum() const override;
|
||||
|
||||
private:
|
||||
HostSharder(const HostSharder &) = delete;
|
||||
HostSharder(HostSharder &&) = delete;
|
||||
HostSharder &operator=(const HostSharder &) = delete;
|
||||
HostSharder &operator=(HostSharder &&) = delete;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_HOST_SHARDER_H
|
|
@ -0,0 +1,166 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_KERNEL_CACHE_H
|
||||
#define AICPU_CONTEXT_COMMON_KERNEL_CACHE_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <mutex>
|
||||
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/device_cpu_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
template <class T>
|
||||
class KernelCache {
|
||||
public:
|
||||
KernelCache() : sess_flag_(false), capacity_(1) {}
|
||||
virtual ~KernelCache() = default;
|
||||
|
||||
/*
|
||||
* Init kernel cache.
|
||||
* @param sess_flag: whether it's a session scene, false need to support LRU
|
||||
* algorithm
|
||||
* @return int32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
int32_t Init(bool sess_flag) {
|
||||
sess_flag_ = sess_flag;
|
||||
return InitParameter();
|
||||
}
|
||||
|
||||
/*
|
||||
* run kernel.
|
||||
* @param param: kernel context
|
||||
* @return int32_t: 0 indicates success, whilWe the others fail
|
||||
*/
|
||||
virtual int32_t RunKernel(void *param) = 0;
|
||||
|
||||
/*
|
||||
* run kernel with blockDimInfo.
|
||||
* @param param: kernel context and kernel context and blkDimInfo
|
||||
* @return int32_t: 0 indicates success, whilWe the others fail
|
||||
*/
|
||||
virtual int32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkDimInfo) = 0;
|
||||
/*
|
||||
* get kernel cache, the lru algorithm is supported in non-session scenarios
|
||||
* @param key: kernel id
|
||||
* @return T *: cache content pointer
|
||||
*/
|
||||
T *GetCache(uint64_t key) {
|
||||
KERNEL_LOG_DEBUG("GetCache begin, key[%llu].", key);
|
||||
T *ret = nullptr;
|
||||
std::unique_lock<std::mutex> lock(kernel_mutex_);
|
||||
auto it = kernel_cache_iter_.find(key);
|
||||
if (it != kernel_cache_iter_.end()) {
|
||||
KERNEL_LOG_DEBUG("GetCache success, key[%llu].", key);
|
||||
ret = it->second->second.get();
|
||||
if (!sess_flag_) {
|
||||
auto pair_iter = it->second;
|
||||
std::pair<uint64_t, std::shared_ptr<T>> pair = *pair_iter;
|
||||
kernel_cache_.erase(pair_iter);
|
||||
kernel_cache_.push_front(pair);
|
||||
kernel_cache_iter_[key] = kernel_cache_.begin();
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* set kernel cache, the lru algorithm is supported in non-session scenarios
|
||||
* @param key: kernel id
|
||||
* @param value: cache content
|
||||
*/
|
||||
void SetCache(uint64_t key, std::shared_ptr<T> value) {
|
||||
KERNEL_LOG_DEBUG("SetCache begin, key[%llu].", key);
|
||||
std::unique_lock<std::mutex> lock(kernel_mutex_);
|
||||
auto iter = kernel_cache_iter_.find(key);
|
||||
if (iter != kernel_cache_iter_.end()) {
|
||||
KERNEL_LOG_DEBUG("SetCache update cache, key[%llu].", key);
|
||||
auto pair_iter = iter->second;
|
||||
pair_iter->second = value;
|
||||
if (!sess_flag_) {
|
||||
std::pair<uint64_t, std::shared_ptr<T>> pair = *pair_iter;
|
||||
kernel_cache_.erase(pair_iter);
|
||||
kernel_cache_.push_front(pair);
|
||||
kernel_cache_iter_[key] = kernel_cache_.begin();
|
||||
}
|
||||
} else {
|
||||
std::pair<uint64_t, std::shared_ptr<T>> pair = std::make_pair(key, value);
|
||||
if ((capacity_ < kernel_cache_.size()) && (!sess_flag_)) {
|
||||
uint64_t del_key = kernel_cache_.back().first;
|
||||
KERNEL_LOG_DEBUG(
|
||||
"SetCache is full, pop last element, capacity[%u], delete "
|
||||
"key[%llu].",
|
||||
capacity_, key);
|
||||
kernel_cache_.pop_back();
|
||||
auto del_iter = kernel_cache_iter_.find(del_key);
|
||||
if (del_iter != kernel_cache_iter_.end()) {
|
||||
kernel_cache_iter_.erase(del_iter);
|
||||
}
|
||||
}
|
||||
KERNEL_LOG_DEBUG("SetCache success, key[%llu].", key);
|
||||
kernel_cache_.push_front(pair);
|
||||
kernel_cache_iter_[key] = kernel_cache_.begin();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get session flag, true means session scene
|
||||
* @return bool: whether it's a session scene
|
||||
*/
|
||||
bool GetSessionFlag() const { return sess_flag_; }
|
||||
|
||||
/*
|
||||
* get kernel cache capacity
|
||||
* @return uint32_t: lru capacity
|
||||
*/
|
||||
uint32_t GetCapacity() { return capacity_; }
|
||||
|
||||
/*
|
||||
* set kernel cache capacity
|
||||
* @param capacity: lru capacity
|
||||
*/
|
||||
void SetCapacity(uint32_t capacity) { capacity_ = capacity; }
|
||||
|
||||
/*
|
||||
* get all kernel cache
|
||||
* @return std::list<std::pair<uint64_t, std::shared_ptr<T>>>: all cache,
|
||||
* pair<kernel id, cache>
|
||||
*/
|
||||
std::list<std::pair<uint64_t, std::shared_ptr<T>>> GetAllKernelCache() { return kernel_cache_; }
|
||||
|
||||
protected:
|
||||
virtual int32_t InitParameter() = 0;
|
||||
|
||||
private:
|
||||
KernelCache(const KernelCache &) = delete;
|
||||
KernelCache(KernelCache &&) = delete;
|
||||
KernelCache &operator=(const KernelCache &) = delete;
|
||||
KernelCache &operator=(KernelCache &&) = delete;
|
||||
|
||||
bool sess_flag_; // whether it's a session scene, false need to support LRU
|
||||
uint32_t capacity_; // lru capacity
|
||||
std::mutex kernel_mutex_;
|
||||
std::list<std::pair<uint64_t, std::shared_ptr<T>>> kernel_cache_; // all kernel cache, key is kernel id
|
||||
std::unordered_map<uint64_t, typename std::list<std::pair<uint64_t, std::shared_ptr<T>>>::iterator>
|
||||
kernel_cache_iter_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_KERNEL_CACHE_H
|
|
@ -0,0 +1,191 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019. All rights reserved.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
*
|
||||
* Description: tensorflow's kernel info
|
||||
*/
|
||||
#include "cpu_kernel/common/node_def_builder.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
std::shared_ptr<NodeDef> NodeDefBuilder::CreateNodeDef() {
|
||||
return CpuKernelUtils::CpuKernelUtils::CreateNodeDef();
|
||||
}
|
||||
|
||||
NodeDefBuilder::NodeDefBuilder(NodeDef *nodeDef, std::string name, std::string opName) {
|
||||
nodeDef_ = nodeDef;
|
||||
name_ = name;
|
||||
nodeDef_->SetOpType(opName);
|
||||
}
|
||||
|
||||
void NodeDefBuilder::BuildNodeFromInputOutputNode(const InputOutputNode& node, bool isInput) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
if (isInput) {
|
||||
tensor = nodeDef_->AddInputs();
|
||||
} else {
|
||||
tensor = nodeDef_->AddOutputs();
|
||||
}
|
||||
aicpu::CpuKernelUtils::SetTensorName(node.node, tensor);
|
||||
tensor->SetDataType(node.dType);
|
||||
auto shape = tensor->GetTensorShape();
|
||||
shape->SetDimSizes(node.dims);
|
||||
shape->SetFormat(node.format);
|
||||
int64_t dataSize = 1;
|
||||
for (size_t i = 0; i < node.dims.size(); i++) {
|
||||
dataSize = dataSize * node.dims[i];
|
||||
}
|
||||
dataSize = dataSize * GetSizeByDataType(node.dType);
|
||||
if (node.dims.empty()) {
|
||||
dataSize = GetSizeByDataType(node.dType);
|
||||
}
|
||||
if (node.data == nullptr) {
|
||||
dataSize = 0;
|
||||
}
|
||||
tensor->SetDataSize(dataSize);
|
||||
tensor->SetData(node.data);
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Input(const InputOutputNode& input) {
|
||||
BuildNodeFromInputOutputNode(input, true);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Output(const InputOutputNode& output) {
|
||||
BuildNodeFromInputOutputNode(output, false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int32_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int64_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, float value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, double value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, bool value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::DataType value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<bool> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::string &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::string> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<float> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<aicpu::DataType> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &dims, std::string type) {
|
||||
if (type == "shape") {
|
||||
auto shape = CpuKernelUtils::CreateAttrValue();
|
||||
auto value = CpuKernelUtils::CreateTensorShape();
|
||||
value->SetDimSizes(dims);
|
||||
shape->SetTensorShape(value.get());
|
||||
nodeDef_->AddAttrs(name, shape.get());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists,
|
||||
std::string type) {
|
||||
if (type == "shape_list") {
|
||||
auto shapeItems = CpuKernelUtils::CreateAttrValue();
|
||||
for (size_t i = 0; i < shapeLists.size(); i++) {
|
||||
auto value = shapeItems->AddListTensorShape();
|
||||
value->SetDimSizes(shapeLists[i]);
|
||||
}
|
||||
nodeDef_->AddAttrs(name, shapeItems.get());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::Tensor *tensor) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetTensor(tensor);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, std::vector<aicpu::Tensor *> &tensors) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListTensor(tensors);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019. All rights reserved.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
*
|
||||
* Description: tensorflow's kernel info
|
||||
*/
|
||||
#ifndef NODE_DEF_BUILDER_H
|
||||
#define NODE_DEF_BUILDER_H
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "cpu_kernel/common/cpu_kernel_register.h"
|
||||
#include "aicpu/common/aicpu_task_struct.h"
|
||||
#include "cpu_kernel/common/device_cpu_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class NodeDefBuilder {
|
||||
public:
|
||||
struct InputOutputNode {
|
||||
std::string node;
|
||||
aicpu::DataType dType;
|
||||
std::vector<int64_t> dims;
|
||||
void *data;
|
||||
aicpu::Format format;
|
||||
};
|
||||
|
||||
static std::shared_ptr<NodeDef> CreateNodeDef();
|
||||
|
||||
NodeDefBuilder(NodeDef *nodeDef, std::string name, std::string opName);
|
||||
|
||||
NodeDefBuilder &Input(const InputOutputNode &input);
|
||||
|
||||
NodeDefBuilder &Output(const InputOutputNode &output);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, int32_t value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, int64_t value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, float value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, double value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, bool value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::DataType value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<bool> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::string &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::string> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::vector<int64_t>> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<float> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<aicpu::DataType> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &dims, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::Tensor *tensor);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, std::vector<aicpu::Tensor *> &tensors);
|
||||
|
||||
private:
|
||||
void BuildNodeFromInputOutputNode(const InputOutputNode &node, bool isInput);
|
||||
|
||||
NodeDef *nodeDef_;
|
||||
|
||||
std::string name_;
|
||||
|
||||
std::string opName_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_CONTEXT_COMMON_NOTIFICATION_H
|
||||
#define AICPU_CONTEXT_COMMON_NOTIFICATION_H
|
||||
#include <cassert>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class Notification {
|
||||
public:
|
||||
Notification() : notified_(0) {}
|
||||
~Notification() { std::unique_lock<std::mutex> l(mu_); }
|
||||
|
||||
void Notify() {
|
||||
std::unique_lock<std::mutex> l(mu_);
|
||||
if (!HasBeenNotified()) {
|
||||
notified_.store(true, std::memory_order_release);
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
bool HasBeenNotified() const { return notified_.load(std::memory_order_acquire); }
|
||||
|
||||
void WaitForNotification() {
|
||||
if (!HasBeenNotified()) {
|
||||
std::unique_lock<std::mutex> l(mu_);
|
||||
while (!HasBeenNotified()) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mu_; // protects mutations of notified_
|
||||
std::condition_variable cv_; // signaled when notified_ becomes non-zero
|
||||
std::atomic<bool> notified_; // mutations under mu_
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_NOTIFICATION_H
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef INC_GE_RUNTIME_TENSOR_DESC_H_
|
||||
#define INC_GE_RUNTIME_TENSOR_DESC_H_
|
||||
|
||||
namespace ge {
|
||||
constexpr int64_t kMaxDimSize = 32;
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct RuntimeTensorDesc {
|
||||
uint64_t data_addr;
|
||||
int64_t data_offset_size;
|
||||
int64_t dtype;
|
||||
int64_t shape[kMaxDimSize + 1]; // shape:Dim_Num|DIM0|DIM1|...|DIM31
|
||||
int64_t original_shape[kMaxDimSize + 1]; // original_shape:Dim_Num|DIM0|DIM1|...|DIM31
|
||||
int64_t format;
|
||||
int64_t sub_format;
|
||||
uint8_t reserved[456]; // padding to 1024 bytes
|
||||
};
|
||||
#pragma pack(pop)
|
||||
} // namespace ge
|
||||
|
||||
#endif // INC_GE_RUNTIME_TENSOR_DESC_H_
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_SESSION_CACHE_H
|
||||
#define AICPU_CONTEXT_COMMON_SESSION_CACHE_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include "cpu_kernel/common/kernel_cache.h"
|
||||
|
||||
namespace aicpu {
|
||||
template <class C>
|
||||
class SessionCache {
|
||||
public:
|
||||
static SessionCache<C> &Instance() {
|
||||
static SessionCache<C> instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
/*
|
||||
* run and cache kernel.
|
||||
* @param param: kernel context
|
||||
* @param session_id: sesson id
|
||||
* @param stream_id: stream id
|
||||
* @param sess_flag: whether it's a session scene, true use session id, false
|
||||
* @param blkdim_info: Op's blkdim_info
|
||||
* use stream id
|
||||
* @return int32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
template <class T>
|
||||
int32_t RunCpuKernelWithBlock(void *param, uint64_t session_id, uint64_t stream_id, bool sess_flag,
|
||||
struct BlkDimInfo *blkdim_info) {
|
||||
std::shared_ptr<KernelCache<C>> kernel = nullptr;
|
||||
if (sess_flag) {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from session, id[%llu].", session_id);
|
||||
std::unique_lock<std::mutex> lock(session_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(session_kernel_cache_, session_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from stream, id[%llu].", stream_id);
|
||||
std::unique_lock<std::mutex> lock(stream_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(stream_kernel_cache_, stream_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return kernel->RunCpuKernelWithBlock(param, blkdim_info);
|
||||
}
|
||||
|
||||
/*
|
||||
* run and cache kernel.
|
||||
* @param param: kernel context
|
||||
* @param session_id: sesson id
|
||||
* @param stream_id: stream id
|
||||
* @param sess_flag: whether it's a session scene, true use session id, false
|
||||
* use stream id
|
||||
* @return int32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
template <class T>
|
||||
int32_t RunKernel(void *param, uint64_t session_id, uint64_t stream_id, bool sess_flag) {
|
||||
std::shared_ptr<KernelCache<C>> kernel = nullptr;
|
||||
if (sess_flag) {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from session, id[%llu].", session_id);
|
||||
std::unique_lock<std::mutex> lock(session_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(session_kernel_cache_, session_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from stream, id[%llu].", stream_id);
|
||||
std::unique_lock<std::mutex> lock(stream_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(stream_kernel_cache_, stream_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return kernel->RunKernel(param);
|
||||
}
|
||||
|
||||
private:
|
||||
SessionCache() = default;
|
||||
~SessionCache() = default;
|
||||
SessionCache(const SessionCache &) = delete;
|
||||
SessionCache(SessionCache &&) = delete;
|
||||
SessionCache &operator=(const SessionCache &) = delete;
|
||||
SessionCache &operator=(SessionCache &&) = delete;
|
||||
|
||||
template <class T>
|
||||
int32_t GetOrCreateKernelCache(std::map<uint64_t, std::shared_ptr<KernelCache<C>>> &kernel_map, uint64_t id,
|
||||
bool sess_flag, std::shared_ptr<KernelCache<C>> &kernel) {
|
||||
auto iter = kernel_map.find(id);
|
||||
if (iter != kernel_map.end()) {
|
||||
KERNEL_LOG_DEBUG("Get kernel from cache success, id[%llu].", id);
|
||||
kernel = iter->second;
|
||||
} else {
|
||||
KernelCache<C> *cache = new (std::nothrow) T();
|
||||
if (cache == nullptr) {
|
||||
KERNEL_LOG_DEBUG("Create kernel cache failed, id[%llu].", id);
|
||||
return -1;
|
||||
}
|
||||
kernel = std::shared_ptr<KernelCache<C>>(cache);
|
||||
int32_t ret = kernel->Init(sess_flag);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
kernel_map.insert(std::make_pair(id, kernel));
|
||||
KERNEL_LOG_DEBUG("Create kernel cache, id[%llu].", id);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex stream_mutex_;
|
||||
std::map<uint64_t, std::shared_ptr<KernelCache<C>>> stream_kernel_cache_; // key is stream id
|
||||
std::mutex session_mutex_;
|
||||
std::map<uint64_t, std::shared_ptr<KernelCache<C>>> session_kernel_cache_; // key is session id
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_SESSION_CACHE_H
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_SHARDER_H
|
||||
#define AICPU_CONTEXT_COMMON_SHARDER_H
|
||||
#include <functional>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
class Sharder {
|
||||
public:
|
||||
explicit Sharder(DeviceType device) : device_(device) {}
|
||||
|
||||
virtual ~Sharder() = default;
|
||||
|
||||
/*
|
||||
* ParallelFor shards the "total" units of work.
|
||||
* @param total: size of total work
|
||||
* @param perUnitSize: expect size of per unit work
|
||||
* @param work: process of per unit work
|
||||
*/
|
||||
virtual void ParallelFor(int64_t total, int64_t perUnitSize,
|
||||
const std::function<void(int64_t, int64_t)> &work) const = 0;
|
||||
|
||||
/*
|
||||
* Get CPU number
|
||||
* @return CPU number
|
||||
*/
|
||||
virtual uint32_t GetCPUNum() const = 0;
|
||||
|
||||
private:
|
||||
Sharder(const Sharder &) = delete;
|
||||
Sharder(Sharder &&) = delete;
|
||||
Sharder &operator=(const Sharder &) = delete;
|
||||
Sharder &operator=(Sharder &&) = delete;
|
||||
|
||||
private:
|
||||
DeviceType device_; // device type, HOST/DEVICE
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_SHARDER_H
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_STATUS_H
|
||||
#define AICPU_CONTEXT_COMMON_STATUS_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* status code
|
||||
*/
|
||||
enum KernelStatus : uint32_t {
|
||||
// 0-3 is fixed error code, runtime need interpret 0-3 error codes
|
||||
KERNEL_STATUS_OK = 0,
|
||||
KERNEL_STATUS_PARAM_INVALID = 1,
|
||||
KERNEL_STATUS_INNER_ERROR = 2,
|
||||
KERNEL_STATUS_TIMEOUT = 3,
|
||||
KERNEL_STATUS_PROTOBUF_ERROR = 4,
|
||||
KERNEL_STATUS_SHARDER_ERROR = 5,
|
||||
KERNEL_STATUS_END_OF_SEQUENCE = 201
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_STATUS_H
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_THREAD_CTX_H_
|
||||
#define AICPU_CONTEXT_COMMON_THREAD_CTX_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
#include "aicpu_sharder/aicpu_context.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ThreadCtx {
|
||||
public:
|
||||
explicit ThreadCtx(DeviceType device) : device_(device) {}
|
||||
|
||||
virtual ~ThreadCtx() = default;
|
||||
|
||||
virtual uint32_t SetThreadCtxInfo(CtxType type, const std::string &key, const std::string &value) const = 0;
|
||||
|
||||
virtual uint32_t GetThreadCtxInfo(CtxType type, const std::string &key, std::string &value) const = 0;
|
||||
|
||||
virtual uint32_t RemoveThreadCtxInfo(CtxType type, const std::string &key) const = 0;
|
||||
|
||||
private:
|
||||
ThreadCtx(const ThreadCtx &) = delete;
|
||||
ThreadCtx(ThreadCtx &&) = delete;
|
||||
ThreadCtx &operator=(const ThreadCtx &) = delete;
|
||||
ThreadCtx &operator=(ThreadCtx &&) = delete;
|
||||
|
||||
private:
|
||||
DeviceType device_; // device type, HOST/DEVICE
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_THREAD_CTX_H_
|
|
@ -0,0 +1,243 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/cpu_proto/attr_value_impl.h"
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
|
||||
namespace aicpu {
|
||||
AttrValue::AttrValue(AttrValueImpl *impl) : impl_(impl) {}
|
||||
|
||||
/*
|
||||
* get string value of attr.
|
||||
*/
|
||||
std::string AttrValue::GetString() const { return impl_->GetString(); }
|
||||
|
||||
/*
|
||||
* get string list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListStringSize() const { return impl_->ListStringSize(); }
|
||||
|
||||
/*
|
||||
* get string list value of attr.
|
||||
*/
|
||||
std::vector<std::string> AttrValue::GetListString() const { return impl_->GetListString(); }
|
||||
|
||||
/*
|
||||
* set string list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListString(const std::vector<std::string> &bytes) { impl_->SetListString(bytes); }
|
||||
|
||||
/*
|
||||
* set string value to attr.
|
||||
*/
|
||||
void AttrValue::SetString(const std::string &byte) { impl_->SetString(byte); }
|
||||
|
||||
/*
|
||||
* attr add string value to list.
|
||||
*/
|
||||
void AttrValue::AddListString(const std::string &str) { impl_->AddListString(str); }
|
||||
|
||||
/*
|
||||
* get int value of attr.
|
||||
*/
|
||||
int64_t AttrValue::GetInt() const { return impl_->GetInt(); }
|
||||
|
||||
/*
|
||||
* get int list value of attr.
|
||||
*/
|
||||
std::vector<int64_t> AttrValue::GetListInt() const { return impl_->GetListInt(); }
|
||||
|
||||
/*
|
||||
* get int list list value of attr.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> AttrValue::GetListListInt() const { return impl_->GetListListInt(); }
|
||||
|
||||
/*
|
||||
* attr add int value to list.
|
||||
*/
|
||||
void AttrValue::AddListInt(int64_t i) { impl_->AddListInt(i); }
|
||||
|
||||
/*
|
||||
* get int list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListIntSize() const { return impl_->ListIntSize(); }
|
||||
|
||||
/*
|
||||
* set int value to attr.
|
||||
*/
|
||||
void AttrValue::SetInt(int64_t i) { impl_->SetInt(i); }
|
||||
|
||||
/*
|
||||
* set int list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListInt(const std::vector<int64_t> &i) { impl_->SetListInt(i); }
|
||||
|
||||
/*
|
||||
* set int list list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListListInt(const std::vector<std::vector<int64_t>> &i) { impl_->SetListListInt(i); }
|
||||
|
||||
/*
|
||||
* get float value of attr.
|
||||
*/
|
||||
float AttrValue::GetFloat() const { return impl_->GetFloat(); }
|
||||
|
||||
/*
|
||||
* get float list value of attr.
|
||||
*/
|
||||
std::vector<float> AttrValue::GetListFloat() const { return impl_->GetListFloat(); }
|
||||
|
||||
/*
|
||||
* attr add float value to list.
|
||||
*/
|
||||
void AttrValue::AddListFloat(float f) { impl_->AddListFloat(f); }
|
||||
|
||||
/*
|
||||
* set float value to attr.
|
||||
*/
|
||||
void AttrValue::SetFloat(float f) { impl_->SetFloat(f); }
|
||||
|
||||
/*
|
||||
* get float list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListFloatSize() const { return impl_->ListFloatSize(); }
|
||||
|
||||
/*
|
||||
* set float list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListFloat(const std::vector<float> &f) { impl_->SetListFloat(f); }
|
||||
|
||||
/*
|
||||
* get bool value of attr.
|
||||
*/
|
||||
bool AttrValue::GetBool() const { return impl_->GetBool(); }
|
||||
|
||||
/*
|
||||
* get bool list value of attr.
|
||||
*/
|
||||
std::vector<bool> AttrValue::GetListBool() const { return impl_->GetListBool(); }
|
||||
|
||||
/*
|
||||
* attr add bool value to list.
|
||||
*/
|
||||
void AttrValue::AddListBool(bool b) { impl_->AddListBool(b); }
|
||||
|
||||
/*
|
||||
* get bool list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListBoolSize() const { return impl_->ListBoolSize(); }
|
||||
|
||||
/*
|
||||
* set bool value to attr.
|
||||
*/
|
||||
void AttrValue::SetBool(bool b) { impl_->SetBool(b); }
|
||||
|
||||
/*
|
||||
* set bool list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListBool(const std::vector<bool> &b) { return impl_->SetListBool(b); }
|
||||
|
||||
/*
|
||||
* get data type value of attr.
|
||||
*/
|
||||
DataType AttrValue::GetDataType() const { return impl_->GetDataType(); }
|
||||
|
||||
/*
|
||||
* get data type list value of attr.
|
||||
*/
|
||||
std::vector<DataType> AttrValue::GetListDataType() const { return impl_->GetListDataType(); }
|
||||
|
||||
/*
|
||||
* attr add data type value to list.
|
||||
*/
|
||||
void AttrValue::AddListDataType(DataType type) { impl_->AddListDataType(type); }
|
||||
|
||||
/*
|
||||
* get data type list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListDataTypeSize() const { return impl_->ListDataTypeSize(); }
|
||||
|
||||
/*
|
||||
* set data type value to attr.
|
||||
*/
|
||||
void AttrValue::SetDataType(DataType type) { impl_->SetDataType(type); }
|
||||
|
||||
/*
|
||||
* set data type list value to attr.
|
||||
*/
|
||||
void AttrValue::SetListDataType(const std::vector<DataType> &type) { impl_->SetListDataType(type); }
|
||||
|
||||
/*
|
||||
* set tensor shape value to attr.
|
||||
*/
|
||||
bool AttrValue::SetTensorShape(const TensorShape *shape) { return impl_->SetTensorShape(shape); }
|
||||
|
||||
/*
|
||||
* set tensor shape list value to attr.
|
||||
*/
|
||||
uint32_t AttrValue::SetListTensorShape(const std::vector<TensorShape *> &shape) {
|
||||
return impl_->SetListTensorShape(shape);
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add tensor shape value to list.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AttrValue::AddListTensorShape() { return impl_->AddListTensorShape(); }
|
||||
|
||||
/*
|
||||
* get tensor shape value of attr.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AttrValue::GetTensorShape() const { return impl_->GetTensorShape(); }
|
||||
|
||||
/*
|
||||
* get tensor shape list value of attr.
|
||||
*/
|
||||
std::vector<TensorShape> AttrValue::GetListTensorShape() const { return impl_->GetListTensorShape(); }
|
||||
|
||||
/*
|
||||
* get tensor shape list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListTensorShapeSize() const { return impl_->ListTensorShapeSize(); }
|
||||
|
||||
/*
|
||||
* set tensor value to attr.
|
||||
*/
|
||||
bool AttrValue::SetTensor(const Tensor *tensor) { return impl_->SetTensor(tensor); }
|
||||
|
||||
/*
|
||||
* set tensor list value to attr.
|
||||
*/
|
||||
uint32_t AttrValue::SetListTensor(const std::vector<Tensor *> &tensor) { return impl_->SetListTensor(tensor); }
|
||||
|
||||
/*
|
||||
* attr add tensor value to list.
|
||||
*/
|
||||
std::shared_ptr<Tensor> AttrValue::AddListTensor() { return impl_->AddListTensor(); }
|
||||
|
||||
/*
|
||||
* get tensor value of attr.
|
||||
*/
|
||||
std::shared_ptr<Tensor> AttrValue::GetTensor() const { return impl_->GetTensor(); }
|
||||
|
||||
/*
|
||||
* get tensor list value of attr.
|
||||
*/
|
||||
std::vector<Tensor> AttrValue::GetListTensor() const { return impl_->GetListTensor(); }
|
||||
|
||||
/*
|
||||
* get tensor list size of attr.
|
||||
*/
|
||||
int32_t AttrValue::ListTensorSize() const { return impl_->ListTensorSize(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,570 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/cpu_proto/attr_value_impl.h"
|
||||
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_impl.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_shape_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* get string value of attr.
|
||||
*/
|
||||
std::string AttrValueImpl::GetString() const { return attr_value_->s(); }
|
||||
|
||||
/*
|
||||
* get string list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListStringSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.s_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* get string list value of attr.
|
||||
*/
|
||||
std::vector<std::string> AttrValueImpl::GetListString() const {
|
||||
std::vector<std::string> ret;
|
||||
auto array = attr_value_->array();
|
||||
for (int32_t i = 0; i < array.s_size(); i++) {
|
||||
ret.emplace_back(array.s(i));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* set string list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListString(const std::vector<std::string> &bytes) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const std::string &s : bytes) {
|
||||
array->add_s(s);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* set string value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetString(const std::string &byte) { attr_value_->set_s(byte); }
|
||||
|
||||
/*
|
||||
* attr add string value to list.
|
||||
*/
|
||||
void AttrValueImpl::AddListString(const std::string &str) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
array->add_s(str);
|
||||
}
|
||||
|
||||
/*
|
||||
* get int value of attr.
|
||||
*/
|
||||
int64_t AttrValueImpl::GetInt() const { return attr_value_->i(); }
|
||||
|
||||
/*
|
||||
* get int list value of attr.
|
||||
*/
|
||||
std::vector<int64_t> AttrValueImpl::GetListInt() const {
|
||||
std::vector<int64_t> ret;
|
||||
auto array = attr_value_->array();
|
||||
for (int32_t i = 0; i < array.i_size(); i++) {
|
||||
ret.emplace_back(array.i(i));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add int value to list.
|
||||
*/
|
||||
void AttrValueImpl::AddListInt(int64_t i) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
array->add_i(i);
|
||||
}
|
||||
|
||||
/*
|
||||
* get int list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListIntSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.i_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* set int value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetInt(int64_t i) { attr_value_->set_i(i); }
|
||||
|
||||
/*
|
||||
* set int list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListInt(const std::vector<int64_t> &list) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const int64_t &i : list) {
|
||||
array->add_i(i);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get int list list value of attr.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> AttrValueImpl::GetListListInt() const {
|
||||
auto array = attr_value_->list_list_int();
|
||||
std::vector<std::vector<int64_t>> ret;
|
||||
for (auto idx = 0; idx < array.list_list_i_size(); ++idx) {
|
||||
std::vector<int64_t> vec;
|
||||
for (auto i = 0; i < array.list_list_i(idx).list_i_size(); ++i) {
|
||||
vec.emplace_back(array.list_list_i(idx).list_i(i));
|
||||
}
|
||||
ret.emplace_back(vec);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* set int list list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListListInt(const std::vector<std::vector<int64_t>> &list) {
|
||||
auto array = attr_value_->mutable_list_list_int();
|
||||
array->clear_list_list_i();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const std::vector<int64_t> &i : list) {
|
||||
const auto list_i = array->add_list_list_i();
|
||||
for (const int64_t val : i) {
|
||||
list_i->add_list_i(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get float value of attr.
|
||||
*/
|
||||
float AttrValueImpl::GetFloat() const { return attr_value_->f(); }
|
||||
|
||||
/*
|
||||
* get float list value of attr.
|
||||
*/
|
||||
std::vector<float> AttrValueImpl::GetListFloat() const {
|
||||
std::vector<float> ret;
|
||||
auto array = attr_value_->array();
|
||||
for (int32_t i = 0; i < array.f_size(); i++) {
|
||||
ret.emplace_back(array.f(i));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add float value to list.
|
||||
*/
|
||||
void AttrValueImpl::AddListFloat(float f) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
array->add_f(f);
|
||||
}
|
||||
|
||||
/*
|
||||
* set float value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetFloat(float f) { attr_value_->set_f(f); }
|
||||
|
||||
/*
|
||||
* get float list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListFloatSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.f_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* set float list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListFloat(const std::vector<float> &list) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const float &f : list) {
|
||||
array->add_f(f);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get bool value of attr.
|
||||
*/
|
||||
bool AttrValueImpl::GetBool() const { return attr_value_->b(); }
|
||||
|
||||
/*
|
||||
* get bool list value of attr.
|
||||
*/
|
||||
std::vector<bool> AttrValueImpl::GetListBool() const {
|
||||
std::vector<bool> ret;
|
||||
auto array = attr_value_->array();
|
||||
for (int32_t i = 0; i < array.b_size(); i++) {
|
||||
ret.push_back(array.b(i));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add bool value to list.
|
||||
*/
|
||||
void AttrValueImpl::AddListBool(bool b) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
array->add_b(b);
|
||||
}
|
||||
|
||||
/*
|
||||
* get bool list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListBoolSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.b_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* set bool value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetBool(bool b) { attr_value_->set_b(b); }
|
||||
|
||||
/*
|
||||
* set bool list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListBool(const std::vector<bool> &list) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const bool &b : list) {
|
||||
array->add_b(b);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get data type value of attr.
|
||||
*/
|
||||
DataType AttrValueImpl::GetDataType() const { return static_cast<DataType>(attr_value_->type()); }
|
||||
|
||||
/*
|
||||
* get data type list value of attr.
|
||||
*/
|
||||
std::vector<DataType> AttrValueImpl::GetListDataType() const {
|
||||
std::vector<DataType> ret;
|
||||
auto array = attr_value_->array();
|
||||
for (int32_t i = 0; i < array.type_size(); i++) {
|
||||
ret.emplace_back(static_cast<DataType>(array.type(i)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add data type value to list.
|
||||
*/
|
||||
void AttrValueImpl::AddListDataType(DataType type) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
array->add_type(type);
|
||||
}
|
||||
|
||||
/*
|
||||
* get data type list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListDataTypeSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.type_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* set data type value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetDataType(DataType type) { attr_value_->set_type(type); }
|
||||
|
||||
/*
|
||||
* set data type list value to attr.
|
||||
*/
|
||||
void AttrValueImpl::SetListDataType(const std::vector<DataType> &list) {
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR_VOID(array, "Protobuf mutable array is nullptr")
|
||||
for (const DataType &type : list) {
|
||||
array->add_type(type);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor shape value to attr.
|
||||
*/
|
||||
bool AttrValueImpl::SetTensorShape(const TensorShape *shape) {
|
||||
KERNEL_CHECK_NULLPTR(shape, false, "Shape is null")
|
||||
|
||||
auto tensorShape = attr_value_->mutable_shape();
|
||||
KERNEL_CHECK_NULLPTR(tensorShape, false, "Protobuf mutable tensor shape is null")
|
||||
auto impl = CpuKernelUtils::GetImpl(shape);
|
||||
KERNEL_CHECK_NULLPTR(impl, false, "Get impl is null")
|
||||
auto proto = impl->GetProto();
|
||||
KERNEL_CHECK_NULLPTR(proto, false, "Get proto is null")
|
||||
*tensorShape = *(impl->GetProto());
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor shape list value to attr.
|
||||
*/
|
||||
uint32_t AttrValueImpl::SetListTensorShape(const std::vector<TensorShape *> &list) {
|
||||
uint32_t ret = 0;
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR(array, ret, "Protobuf mutable array is nullptr")
|
||||
|
||||
for (size_t i = 0; i < list.size(); i++) {
|
||||
auto tmpShape = array->add_shape();
|
||||
if ((list[i] == nullptr) || (tmpShape == nullptr)) {
|
||||
KERNEL_LOG_ERROR("Shape[%zu] is null or protobuf add shape ret null.", i);
|
||||
} else {
|
||||
auto impl = CpuKernelUtils::GetImpl(list[i]);
|
||||
if ((impl == nullptr) || (impl->GetProto() == nullptr)) {
|
||||
KERNEL_LOG_ERROR("Get list[%zu] impl or proto is null.", i);
|
||||
continue;
|
||||
}
|
||||
*tmpShape = *(impl->GetProto());
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add tensor shape value to list.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AttrValueImpl::AddListTensorShape() {
|
||||
auto array = attr_value_->mutable_array();
|
||||
if (array == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable array is nullptr.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
auto shape = array->add_shape();
|
||||
if (shape == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable array add shape is nullptr.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
TensorShapeImpl *impl = new (std::nothrow) TensorShapeImpl(shape);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorShapeImpl failed.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
auto tensorShape = CpuKernelUtils::CreateTensorShape(impl);
|
||||
if (tensorShape == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return tensorShape;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor shape value of attr.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AttrValueImpl::GetTensorShape() const {
|
||||
auto shape = attr_value_->mutable_shape();
|
||||
if (shape == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable shape is nullptr.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
TensorShapeImpl *impl = new (std::nothrow) TensorShapeImpl(shape);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorShapeImpl failed.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
auto tensorShape = CpuKernelUtils::CreateTensorShape(impl);
|
||||
if (tensorShape == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return tensorShape;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor shape list value of attr.
|
||||
*/
|
||||
std::vector<TensorShape> AttrValueImpl::GetListTensorShape() const {
|
||||
std::vector<TensorShape> ret;
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR(array, ret, "Protobuf mutable array is nullptr")
|
||||
for (int32_t i = 0; i < array->shape_size(); i++) {
|
||||
auto shape = array->mutable_shape(i);
|
||||
if (shape == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable shape[%d] is nullptr.", i);
|
||||
return std::vector<TensorShape>();
|
||||
}
|
||||
|
||||
TensorShapeImpl *impl = new (std::nothrow) TensorShapeImpl(shape);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorShapeImpl[%d] failed.", i);
|
||||
return std::vector<TensorShape>();
|
||||
} else {
|
||||
auto tensorShape = CpuKernelUtils::CreateTensorShape(impl);
|
||||
if (tensorShape == nullptr) {
|
||||
delete impl;
|
||||
return std::vector<TensorShape>();
|
||||
}
|
||||
ret.emplace_back(*tensorShape);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor shape list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListTensorShapeSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.shape_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor value to attr.
|
||||
*/
|
||||
bool AttrValueImpl::SetTensor(const Tensor *tensor) {
|
||||
KERNEL_CHECK_NULLPTR(tensor, false, "Tensor is null")
|
||||
auto tensorPtr = attr_value_->mutable_tensor();
|
||||
KERNEL_CHECK_NULLPTR(tensorPtr, false, "Protobuf mutable tensor is nullptr")
|
||||
auto impl = CpuKernelUtils::GetImpl(tensor);
|
||||
KERNEL_CHECK_NULLPTR(impl, false, "Get impl is nullptr")
|
||||
auto proto = impl->GetProto();
|
||||
KERNEL_CHECK_NULLPTR(proto, false, "Get proto is nullptr")
|
||||
*tensorPtr = *(proto);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor list value to attr.
|
||||
*/
|
||||
uint32_t AttrValueImpl::SetListTensor(const std::vector<Tensor *> &list) {
|
||||
uint32_t ret = 0;
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR(array, ret, "Protobuf mutable array is nullptr")
|
||||
for (size_t i = 0; i < list.size(); i++) {
|
||||
auto tensorPtr = array->add_tensor();
|
||||
if ((list[i] == nullptr) || (tensorPtr == nullptr)) {
|
||||
KERNEL_LOG_WARN("Tensor[%zu] is null or protobuf add tensor ret null.", i);
|
||||
} else {
|
||||
auto impl = CpuKernelUtils::GetImpl(list[i]);
|
||||
if ((impl == nullptr) || (impl->GetProto() == nullptr)) {
|
||||
KERNEL_LOG_WARN("Get list[%zu] impl or proto is null.", i);
|
||||
continue;
|
||||
}
|
||||
*tensorPtr = *(impl->GetProto());
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* attr add tensor value to list.
|
||||
*/
|
||||
std::shared_ptr<Tensor> AttrValueImpl::AddListTensor() {
|
||||
auto array = attr_value_->mutable_array();
|
||||
if (array == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable array is nullptr.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto tensor = array->add_tensor();
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable array add tensor is nullptr.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpuTensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpuTensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpuTensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor value of attr.
|
||||
*/
|
||||
std::shared_ptr<Tensor> AttrValueImpl::GetTensor() const {
|
||||
auto tensor = attr_value_->mutable_tensor();
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable tensor is nullptr.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpuTensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpuTensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpuTensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor list value of attr.
|
||||
*/
|
||||
std::vector<Tensor> AttrValueImpl::GetListTensor() const {
|
||||
std::vector<Tensor> ret;
|
||||
auto array = attr_value_->mutable_array();
|
||||
KERNEL_CHECK_NULLPTR(array, ret, "Protobuf mutable array is nullptr")
|
||||
for (int32_t i = 0; i < array->tensor_size(); i++) {
|
||||
auto tensor = array->mutable_tensor(i);
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable tensor is nullptr.");
|
||||
return std::vector<Tensor>();
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl[%d] failed.", i);
|
||||
return std::vector<Tensor>();
|
||||
} else {
|
||||
auto aicpuTensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpuTensor == nullptr) {
|
||||
delete impl;
|
||||
return std::vector<Tensor>();
|
||||
}
|
||||
ret.emplace_back(*aicpuTensor);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor list size of attr.
|
||||
*/
|
||||
int32_t AttrValueImpl::ListTensorSize() const {
|
||||
auto array = attr_value_->array();
|
||||
return array.tensor_size();
|
||||
}
|
||||
|
||||
/*
|
||||
* get attr proto.
|
||||
*/
|
||||
aicpuops::AttrValue *AttrValueImpl::GetProto() const { return attr_value_.get(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,319 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_CPU_PROTO_ATTR_VALUE_IMPL_H
|
||||
#define AICPU_CONTEXT_CPU_PROTO_ATTR_VALUE_IMPL_H
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor_shape.h"
|
||||
#include "proto/cpu_attr.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AttrValueImpl {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
AttrValueImpl(
|
||||
aicpuops::AttrValue *attr, std::function<void(aicpuops::AttrValue *)> del_func = [](aicpuops::AttrValue *p) {})
|
||||
: attr_value_(attr, del_func) {}
|
||||
|
||||
~AttrValueImpl() = default;
|
||||
AttrValueImpl(const AttrValueImpl &) = delete;
|
||||
AttrValueImpl(AttrValueImpl &&) = delete;
|
||||
AttrValueImpl &operator=(const AttrValueImpl &) = delete;
|
||||
AttrValueImpl &operator=(AttrValueImpl &&) = delete;
|
||||
|
||||
/*
|
||||
* get string value of attr.
|
||||
* @return string: string value of attr
|
||||
*/
|
||||
std::string GetString() const;
|
||||
|
||||
/*
|
||||
* get string list value of attr.
|
||||
* @return vector<std::string>: string list value of attr
|
||||
*/
|
||||
std::vector<std::string> GetListString() const;
|
||||
|
||||
/*
|
||||
* attr add string value to list.
|
||||
* @param string: string value need to add to list
|
||||
*/
|
||||
void AddListString(const std::string &str);
|
||||
|
||||
/*
|
||||
* get string list size of attr.
|
||||
* @return int32_t: string list size of attr
|
||||
*/
|
||||
int32_t ListStringSize() const;
|
||||
|
||||
/*
|
||||
* set string value to attr.
|
||||
* @param string: string value need to set to attr
|
||||
*/
|
||||
void SetString(const std::string &byte);
|
||||
|
||||
/*
|
||||
* set string list value to attr.
|
||||
* @param vector<std::string>: string list value need to set to attr
|
||||
*/
|
||||
void SetListString(const std::vector<std::string> &bytes);
|
||||
|
||||
/*
|
||||
* get int value of attr.
|
||||
* @return int64_t: int value of attr
|
||||
*/
|
||||
int64_t GetInt() const;
|
||||
|
||||
/*
|
||||
* get int list value of attr.
|
||||
* @return vector<int64_t>: int list value of attr
|
||||
*/
|
||||
std::vector<int64_t> GetListInt() const;
|
||||
|
||||
/*
|
||||
* get int list list value of attr.
|
||||
* @return vector<vector<int64_t>>: int list list value of attr
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> GetListListInt() const;
|
||||
|
||||
/*
|
||||
* attr add int value to list.
|
||||
* @param i: int value need to add to list
|
||||
*/
|
||||
void AddListInt(int64_t i);
|
||||
|
||||
/*
|
||||
* get int list size of attr.
|
||||
* @return int32_t: int list size of attr
|
||||
*/
|
||||
int32_t ListIntSize() const;
|
||||
|
||||
/*
|
||||
* set int value to attr.
|
||||
* @param i: int value need to set to attr
|
||||
*/
|
||||
void SetInt(int64_t i);
|
||||
|
||||
/*
|
||||
* set int list value to attr.
|
||||
* @param vector<int64_t>: int list value need to set to attr
|
||||
*/
|
||||
void SetListInt(const std::vector<int64_t> &list);
|
||||
|
||||
/*
|
||||
* set int list list value to attr.
|
||||
* @param vector<vector<int64_t>>: int list list value need to set to attr
|
||||
*/
|
||||
void SetListListInt(const std::vector<std::vector<int64_t>> &list);
|
||||
|
||||
/*
|
||||
* get float value of attr.
|
||||
* @return float: float value of attr
|
||||
*/
|
||||
float GetFloat() const;
|
||||
|
||||
/*
|
||||
* get float list value of attr.
|
||||
* @return vector<float>: float list value of attr
|
||||
*/
|
||||
std::vector<float> GetListFloat() const;
|
||||
|
||||
/*
|
||||
* attr add float value to list.
|
||||
* @param f: float value need to add to list
|
||||
*/
|
||||
void AddListFloat(float f);
|
||||
|
||||
/*
|
||||
* get float list size of attr.
|
||||
* @return int32_t: float list size of attr
|
||||
*/
|
||||
int32_t ListFloatSize() const;
|
||||
|
||||
/*
|
||||
* set float value to attr.
|
||||
* @param f: float value need to set to attr
|
||||
*/
|
||||
void SetFloat(float f);
|
||||
|
||||
/*
|
||||
* set float list value to attr.
|
||||
* @param vector<float>: float list value need to set to attr
|
||||
*/
|
||||
void SetListFloat(const std::vector<float> &list);
|
||||
|
||||
/*
|
||||
* get bool value of attr.
|
||||
* @return bool: bool value of attr
|
||||
*/
|
||||
bool GetBool() const;
|
||||
|
||||
/*
|
||||
* get bool list value of attr.
|
||||
* @return vector<bool>: bool list value of attr
|
||||
*/
|
||||
std::vector<bool> GetListBool() const;
|
||||
|
||||
/*
|
||||
* attr add bool value to list.
|
||||
* @param b: bool value need to add to list
|
||||
*/
|
||||
void AddListBool(bool b);
|
||||
|
||||
/*
|
||||
* get bool list size of attr.
|
||||
* @return int32_t: bool list size of attr
|
||||
*/
|
||||
int32_t ListBoolSize() const;
|
||||
|
||||
/*
|
||||
* set bool value to attr.
|
||||
* @param b: bool value need to set to attr
|
||||
*/
|
||||
void SetBool(bool b);
|
||||
|
||||
/*
|
||||
* set bool list value to attr.
|
||||
* @param vector<bool>: bool list value need to set to attr
|
||||
*/
|
||||
void SetListBool(const std::vector<bool> &list);
|
||||
|
||||
/*
|
||||
* get data type value of attr.
|
||||
* @return DataType: data type value of attr
|
||||
*/
|
||||
DataType GetDataType() const;
|
||||
|
||||
/*
|
||||
* get data type list value of attr.
|
||||
* @return vector<int32_t>: data type list value of attr
|
||||
*/
|
||||
std::vector<DataType> GetListDataType() const;
|
||||
|
||||
/*
|
||||
* attr add data type value to list.
|
||||
* @param type: data type value need to add to list
|
||||
*/
|
||||
void AddListDataType(DataType type);
|
||||
|
||||
/*
|
||||
* get data type list size of attr.
|
||||
* @return int32_t: data type list size of attr
|
||||
*/
|
||||
int32_t ListDataTypeSize() const;
|
||||
|
||||
/*
|
||||
* set data type value to attr.
|
||||
* @param type: data type value need to set to attr
|
||||
*/
|
||||
void SetDataType(DataType type);
|
||||
|
||||
/*
|
||||
* set data type list value to attr.
|
||||
* @param vector<DataType>: data type list value need to set to attr
|
||||
*/
|
||||
void SetListDataType(const std::vector<DataType> &list);
|
||||
|
||||
/*
|
||||
* set tensor shape value to attr.
|
||||
* @param shape: tensor shape value need to set to attr
|
||||
* @return bool: true->success false->failed
|
||||
*/
|
||||
bool SetTensorShape(const TensorShape *shape);
|
||||
|
||||
/*
|
||||
* set tensor shape list value to attr.
|
||||
* @param vector<TensorShape>: tensor shape list value need to set to attr
|
||||
* @return uint32_t: success number
|
||||
*/
|
||||
uint32_t SetListTensorShape(const std::vector<TensorShape *> &list);
|
||||
|
||||
/*
|
||||
* attr add tensor shape value to list.
|
||||
* @return shared_ptr<TensorShape>: tensor shape value ptr added to list
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AddListTensorShape();
|
||||
|
||||
/*
|
||||
* get tensor shape value of attr.
|
||||
* @return TensorShape: tensor shape value of attr
|
||||
*/
|
||||
std::shared_ptr<TensorShape> GetTensorShape() const;
|
||||
|
||||
/*
|
||||
* get tensor shape list value of attr.
|
||||
* @return vector<TensorShape>: tensor shape list value of attr
|
||||
*/
|
||||
std::vector<TensorShape> GetListTensorShape() const;
|
||||
|
||||
/*
|
||||
* get tensor shape list size of attr.
|
||||
* @return int32_t: tensor shape list size of attr
|
||||
*/
|
||||
int32_t ListTensorShapeSize() const;
|
||||
|
||||
/*
|
||||
* set tensor value to attr.
|
||||
* @param tensor: tensor value need to set to attr
|
||||
* @return bool: true->success false->failed
|
||||
*/
|
||||
bool SetTensor(const Tensor *tensor);
|
||||
|
||||
/*
|
||||
* set tensor list value to attr.
|
||||
* @param vector<Tensor>: tensor list value need to set to attr
|
||||
* @return uint32_t: success number
|
||||
*/
|
||||
uint32_t SetListTensor(const std::vector<Tensor *> &list);
|
||||
|
||||
/*
|
||||
* attr add tensor value to list.
|
||||
* @return shared_ptr<Tensor>: tensor value ptr added to list
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddListTensor();
|
||||
|
||||
/*
|
||||
* get tensor value of attr.
|
||||
* @return Tensor: tensor value of attr
|
||||
*/
|
||||
std::shared_ptr<Tensor> GetTensor() const;
|
||||
|
||||
/*
|
||||
* get tensor list value of attr.
|
||||
* @return vector<Tensor>: tensor list value of attr
|
||||
*/
|
||||
std::vector<Tensor> GetListTensor() const;
|
||||
|
||||
/*
|
||||
* get tensor list size of attr.
|
||||
* @return int32_t: tensor list size of attr
|
||||
*/
|
||||
int32_t ListTensorSize() const;
|
||||
|
||||
/*
|
||||
* get attr proto.
|
||||
*/
|
||||
aicpuops::AttrValue *GetProto() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<aicpuops::AttrValue> attr_value_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_CPU_PROTO_ATTR_VALUE_IMPL_H
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/cpu_node_def.h"
|
||||
#include "cpu_kernel/cpu_proto/node_def_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
NodeDef::NodeDef(NodeDefImpl *impl) : impl_(impl) {}
|
||||
|
||||
/*
|
||||
* parse parameter from string.
|
||||
*/
|
||||
bool NodeDef::ParseFromString(const std::string &str) { return impl_->ParseFromString(str); }
|
||||
|
||||
/*
|
||||
* serialize string to node def.
|
||||
*/
|
||||
bool NodeDef::SerializeToString(std::string &str) const { return impl_->SerializeToString(str); }
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
*/
|
||||
void NodeDef::SetOpType(const std::string &op) { impl_->SetOpType(op); }
|
||||
|
||||
/*
|
||||
* get op type of node def.
|
||||
*/
|
||||
std::string NodeDef::GetOpType() const { return impl_->GetOpType(); }
|
||||
|
||||
/*
|
||||
* add input tensor to node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDef::AddInputs() { return impl_->AddInputs(); }
|
||||
|
||||
/*
|
||||
* add output tensor to node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDef::AddOutputs() { return impl_->AddOutputs(); }
|
||||
|
||||
/*
|
||||
* add attr to node def.
|
||||
*/
|
||||
bool NodeDef::AddAttrs(const std::string &name, const AttrValue *attr) { return impl_->AddAttrs(name, attr); }
|
||||
|
||||
/*
|
||||
* get input tensor size of node def.
|
||||
*/
|
||||
int32_t NodeDef::InputsSize() const { return impl_->InputsSize(); }
|
||||
|
||||
/*
|
||||
* get output tensor size of node def.
|
||||
*/
|
||||
int32_t NodeDef::OutputsSize() const { return impl_->OutputsSize(); }
|
||||
|
||||
/*
|
||||
* get input tensor of node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDef::MutableInputs(int32_t index) const { return impl_->MutableInputs(index); }
|
||||
|
||||
/*
|
||||
* get output tensor of node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDef::MutableOutputs(int32_t index) const { return impl_->MutableOutputs(index); }
|
||||
|
||||
/*
|
||||
* get attr of node def.
|
||||
*/
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue> > NodeDef::Attrs() const { return impl_->Attrs(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,224 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <utility>
|
||||
#include "cpu_kernel/cpu_proto/node_def_impl.h"
|
||||
|
||||
#include "cpu_kernel/cpu_proto/attr_value_impl.h"
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* parse parameter from string.
|
||||
*/
|
||||
bool NodeDefImpl::ParseFromString(const std::string &str) {
|
||||
if (!nodedef_->ParseFromString(str)) {
|
||||
KERNEL_LOG_ERROR("ParseFromString failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* serialize string to node def.
|
||||
*/
|
||||
bool NodeDefImpl::SerializeToString(std::string &str) const {
|
||||
if (!nodedef_->SerializeToString(&str)) {
|
||||
KERNEL_LOG_ERROR("SerializeToString failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
*/
|
||||
void NodeDefImpl::SetOpType(const std::string &op) { nodedef_->set_op(op); }
|
||||
|
||||
/*
|
||||
* get op type of node def.
|
||||
*/
|
||||
std::string NodeDefImpl::GetOpType() const { return nodedef_->op(); }
|
||||
|
||||
/*
|
||||
* add input tensor to node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDefImpl::AddInputs() {
|
||||
auto tensor = nodedef_->add_inputs();
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf node def add tensor is nullptr.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpu_tensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpu_tensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpu_tensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* add output tensor to node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDefImpl::AddOutputs() {
|
||||
auto tensor = nodedef_->add_outputs();
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf node def add tensor is nullptr.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpu_tensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpu_tensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpu_tensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* add attr to node def.
|
||||
*/
|
||||
bool NodeDefImpl::AddAttrs(const std::string &name, const AttrValue *attr) {
|
||||
if (attr == nullptr) {
|
||||
KERNEL_LOG_ERROR("Attr is null.");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto attrs = nodedef_->mutable_attrs();
|
||||
KERNEL_CHECK_NULLPTR(attrs, false, "Protobuf mutable attrs is null")
|
||||
auto impl = CpuKernelUtils::GetImpl(attr);
|
||||
auto pair =
|
||||
attrs->insert(google::protobuf::Map<std::string, aicpuops::AttrValue>::value_type(name, *(impl->GetProto())));
|
||||
if (!pair.second) {
|
||||
KERNEL_LOG_ERROR("Nodedef insert attr %s to nodeDef failed.", name.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* get input tensor size of node def.
|
||||
*/
|
||||
int32_t NodeDefImpl::InputsSize() const { return nodedef_->inputs_size(); }
|
||||
|
||||
/*
|
||||
* get output tensor size of node def.
|
||||
*/
|
||||
int32_t NodeDefImpl::OutputsSize() const { return nodedef_->outputs_size(); }
|
||||
|
||||
/*
|
||||
* get input tensor of node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDefImpl::MutableInputs(int32_t index) const {
|
||||
if ((index >= InputsSize()) || (index < 0)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Index[%d] should be less than input tensors size[%d] and noe less than "
|
||||
"0.",
|
||||
index, InputsSize());
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto tensor = nodedef_->mutable_inputs(index);
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf node def mutable inputs[%d] tensor is nullptr.", index);
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpu_tensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpu_tensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpu_tensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* get output tensor of node def.
|
||||
*/
|
||||
std::shared_ptr<Tensor> NodeDefImpl::MutableOutputs(int32_t index) const {
|
||||
if ((index >= OutputsSize()) || (index < 0)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Index[%d] should be less than output tensors size[%d] and noe less than "
|
||||
"0.",
|
||||
index, OutputsSize());
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto tensor = nodedef_->mutable_outputs(index);
|
||||
if (tensor == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf node def mutable outputs[%d] tensor is nullptr.", index);
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
TensorImpl *impl = new (std::nothrow) TensorImpl(tensor);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorImpl failed.");
|
||||
return std::shared_ptr<Tensor>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpu_tensor = CpuKernelUtils::CreateTensor(impl);
|
||||
if (aicpu_tensor == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpu_tensor;
|
||||
}
|
||||
|
||||
/*
|
||||
* get attr of node def.
|
||||
*/
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue>> NodeDefImpl::Attrs() const {
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue>> ret;
|
||||
auto attrs_map = nodedef_->mutable_attrs();
|
||||
KERNEL_CHECK_NULLPTR(attrs_map, ret, "Protobuf mutable attrs is null")
|
||||
|
||||
for (auto it = attrs_map->begin(); it != attrs_map->end(); ++it) {
|
||||
aicpuops::AttrValue *attr = &(it->second);
|
||||
AttrValueImpl *impl = new (std::nothrow) AttrValueImpl(attr);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_WARN("Create AttrValueImpl failed.");
|
||||
}
|
||||
|
||||
auto attr_value = CpuKernelUtils::CreateAttrValue(impl);
|
||||
if (attr_value == nullptr) {
|
||||
KERNEL_LOG_WARN("Create CreateAttrValue failed.");
|
||||
delete impl;
|
||||
}
|
||||
(void)ret.insert(std::make_pair(it->first, attr_value));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_CPU_PROTO_NODE_DEF_IMPL_H
|
||||
#define AICPU_CONTEXT_CPU_PROTO_NODE_DEF_IMPL_H
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
#include "proto/cpu_node_def.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
class NodeDefImpl {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
NodeDefImpl(
|
||||
aicpuops::NodeDef *nodedef, std::function<void(aicpuops::NodeDef *)> del_func = [](aicpuops::NodeDef *p) {})
|
||||
: nodedef_(nodedef, del_func) {}
|
||||
|
||||
~NodeDefImpl() = default;
|
||||
NodeDefImpl(const NodeDefImpl &) = delete;
|
||||
NodeDefImpl(NodeDefImpl &&) = delete;
|
||||
NodeDefImpl &operator=(const NodeDefImpl &) = delete;
|
||||
NodeDefImpl &operator=(NodeDefImpl &&) = delete;
|
||||
|
||||
/*
|
||||
* parse parameter from string.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool ParseFromString(const std::string &str);
|
||||
|
||||
/*
|
||||
* serialize string to node def.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SerializeToString(std::string &str) const;
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
* @param op: op type
|
||||
*/
|
||||
void SetOpType(const std::string &op);
|
||||
|
||||
/*
|
||||
* get op type of node def.
|
||||
* @return string: op type
|
||||
*/
|
||||
std::string GetOpType() const;
|
||||
|
||||
/*
|
||||
* add input tensor to node def.
|
||||
* @return shared_ptr<Tensor>: not null->success, null->failed
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddInputs();
|
||||
|
||||
/*
|
||||
* add output tensor to node def.
|
||||
* @return shared_ptr<Tensor>: not null->success, null->failed
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddOutputs();
|
||||
|
||||
/*
|
||||
* add attr to node def.
|
||||
* @param name: attr name
|
||||
* @param attr: attr need to add
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool AddAttrs(const std::string &name, const AttrValue *attr);
|
||||
|
||||
/*
|
||||
* get input tensor size of node def.
|
||||
* @return int32_t: input tensor size of node def
|
||||
*/
|
||||
int32_t InputsSize() const;
|
||||
|
||||
/*
|
||||
* get output tensor size of node def.
|
||||
* @return int32_t: input tensor size of node def
|
||||
*/
|
||||
int32_t OutputsSize() const;
|
||||
|
||||
/*
|
||||
* get input tensor of node def.
|
||||
* @param index: index of input tensor
|
||||
* @return shared_ptr<Tensor>: input tensor ptr of node def
|
||||
*/
|
||||
std::shared_ptr<Tensor> MutableInputs(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get output tensor of node def.
|
||||
* @param index: index of output tensor
|
||||
* @return shared_ptr<Tensor>: output tensor ptr of node def
|
||||
*/
|
||||
std::shared_ptr<Tensor> MutableOutputs(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get attr of node def.
|
||||
* @return std::unordered_map<std::string, std::shared_ptr<AttrValue>>: attrs
|
||||
* of node def
|
||||
*/
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue> > Attrs() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<aicpuops::NodeDef> nodedef_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_CPU_PROTO_NODE_DEF_IMPL_H
|
|
@ -0,0 +1,36 @@
|
|||
syntax = "proto3";
|
||||
package aicpuops;
|
||||
import "cpu_tensor.proto";
|
||||
import "cpu_tensor_shape.proto";
|
||||
|
||||
message AttrValue {
|
||||
|
||||
message ArrayValue {
|
||||
repeated bytes s = 2; //"array(string)"
|
||||
repeated int64 i = 3 [ packed = true ]; //"array(int)"
|
||||
repeated float f = 4 [ packed = true ]; //"array(float)"
|
||||
repeated bool b = 5 [ packed = true ]; //"array(bool)"
|
||||
repeated int32 type = 6 [ packed = true ]; //"array(type)"
|
||||
repeated TensorShape shape = 7; //"array(shape)"
|
||||
repeated Tensor tensor = 8; //"array(tensor)"
|
||||
}
|
||||
|
||||
message ListListInt{
|
||||
message ListInt{
|
||||
repeated int64 list_i = 1; // list int
|
||||
}
|
||||
repeated ListInt list_list_i = 1; // list list int
|
||||
}
|
||||
|
||||
oneof value {
|
||||
ArrayValue array = 1;
|
||||
bytes s = 2; //"string"
|
||||
int64 i = 3; //"int"
|
||||
float f = 4; //"float"
|
||||
bool b = 5; //"bool"
|
||||
int32 type = 6; //"type"
|
||||
TensorShape shape = 7; //"shape"
|
||||
Tensor tensor = 8; //"tensor"
|
||||
ListListInt list_list_int = 9; // List List Int type
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
syntax = "proto3";
|
||||
package aicpuops;
|
||||
import "cpu_attr.proto";
|
||||
import "cpu_tensor.proto";
|
||||
|
||||
message DynamicIdx {
|
||||
int32 idx = 1;
|
||||
int32 num = 2;
|
||||
}
|
||||
|
||||
message NodeDef {
|
||||
string op = 2;
|
||||
map<string, AttrValue> attrs = 3;
|
||||
repeated Tensor inputs = 4;
|
||||
repeated Tensor outputs = 5;
|
||||
map<string, DynamicIdx> dym_inputs = 6;
|
||||
map<string, DynamicIdx> dym_outputs = 7;
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
syntax = "proto3";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
import "cpu_tensor_shape.proto";
|
||||
package aicpuops;
|
||||
|
||||
message Tensor {
|
||||
|
||||
// tensor shape info
|
||||
TensorShape tensor_shape = 1;
|
||||
|
||||
// tensor content data type
|
||||
int32 tensor_type = 2;
|
||||
|
||||
// tensor memory device
|
||||
// data located memory device , "DDR" "HBM" OR "NONE"
|
||||
string mem_device = 3;
|
||||
string name = 4;
|
||||
uint64 data_ptr = 5;
|
||||
uint64 data_size = 6;
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
syntax = "proto3";
|
||||
package aicpuops;
|
||||
|
||||
message TensorShape {
|
||||
// One dimension of the tensor.
|
||||
message Dim {
|
||||
// size must >=0
|
||||
int64 size = 1;
|
||||
};
|
||||
|
||||
// group dim info
|
||||
repeated Dim dim = 2;
|
||||
|
||||
// If true, the number of dimensions in the shape is unknown.
|
||||
// If true, "dim.size()" must be 0.
|
||||
bool unknown_rank = 3;
|
||||
|
||||
// data format "NHWC" "NCHW" "NC1HWC0" OR "NONE"
|
||||
int32 data_format = 4;
|
||||
};
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
Tensor::Tensor(TensorImpl *impl) : impl_(impl) {}
|
||||
|
||||
/*
|
||||
* get tensor shape value of tensor.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> Tensor::GetTensorShape() const { return impl_->GetTensorShape(); }
|
||||
|
||||
/*
|
||||
* set tensor shape value to tensor.
|
||||
*/
|
||||
bool Tensor::SetTensorShape(const TensorShape *shape) { return impl_->SetTensorShape(shape); }
|
||||
|
||||
/*
|
||||
* get data type value of tensor.
|
||||
*/
|
||||
DataType Tensor::GetDataType() const { return impl_->GetDataType(); }
|
||||
|
||||
/*
|
||||
* set data type value to tensor.
|
||||
*/
|
||||
void Tensor::SetDataType(DataType type) { impl_->SetDataType(type); }
|
||||
|
||||
/*
|
||||
* get data ptr of tensor.
|
||||
*/
|
||||
void *Tensor::GetData() const { return impl_->GetData(); }
|
||||
|
||||
/*
|
||||
* set data ptr to tensor.
|
||||
*/
|
||||
void Tensor::SetData(void *addr) { impl_->SetData(addr); }
|
||||
|
||||
/*
|
||||
* get data size of tensor.
|
||||
*/
|
||||
uint64_t Tensor::GetDataSize() const { return impl_->GetDataSize(); }
|
||||
|
||||
/*
|
||||
* set data size to tensor.
|
||||
*/
|
||||
void Tensor::SetDataSize(uint64_t size) { impl_->SetDataSize(size); }
|
||||
|
||||
/*
|
||||
* calculate data size by tensor shape.
|
||||
*/
|
||||
int64_t Tensor::CalcDataSizeByShape() const { return impl_->CalcDataSizeByShape(); }
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
*/
|
||||
int64_t Tensor::NumElements() const { return impl_->NumElements(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/cpu_proto/tensor_impl.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "proto/cpu_tensor_shape.pb.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_shape_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* get tensor shape value of tensor.
|
||||
*/
|
||||
std::shared_ptr<TensorShape> TensorImpl::GetTensorShape() const {
|
||||
aicpuops::TensorShape *tensor_shape = tensor_->mutable_tensor_shape();
|
||||
if (tensor_shape == nullptr) {
|
||||
KERNEL_LOG_ERROR("Protobuf mutable tensor shape is null.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
TensorShapeImpl *impl = new (std::nothrow) TensorShapeImpl(tensor_shape);
|
||||
if (impl == nullptr) {
|
||||
KERNEL_LOG_ERROR("Create TensorShapeImpl failed.");
|
||||
return std::shared_ptr<TensorShape>(nullptr);
|
||||
}
|
||||
|
||||
auto aicpu_shape = CpuKernelUtils::CreateTensorShape(impl);
|
||||
if (aicpu_shape == nullptr) {
|
||||
delete impl;
|
||||
}
|
||||
return aicpu_shape;
|
||||
}
|
||||
|
||||
/*
|
||||
* set tensor shape value to tensor.
|
||||
*/
|
||||
bool TensorImpl::SetTensorShape(const TensorShape *shape) {
|
||||
KERNEL_CHECK_NULLPTR(shape, false, "Tensor shape is null")
|
||||
|
||||
aicpuops::TensorShape *tensor_shape = tensor_->mutable_tensor_shape();
|
||||
KERNEL_CHECK_NULLPTR(tensor_shape, false, "Protobuf mutable tensor shape is null")
|
||||
auto impl = CpuKernelUtils::GetImpl(shape);
|
||||
KERNEL_CHECK_NULLPTR(impl, false, "Get impl is null")
|
||||
|
||||
auto proto = impl->GetProto();
|
||||
KERNEL_CHECK_NULLPTR(proto, false, "Get proto is null")
|
||||
|
||||
*tensor_shape = *(proto);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* get data type value of tensor.
|
||||
*/
|
||||
DataType TensorImpl::GetDataType() const { return static_cast<DataType>(tensor_->tensor_type()); }
|
||||
|
||||
/*
|
||||
* set data type value to tensor.
|
||||
*/
|
||||
void TensorImpl::SetDataType(DataType type) { tensor_->set_tensor_type(type); }
|
||||
|
||||
/*
|
||||
* get data ptr of tensor.
|
||||
*/
|
||||
void *TensorImpl::GetData() const { return reinterpret_cast<void *>(static_cast<uintptr_t>(tensor_->data_ptr())); }
|
||||
|
||||
/*
|
||||
* set data ptr to tensor.
|
||||
*/
|
||||
void TensorImpl::SetData(void *addr) { tensor_->set_data_ptr(static_cast<uint64_t>(reinterpret_cast<intptr_t>(addr))); }
|
||||
|
||||
/*
|
||||
* get data size of tensor.
|
||||
*/
|
||||
uint64_t TensorImpl::GetDataSize() const { return tensor_->data_size(); }
|
||||
|
||||
/*
|
||||
* set data size to tensor.
|
||||
*/
|
||||
void TensorImpl::SetDataSize(uint64_t size) { tensor_->set_data_size(size); }
|
||||
|
||||
/*
|
||||
* get name of tensor.
|
||||
*/
|
||||
std::string TensorImpl::GetName() const { return tensor_->name(); }
|
||||
|
||||
/*
|
||||
* set name of tensor.
|
||||
*/
|
||||
void TensorImpl::SetName(const std::string &name) { tensor_->set_name(name); }
|
||||
|
||||
/*
|
||||
* calculate data size by tensor shape.
|
||||
*/
|
||||
int64_t TensorImpl::CalcDataSizeByShape() const {
|
||||
int64_t data_size = NumElements();
|
||||
int32_t element_size = GetSizeByDataType(static_cast<DataType>(GetDataType()));
|
||||
if ((data_size < 0) || (element_size < 0)) {
|
||||
KERNEL_LOG_WARN("Get tensor element number[%lld] or element type size[%d] less than 0.", data_size, element_size);
|
||||
return -1;
|
||||
}
|
||||
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(data_size, element_size, data_size, -1);
|
||||
return data_size;
|
||||
}
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
*/
|
||||
int64_t TensorImpl::NumElements() const {
|
||||
auto shape = GetTensorShape();
|
||||
if (shape == nullptr) {
|
||||
KERNEL_LOG_ERROR("Get tensor shape failed.");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return shape->NumElements();
|
||||
}
|
||||
|
||||
aicpuops::Tensor *TensorImpl::GetProto() const { return tensor_.get(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_CPU_PROTO_TENSOR_IMPL_H
|
||||
#define AICPU_CONTEXT_CPU_PROTO_TENSOR_IMPL_H
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_tensor_shape.h"
|
||||
#include "proto/cpu_tensor.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
class TensorImpl {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
TensorImpl(
|
||||
aicpuops::Tensor *tensor, std::function<void(aicpuops::Tensor *)> delFunc = [](aicpuops::Tensor *p) {})
|
||||
: tensor_(tensor, delFunc) {}
|
||||
|
||||
~TensorImpl() = default;
|
||||
TensorImpl(const TensorImpl &) = delete;
|
||||
TensorImpl(TensorImpl &&) = delete;
|
||||
TensorImpl &operator=(const TensorImpl &) = delete;
|
||||
TensorImpl &operator=(TensorImpl &&) = delete;
|
||||
|
||||
/*
|
||||
* set tensor shape value to tensor.
|
||||
* @param shape: tensor shape value need to set to tensor
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SetTensorShape(const TensorShape *shape);
|
||||
|
||||
/*
|
||||
* get tensor shape value of tensor.
|
||||
* @return std::shared_ptr<TensorShape>: tensor shape value of tensor
|
||||
*/
|
||||
std::shared_ptr<TensorShape> GetTensorShape() const;
|
||||
|
||||
/*
|
||||
* set data type value to tensor.
|
||||
* @param type: data type value need to set to tensor
|
||||
*/
|
||||
void SetDataType(DataType type);
|
||||
|
||||
/*
|
||||
* get data type value of tensor.
|
||||
* @return DataType: data type value of tensor
|
||||
*/
|
||||
DataType GetDataType() const;
|
||||
|
||||
/*
|
||||
* set data ptr to tensor.
|
||||
* @param addr: tensor data ptr
|
||||
*/
|
||||
void SetData(void *addr);
|
||||
|
||||
/*
|
||||
* get data ptr of tensor.
|
||||
* @return void *: tensor data ptr
|
||||
*/
|
||||
void *GetData() const;
|
||||
|
||||
/*
|
||||
* set data size to tensor.
|
||||
* @param size: tensor data size
|
||||
*/
|
||||
void SetDataSize(uint64_t size);
|
||||
|
||||
/*
|
||||
* get data size of tensor.
|
||||
* @return uint64_t: tensor data size
|
||||
*/
|
||||
uint64_t GetDataSize() const;
|
||||
|
||||
/*
|
||||
* get name of tensor.
|
||||
* @return std::string: tensor name
|
||||
*/
|
||||
std::string GetName() const;
|
||||
|
||||
/*
|
||||
* set name of tensor.
|
||||
* @param name: tensor name
|
||||
*/
|
||||
void SetName(const std::string &name);
|
||||
|
||||
/*
|
||||
* calculate data size by tensor shape.
|
||||
* @return success->not less than 0, failed->less than 0
|
||||
*/
|
||||
int64_t CalcDataSizeByShape() const;
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
* @return success->not less than 0, unknown->less than 0
|
||||
*/
|
||||
int64_t NumElements() const;
|
||||
|
||||
/*
|
||||
* get tensor proto.
|
||||
*/
|
||||
aicpuops::Tensor *GetProto() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<aicpuops::Tensor> tensor_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_CPU_PROTO_TENSOR_IMPL_H
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/inc/cpu_tensor_shape.h"
|
||||
#include "cpu_kernel/cpu_proto/tensor_shape_impl.h"
|
||||
|
||||
namespace aicpu {
|
||||
TensorShape::TensorShape(TensorShapeImpl *tensorShape) : impl_(tensorShape) {}
|
||||
|
||||
/*
|
||||
* get dims value of tensor shape.
|
||||
*/
|
||||
std::vector<int64_t> TensorShape::GetDimSizes() const { return impl_->GetDimSizes(); }
|
||||
|
||||
/*
|
||||
* set dims value to tensor shape.
|
||||
*/
|
||||
void TensorShape::SetDimSizes(const std::vector<int64_t> &dims) { impl_->SetDimSizes(dims); }
|
||||
|
||||
/*
|
||||
* get format value of tensor shape.
|
||||
*/
|
||||
Format TensorShape::GetFormat() const { return impl_->GetFormat(); }
|
||||
|
||||
/*
|
||||
* set format value to tensor shape.
|
||||
*/
|
||||
void TensorShape::SetFormat(Format format) { impl_->SetFormat(format); }
|
||||
|
||||
/*
|
||||
* get unknown rank value of tensor shape.
|
||||
*/
|
||||
bool TensorShape::GetUnknownRank() const { return impl_->GetUnknownRank(); }
|
||||
|
||||
/*
|
||||
* set unknown rank value to tensor shape.
|
||||
*/
|
||||
void TensorShape::SetUnknownRank(bool unknownRank) { impl_->SetUnknownRank(unknownRank); }
|
||||
|
||||
/*
|
||||
* get dims size of tensor shape.
|
||||
*/
|
||||
int32_t TensorShape::GetDims() const { return impl_->GetDims(); }
|
||||
|
||||
/*
|
||||
* get dim value of tensor shape index dim.
|
||||
*/
|
||||
int64_t TensorShape::GetDimSize(int32_t index) const { return impl_->GetDimSize(index); }
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
*/
|
||||
int64_t TensorShape::NumElements() const { return impl_->NumElements(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/cpu_proto/tensor_shape_impl.h"
|
||||
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
/*
|
||||
* get dims value of tensor shape.
|
||||
*/
|
||||
std::vector<int64_t> TensorShapeImpl::GetDimSizes() const {
|
||||
std::vector<int64_t> ret;
|
||||
for (int32_t i = 0; i < tensor_shape_->dim_size(); i++) {
|
||||
ret.emplace_back(tensor_shape_->dim(i).size());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* set dims value to tensor shape.
|
||||
*/
|
||||
void TensorShapeImpl::SetDimSizes(const std::vector<int64_t> &dims) {
|
||||
tensor_shape_->clear_dim();
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
aicpuops::TensorShape_Dim *aicpu_dims = tensor_shape_->add_dim();
|
||||
KERNEL_CHECK_NULLPTR_VOID(aicpu_dims, "Protobuf add dim is null")
|
||||
aicpu_dims->set_size(dims[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* get format value of tensor shape.
|
||||
*/
|
||||
Format TensorShapeImpl::GetFormat() const { return static_cast<Format>(tensor_shape_->data_format()); }
|
||||
|
||||
/*
|
||||
* set format value to tensor shape.
|
||||
*/
|
||||
void TensorShapeImpl::SetFormat(Format format) { tensor_shape_->set_data_format(format); }
|
||||
|
||||
/*
|
||||
* get unknown rank value of tensor shape.
|
||||
*/
|
||||
bool TensorShapeImpl::GetUnknownRank() const { return tensor_shape_->unknown_rank(); }
|
||||
|
||||
/*
|
||||
* set unknown rank value to tensor shape.
|
||||
*/
|
||||
void TensorShapeImpl::SetUnknownRank(bool unknown_rank) { tensor_shape_->set_unknown_rank(unknown_rank); }
|
||||
|
||||
/*
|
||||
* get dims size of tensor shape.
|
||||
*/
|
||||
int32_t TensorShapeImpl::GetDims() const { return tensor_shape_->dim_size(); }
|
||||
|
||||
/*
|
||||
* get dim value of tensor shape index dim.
|
||||
*/
|
||||
int64_t TensorShapeImpl::GetDimSize(int32_t index) const {
|
||||
if ((index >= GetDims()) || (index < 0)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Dim index[%d] must be not less than 0 and not greater than dims "
|
||||
"size[%d]",
|
||||
index, GetDims());
|
||||
return 0;
|
||||
}
|
||||
|
||||
return tensor_shape_->dim(index).size();
|
||||
}
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
*/
|
||||
int64_t TensorShapeImpl::NumElements() const {
|
||||
int64_t num_elements = 1;
|
||||
for (int32_t i = 0; i < tensor_shape_->dim_size(); i++) {
|
||||
int64_t dim_size = tensor_shape_->dim(i).size();
|
||||
if (dim_size < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, dim_size, num_elements, -1);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
/*
|
||||
* get tensor proto.
|
||||
* @return shared_ptr<TensorShapeProto>:tensor shape proto ptr
|
||||
*/
|
||||
|
||||
aicpuops::TensorShape *TensorShapeImpl::GetProto() const { return tensor_shape_.get(); }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_CPU_PROTO_TENSOR_SHAPE_IMPL_H
|
||||
#define AICPU_CONTEXT_CPU_PROTO_TENSOR_SHAPE_IMPL_H
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
#include "proto/cpu_tensor_shape.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
class TensorShapeImpl {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
TensorShapeImpl(
|
||||
aicpuops::TensorShape *shape,
|
||||
std::function<void(aicpuops::TensorShape *)> del_func = [](aicpuops::TensorShape *p) {})
|
||||
: tensor_shape_(shape, del_func) {}
|
||||
|
||||
~TensorShapeImpl() = default;
|
||||
TensorShapeImpl(const TensorShapeImpl &) = delete;
|
||||
TensorShapeImpl(TensorShapeImpl &&) = delete;
|
||||
TensorShapeImpl &operator=(const TensorShapeImpl &) = delete;
|
||||
TensorShapeImpl &operator=(TensorShapeImpl &&) = delete;
|
||||
|
||||
/*
|
||||
* set format value to tensor shape.
|
||||
* @param format: format value need to set to tensor shape
|
||||
*/
|
||||
void SetFormat(Format format);
|
||||
|
||||
/*
|
||||
* get format value of tensor shape.
|
||||
* @return Format: format value of tensor shape
|
||||
*/
|
||||
Format GetFormat() const;
|
||||
|
||||
/*
|
||||
* get unknown rank value of tensor shape.
|
||||
* @return bool: unknown rank value of tensor shape
|
||||
*/
|
||||
bool GetUnknownRank() const;
|
||||
|
||||
/*
|
||||
* set unknown rank value to tensor shape.
|
||||
* @param unknown_rank: unknown rank value need to set to tensor shape
|
||||
*/
|
||||
void SetUnknownRank(bool unknown_rank);
|
||||
|
||||
/*
|
||||
* set dims value to tensor shape.
|
||||
* @param dims: dims value need to set to tensor shape
|
||||
*/
|
||||
void SetDimSizes(const std::vector<int64_t> &dims);
|
||||
|
||||
/*
|
||||
* get dims value of tensor shape.
|
||||
* @return int32_t: dims value of tensor shape
|
||||
*/
|
||||
std::vector<int64_t> GetDimSizes() const;
|
||||
|
||||
/*
|
||||
* get dim value of tensor shape index dim.
|
||||
* @param index: index dim of tensor shape
|
||||
* @return int64_t: dim value of tensor shape index dim
|
||||
*/
|
||||
int64_t GetDimSize(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get dims size of tensor shape.
|
||||
* @return int32_t: dims size of tensor shape
|
||||
*/
|
||||
int32_t GetDims() const;
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
* @return success->not less than 0, unknown->less than 0
|
||||
*/
|
||||
int64_t NumElements() const;
|
||||
|
||||
/*
|
||||
* get tensor shape proto.
|
||||
*/
|
||||
aicpuops::TensorShape *GetProto() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<aicpuops::TensorShape> tensor_shape_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_CPU_PROTO_TENSOR_SHAPE_IMPL_H
|
|
@ -0,0 +1,405 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/format_transfer/format_transfer_fractal_nz.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
const int64_t kDimDefaultValue = 1;
|
||||
const int kDimSize4D = 4;
|
||||
const size_t kSingleDim = 1;
|
||||
const size_t kNdDimIndexN = 0;
|
||||
const size_t kNdDimIndexH = 1;
|
||||
const size_t kNdDimIndexW = 2;
|
||||
const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz
|
||||
const size_t kNdDimCountBackwardsW = 1;
|
||||
const size_t kNdDimCountBackwardsWH = 2;
|
||||
const size_t kFNzDimCountBackwardsW0 = 1;
|
||||
const size_t kFNzDimCountBackwardsW0H0 = 2;
|
||||
const size_t kFNzDimCountBackwardsW0H0H1 = 3;
|
||||
const size_t kFNzDimCountBackwardsW0H0H1W1 = 4;
|
||||
|
||||
bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; }
|
||||
|
||||
using ShapeVector = std::vector<int64_t>;
|
||||
|
||||
bool CheckShape(Format format, const ShapeVector &shape) {
|
||||
switch (format) {
|
||||
case FORMAT_ND:
|
||||
return IsShapeValid(shape);
|
||||
case FORMAT_NCHW:
|
||||
case FORMAT_NHWC:
|
||||
return CheckShapeValid(shape, kDimSize4D);
|
||||
default:
|
||||
std::string error =
|
||||
"Trans format between " + FmtToStr(FormatToSerialString(format)) + " and [FORMAT_FRACTAL_NZ] is not supported.";
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* After the conversion to two-dimensional matrix, the memory arrangement is
|
||||
* small z and large N.
|
||||
* @src_shape: N*H*W
|
||||
* @dst_shape: N*W1*H1*H0*w0
|
||||
* @return
|
||||
*/
|
||||
uint32_t TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape,
|
||||
ShapeVector &hw_shape) {
|
||||
dst_shape.clear();
|
||||
hw_shape.clear();
|
||||
auto w0 = GetCubeSizeByDataType(data_type);
|
||||
int64_t h0 = kCubeSize;
|
||||
switch (src_shape.size()) {
|
||||
case kSingleDim:
|
||||
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
|
||||
dst_shape.push_back(kDimDefaultValue);
|
||||
dst_shape.push_back(h0);
|
||||
dst_shape.push_back(w0);
|
||||
hw_shape.push_back(kDimDefaultValue);
|
||||
hw_shape.push_back(kDimDefaultValue);
|
||||
hw_shape.push_back(src_shape[kNdDimIndexN]);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
default:
|
||||
auto size = src_shape.size();
|
||||
int64_t times = 1;
|
||||
for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) {
|
||||
dst_shape.push_back(src_shape[i]);
|
||||
times *= src_shape[i];
|
||||
}
|
||||
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
|
||||
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
|
||||
dst_shape.push_back(h0);
|
||||
dst_shape.push_back(w0);
|
||||
hw_shape.push_back(times);
|
||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
|
||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) {
|
||||
ShapeVector expect_src_shape;
|
||||
auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans shape from [%s] to [%s], shape [%s] to [%s], data type [%s] "
|
||||
"failed",
|
||||
FormatToSerialString(args.dst_format).c_str(), FormatToSerialString(args.src_format).c_str(),
|
||||
VectorToString(args.dst_shape).c_str(), VectorToString(args.src_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
return ret;
|
||||
}
|
||||
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) {
|
||||
int size = GetSizeByDataType(args.src_data_type);
|
||||
// data size will not be greater than INT_MAX
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>());
|
||||
if (dst == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans format from [%s] to [%s], can not alloc the memory "
|
||||
"for dst buf [%ld]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(), dst_size);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
// src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0,
|
||||
// respectively. dst_shape_size >= kDimNum4D
|
||||
auto times = hw_shape.at(kNdDimIndexN);
|
||||
auto h = hw_shape.at(kNdDimIndexH);
|
||||
auto w = hw_shape.at(kNdDimIndexW);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.dst_shape.size();
|
||||
auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
|
||||
auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
|
||||
auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0];
|
||||
auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0];
|
||||
auto h1h0 = h1 * h0;
|
||||
auto h1h0w0 = h1h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
// w0 not equal 0
|
||||
auto num_w1 = w / w0;
|
||||
|
||||
for (int64_t times_idx = 0; times_idx < times; times_idx++) {
|
||||
auto times_head = times_idx * w1h1h0w0;
|
||||
auto src_times_head = times_idx * hw;
|
||||
for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
|
||||
auto h1h0_head = times_head + h1h0_idx * w0;
|
||||
auto src_h_head = src_times_head + h1h0_idx * w;
|
||||
for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
|
||||
auto dst_offset = (h1h0_head + w1_idx * h1h0w0) * size;
|
||||
auto src_offset = (src_h_head + w1_idx * w0) * size;
|
||||
auto protected_size = (dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN))
|
||||
? (dst_size - dst_offset)
|
||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
|
||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
||||
static_cast<size_t>(size * w0));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to operate the dst memory at offset [%ld], error-code "
|
||||
"[%d]",
|
||||
dst_offset, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
auto w1_head = num_w1 * w0;
|
||||
for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
|
||||
auto src_w_idx = w1_head + w0_idx;
|
||||
auto dst_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size;
|
||||
auto src_offset = (src_h_head + src_w_idx) * size;
|
||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
|
||||
? dst_size - dst_offset
|
||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
|
||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
||||
static_cast<size_t>(size));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to operate the dst memory at offset [%ld], error-code "
|
||||
"[%d]",
|
||||
dst_offset, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) {
|
||||
int size = GetSizeByDataType(args.src_data_type);
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
if (dst == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans format from [%s] to [%s], can not alloc the memory "
|
||||
"for dst buf [%ld]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(), dst_size);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
auto times = dst_hw_shape.at(kNdDimIndexN);
|
||||
auto h = dst_hw_shape.at(kNdDimIndexH);
|
||||
auto w = dst_hw_shape.at(kNdDimIndexW);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.src_shape.size();
|
||||
auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1];
|
||||
auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1];
|
||||
auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0];
|
||||
auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0];
|
||||
auto h1h0 = h1 * h0;
|
||||
auto h1h0w0 = h1h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
auto num_w1 = w / w0;
|
||||
errno_t ret;
|
||||
|
||||
for (int64_t times_idx = 0; times_idx < times; times_idx++) {
|
||||
auto times_head = times_idx * w1h1h0w0;
|
||||
auto dst_times_head = times_idx * hw;
|
||||
for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
|
||||
auto h1h0_head = times_head + h1h0_idx * w0;
|
||||
auto dst_h_head = dst_times_head + h1h0_idx * w;
|
||||
for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
|
||||
auto src_offset = (h1h0_head + w1_idx * h1h0w0) * size;
|
||||
auto dst_offset = (dst_h_head + w1_idx * w0) * size;
|
||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
|
||||
? dst_size - dst_offset
|
||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
|
||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
||||
static_cast<size_t>(size * w0));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to operate the dst memory at offset [%ld], error-code "
|
||||
"[%d]",
|
||||
dst_offset, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
auto w1_head = num_w1 * w0;
|
||||
for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
|
||||
auto dst_w_idx = w1_head + w0_idx;
|
||||
auto src_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size;
|
||||
auto dst_offset = (dst_h_head + dst_w_idx) * size;
|
||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
|
||||
? dst_size - dst_offset
|
||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
|
||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
|
||||
static_cast<size_t>(size));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to operate the dst memory at offset [%ld], error-code "
|
||||
"[%d]",
|
||||
dst_offset, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (!IsDataTypeSupport(args.src_data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
"type [%s] is not supported",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (!CheckShape(args.src_format, args.src_shape) || !IsShapeValid(args.dst_shape)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
"type [%s] is not supported",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_LOG_INFO(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], dst shape "
|
||||
"[%s], data type [%s]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
ShapeVector expect_shape;
|
||||
ShapeVector hw_shape;
|
||||
auto ret = TransShapeToFracNz(args.src_shape, args.src_data_type, expect_shape, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsTransShapeDstCorrect(args, expect_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return TransFormatFromNdToFracNz(args, result, hw_shape);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type,
|
||||
Format dst_format, ShapeVector &dst_shape, int64_t groups) {
|
||||
if (!IsDataTypeSupport(data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], data type [%s] is not "
|
||||
"supported",
|
||||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str(),
|
||||
VectorToString(src_shape).c_str(), DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (!CheckShape(src_format, src_shape)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], data type [%s] is not "
|
||||
"supported",
|
||||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str(),
|
||||
VectorToString(src_shape).c_str(), DTypeStr(data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
ShapeVector hw_shape;
|
||||
return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (!IsDataTypeSupport(args.src_data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
"type [%s] is not supported",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!IsShapeValid(args.src_shape) || !CheckShape(args.dst_format, args.dst_shape)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
"type [%s] is not supported",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_LOG_INFO(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], dst shape "
|
||||
"[%s], data type [%s]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), VectorToString(args.dst_shape).c_str(),
|
||||
DTypeStr(args.src_data_type).c_str());
|
||||
|
||||
ShapeVector hw_shape;
|
||||
auto ret = CheckShapeRelation(args, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
return TransFormatFromFracNzToNd(args, result, hw_shape);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalNzND::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type,
|
||||
Format dst_format, ShapeVector &dst_shape, int64_t groups) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The shape derivation from [%s] to [%s] is not unique. Trans shape is "
|
||||
"not supported",
|
||||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_ND, FORMAT_FRACTAL_NZ)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_NCHW, FORMAT_FRACTAL_NZ)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNz, FORMAT_NHWC, FORMAT_FRACTAL_NZ)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_ND)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_NCHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalNzND, FORMAT_FRACTAL_NZ, FORMAT_NHWC)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
// transfer from nd to nz
|
||||
class FormatTransferFractalNz : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
|
||||
// transfer nz to nd
|
||||
class FormatTransferFractalNzND : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_
|
|
@ -0,0 +1,285 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/format_transfer/format_transfer_fractal_z.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
#include "cpu_kernel/format_transfer/formats_definitions.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
KernelStatus CheckDataTypeSupport(DataType data_type) {
|
||||
return GetSizeByDataType(data_type) > 0 ? KERNEL_STATUS_OK : KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
/**
|
||||
* FZ represents the weight of convolution,.
|
||||
* After the conversion to two-dimensional matrix, the memory arrangement is
|
||||
* small n and large Z. If 4D(eg.NCHW) is used to represent convolution kernel,
|
||||
* N is width, HWC is height.
|
||||
*
|
||||
* frac_z axes: (C1*H*W, No, Ni, C0), which Ni = 16, C0 = 16/32, No =
|
||||
* Ceil(N/Ni), C1 = Ceil(C/C0)
|
||||
* @return
|
||||
*/
|
||||
|
||||
uint32_t TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
auto c0 = GetCubeSizeByDataType(data_type);
|
||||
if (c0 < 0) {
|
||||
KERNEL_LOG_ERROR("Cube size must greater than or equal to 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c;
|
||||
// For this place , groups is not equal to 0, which had been checked in
|
||||
// [Transdata] entrance.
|
||||
int64_t cout_ori = n / groups;
|
||||
if (cin_ori == 0 || cout_ori == 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Cin_ori, cout_ori must not be equal 0, "
|
||||
"and current cin_ori, cout_ori, groups are [%ld] [%ld] [%ld]",
|
||||
cin_ori, cout_ori, groups);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
// This is equal with c0
|
||||
int64_t cube_k = GetCubeSizeByDataType(data_type);
|
||||
int64_t e_mult = std::min(
|
||||
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), groups);
|
||||
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
|
||||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t g_dim = Ceil(groups, e_mult);
|
||||
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(g_dim * c1_dim * h * w);
|
||||
dst_shape.push_back(n1);
|
||||
dst_shape.push_back(kNiSize);
|
||||
dst_shape.push_back(cube_k);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeNchwToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto n = src_shape.at(kNchwN);
|
||||
auto c = src_shape.at(kNchwC);
|
||||
auto h = src_shape.at(kNchwH);
|
||||
auto w = src_shape.at(kNchwW);
|
||||
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
uint32_t TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto h = src_shape.at(kHwcnH);
|
||||
auto w = src_shape.at(kHwcnW);
|
||||
auto c = src_shape.at(kHwcnC);
|
||||
auto n = src_shape.at(kHwcnN);
|
||||
|
||||
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
uint32_t TransShapeNhwcToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto n = src_shape.at(kNhwcN);
|
||||
auto h = src_shape.at(kNhwcH);
|
||||
auto w = src_shape.at(kNhwcW);
|
||||
auto c = src_shape.at(kNhwcC);
|
||||
|
||||
return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
// Supporting NHWC/NCHW/HWCN <=> FORMAT_FRACTAL_Z (GC1HWN1N0C0),
|
||||
// the final effect achieved is for the data to be distributed diagonally.
|
||||
// For example: When the input filter format is NCHW, calculated the
|
||||
// Correspondence of index between NCHW and FORMAT_FRACTAL_Z , then Convert the
|
||||
// old filter to the new filter, and finally added 0 to the position where there
|
||||
// is no data.
|
||||
uint32_t TransFormatWithGroups(const Format &format_4d, const std::vector<int64_t> &shape_4d, const TransArgs &args,
|
||||
TransResult &result, bool reverse) {
|
||||
int64_t h_dim = 0;
|
||||
int64_t w_dim = 0;
|
||||
int64_t c_dim = 0;
|
||||
int64_t n_dim = 0;
|
||||
int64_t d_dim = 1;
|
||||
if (GetFormatDim(d_dim, h_dim, w_dim, c_dim, n_dim, format_4d, shape_4d) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c_dim;
|
||||
// For this place , groups is not equal to 0, which had been checked in
|
||||
// [Transdata] entrance.
|
||||
int64_t cout_ori = n_dim / args.groups;
|
||||
if (CheckDimOri(cin_ori, cout_ori) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type);
|
||||
int64_t e_mult = std::min(
|
||||
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), args.groups);
|
||||
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
|
||||
int64_t cout_opt = Ceil(e_mult * cout_ori, static_cast<int64_t>(kCubeSize)) * static_cast<int64_t>(kCubeSize);
|
||||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t data_size = GetSizeByDataType(args.src_data_type);
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly.
|
||||
if (dst_size == 0) {
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
KERNEL_CHECK_NULLPTR(dst, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Failed to allcoate memory for dst buf [%lld] when trans "
|
||||
"format from [%s] to [%s]",
|
||||
dst_size, FormatToSerialString(args.src_format).c_str(),
|
||||
FormatToSerialString(args.dst_format).c_str())
|
||||
(void)memset_s(dst.get(), static_cast<size_t>(dst_size), 0, static_cast<size_t>(dst_size));
|
||||
for (int64_t g = 0; g < args.groups; g++) {
|
||||
for (int64_t d = 0; d < d_dim; d++) {
|
||||
for (int64_t c = 0; c < c_dim; c++) {
|
||||
for (int64_t h = 0; h < h_dim; h++) {
|
||||
for (int64_t w = 0; w < w_dim; w++) {
|
||||
for (int64_t n = 0; n < cout_ori; n++) {
|
||||
int64_t e_val = g % e_mult;
|
||||
int64_t dst_ci = e_val * cin_ori + c;
|
||||
int64_t dst_co = e_val * cout_ori + n;
|
||||
int64_t src_co = g * cout_ori + n;
|
||||
int64_t temporary = dst_ci % cube_k;
|
||||
int64_t inx_4d = 0;
|
||||
int64_t inx_fz = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
|
||||
w * cout_opt * cube_k + dst_co * cube_k + temporary;
|
||||
if (format_4d == FORMAT_HWCN) {
|
||||
inx_4d = d * h_dim * w_dim * c_dim * n_dim + h * w_dim * c_dim * n_dim + w * c_dim * n_dim + c * n_dim +
|
||||
src_co;
|
||||
} else if (format_4d == FORMAT_NCHW) {
|
||||
inx_4d = src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim +
|
||||
h * w_dim + w;
|
||||
} else if (format_4d == FORMAT_NHWC) {
|
||||
inx_4d = src_co * d_dim * h_dim * w_dim * c_dim + d * h_dim * w_dim * c_dim + h * w_dim * c_dim +
|
||||
w * c_dim + c;
|
||||
}
|
||||
if (!reverse) {
|
||||
copy_data(args.data, dst, inx_4d, inx_fz, data_size);
|
||||
} else {
|
||||
copy_data(args.data, dst, inx_fz, inx_4d, data_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (args.groups == 0) {
|
||||
KERNEL_LOG_ERROR("Attr[groups] must not be equal to 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_LOG_DEBUG(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], data type "
|
||||
"[%s], dst "
|
||||
"shape [%s], groups [%lld]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), DTypeStr(args.src_data_type).c_str(),
|
||||
VectorToString(args.dst_shape).c_str(), args.groups);
|
||||
|
||||
if (((args.src_format == FORMAT_NHWC) || (args.src_format == FORMAT_HWCN) || (args.src_format == FORMAT_NCHW)) &&
|
||||
args.dst_format == FORMAT_FRACTAL_Z) {
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!IsTransShapeDstCorrect(args, expect_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return TransFormatWithGroups(args.src_format, args.src_shape, args, result, false);
|
||||
} else if (((args.dst_format == FORMAT_NHWC) || (args.dst_format == FORMAT_HWCN) ||
|
||||
(args.dst_format == FORMAT_NCHW)) &&
|
||||
args.src_format == FORMAT_FRACTAL_Z) {
|
||||
std::vector<int64_t> expect_input_shape;
|
||||
auto ret =
|
||||
TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, expect_input_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
if ((!args.src_shape.empty()) && (args.src_shape != expect_input_shape)) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return TransFormatWithGroups(args.dst_format, args.dst_shape, args, result, true);
|
||||
}
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
if (CheckDataTypeSupport(data_type) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (src_format == FORMAT_NHWC && GetPrimaryFormat(static_cast<int32_t>(dst_format)) == FORMAT_FRACTAL_Z) {
|
||||
return TransShapeNhwcToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
if ((src_format == FORMAT_HWCN) &&
|
||||
(GetPrimaryFormat(static_cast<int32_t>(dst_format)) == static_cast<int32_t>(FORMAT_FRACTAL_Z))) {
|
||||
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
if (src_format == FORMAT_NCHW && GetPrimaryFormat(static_cast<int32_t>(dst_format)) == FORMAT_FRACTAL_Z) {
|
||||
return TransShapeNchwToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_NCHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_HWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_NHWC)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_FRACTAL_Z_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_FRACTAL_Z_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
class FormatTransferFractalZ : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_
|
|
@ -0,0 +1,286 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_fractalz_3d.h"
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
#include "cpu_kernel/format_transfer/formats_definitions.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
KernelStatus CheckDataTypeSupport(DataType data_type) {
|
||||
return GetSizeByDataType(data_type) > 0 ? KERNEL_STATUS_OK : KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
/**
|
||||
* FZ represents the weight of convolution,.
|
||||
* After the conversion to two-dimensional matrix, the memory arrangement is
|
||||
* small n and large Z. If 4D(eg.NCHW) is used to represent convolution kernel,
|
||||
* N is width, HWC is height.
|
||||
*
|
||||
* frac_z_3d axes: (C1 * H* W * D, N1, Ni, C0), which Ni = 16, C0 = 16 / 32, No =
|
||||
* Ceil(N / Ni), C1 = Ceil(C / C0)
|
||||
* @return
|
||||
*/
|
||||
|
||||
uint32_t TransShapeToFz3DWithGroups(int64_t n, int64_t c, int64_t d, int64_t h, int64_t w, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
auto c0 = GetCubeSizeByDataType(data_type);
|
||||
if (c0 < 0) {
|
||||
KERNEL_LOG_ERROR("Cube size must greater than or equal to 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c;
|
||||
// For this place , groups is not equal to 0, which had been checked in [Transdata] entrance.
|
||||
int64_t cout_ori = n / groups;
|
||||
if (cin_ori == 0 || cout_ori == 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Check param Failed, cin_ori, cout_ori must not be equal 0, "
|
||||
"and current cin_ori, cout_ori, groups are [%ld] [%ld] [%ld]",
|
||||
cin_ori, cout_ori, groups);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cube_k = GetCubeSizeByDataType(data_type);
|
||||
int64_t e_mult = std::min(
|
||||
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), groups);
|
||||
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
|
||||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t dim_g = Ceil(groups, e_mult);
|
||||
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(dim_g * c1_dim * d * h * w);
|
||||
dst_shape.push_back(n1);
|
||||
dst_shape.push_back(kNiSize);
|
||||
dst_shape.push_back(cube_k);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeNcdhwToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(kNcdhwDimsNum))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto n = src_shape.at(kNcdhwN);
|
||||
auto c = src_shape.at(kNcdhwC);
|
||||
auto d = src_shape.at(kNcdhwD);
|
||||
auto h = src_shape.at(kNcdhwH);
|
||||
auto w = src_shape.at(kNcdhwW);
|
||||
return TransShapeToFz3DWithGroups(n, c, d, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
uint32_t TransShapeDhwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(kDhwcnDimsNum))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto d = src_shape.at(kDhwcnD);
|
||||
auto h = src_shape.at(kDhwcnH);
|
||||
auto w = src_shape.at(kDhwcnW);
|
||||
auto c = src_shape.at(kDhwcnC);
|
||||
auto n = src_shape.at(kDhwcnN);
|
||||
|
||||
return TransShapeToFz3DWithGroups(n, c, d, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
uint32_t TransShapeNdhwcToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNdhwcDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto n = src_shape.at(kNdhwcN);
|
||||
auto d = src_shape.at(kNdhwcD);
|
||||
auto h = src_shape.at(kNdhwcH);
|
||||
auto w = src_shape.at(kNdhwcW);
|
||||
auto c = src_shape.at(kNdhwcC);
|
||||
|
||||
return TransShapeToFz3DWithGroups(n, c, d, h, w, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
// Supporting NCDHW, DHWCN, NDHWC converte to FORMAT_FRACTAL_Z_3D (GDC1HWN1N0C0),
|
||||
// the final effect achieved is for the data to be distributed diagonally.
|
||||
// For example: When the input filter format is NCDHW, calculated the Correspondence of
|
||||
// index between NCDHW and FORMAT_FRACTAL_Z_3D , then Convert the old filter to the new
|
||||
// filter, and finally added 0 to the position where there is no data.
|
||||
uint32_t TransFormatWithGroups(const Format &format_5d, const std::vector<int64_t> &shape_5d, const TransArgs &args,
|
||||
TransResult &result, bool reverse) {
|
||||
int64_t h_dim = 0;
|
||||
int64_t w_dim = 0;
|
||||
int64_t c_dim = 0;
|
||||
int64_t n_dim = 0;
|
||||
int64_t d_dim = 0;
|
||||
if (GetFormatDim(d_dim, h_dim, w_dim, c_dim, n_dim, format_5d, shape_5d) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c_dim;
|
||||
// For this place , groups is not equal to 0, which had been checked in [Transdata] entrance.
|
||||
int64_t cout_ori = n_dim / args.groups;
|
||||
if (CheckDimOri(cin_ori, cout_ori) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type);
|
||||
int64_t e_mult = std::min(
|
||||
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)), args.groups);
|
||||
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
|
||||
int64_t cout_opt = Ceil(e_mult * cout_ori, static_cast<int64_t>(kCubeSize)) * static_cast<int64_t>(kCubeSize);
|
||||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t data_size = GetSizeByDataType(args.src_data_type);
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly.
|
||||
if (dst_size == 0) {
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
KERNEL_CHECK_NULLPTR(dst, KERNEL_STATUS_PARAM_INVALID,
|
||||
"Failed to allcoate memory for dst buf [%lld] when trans "
|
||||
"format from [%s] to [%s]",
|
||||
dst_size, FormatToSerialString(args.src_format).c_str(),
|
||||
FormatToSerialString(args.dst_format).c_str())
|
||||
(void)memset_s(dst.get(), static_cast<size_t>(dst_size), 0, static_cast<size_t>(dst_size));
|
||||
for (int64_t g = 0; g < args.groups; g++) {
|
||||
for (int64_t d = 0; d < d_dim; d++) {
|
||||
for (int64_t c = 0; c < c_dim; c++) {
|
||||
for (int64_t h = 0; h < h_dim; h++) {
|
||||
for (int64_t w = 0; w < w_dim; w++) {
|
||||
for (int64_t n = 0; n < cout_ori; n++) {
|
||||
int64_t e_val = g % e_mult;
|
||||
int64_t dst_ci = e_val * cin_ori + c;
|
||||
int64_t dst_co = e_val * cout_ori + n;
|
||||
int64_t src_co = g * cout_ori + n;
|
||||
int64_t temporary = dst_ci % cube_k;
|
||||
int64_t index_5d = 0;
|
||||
int64_t index_fz = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
|
||||
w * cout_opt * cube_k + dst_co * cube_k + temporary;
|
||||
if (format_5d == FORMAT_DHWCN) {
|
||||
index_5d = d * h_dim * w_dim * c_dim * n_dim + h * w_dim * c_dim * n_dim + w * c_dim * n_dim +
|
||||
c * n_dim + src_co;
|
||||
} else if (format_5d == FORMAT_NCDHW) {
|
||||
index_5d = src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim +
|
||||
h * w_dim + w;
|
||||
} else if (format_5d == FORMAT_NDHWC) {
|
||||
index_5d = src_co * d_dim * h_dim * w_dim * c_dim + d * h_dim * w_dim * c_dim + h * w_dim * c_dim +
|
||||
w * c_dim + c;
|
||||
}
|
||||
if (!reverse) {
|
||||
copy_data(args.data, dst, index_5d, index_fz, data_size);
|
||||
} else {
|
||||
copy_data(args.data, dst, index_fz, index_5d, data_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
KERNEL_LOG_DEBUG(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], data type "
|
||||
"[%s], dst "
|
||||
"shape [%s]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), DTypeStr(args.src_data_type).c_str(),
|
||||
VectorToString(args.dst_shape).c_str());
|
||||
|
||||
if ((args.groups) == 0) {
|
||||
KERNEL_LOG_ERROR("Attr[groups] must not be equal 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (((args.src_format == FORMAT_NDHWC) || (args.src_format == FORMAT_DHWCN) || (args.src_format == FORMAT_NCDHW)) &&
|
||||
args.dst_format == FORMAT_FRACTAL_Z_3D) {
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsTransShapeDstCorrect(args, expect_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return TransFormatWithGroups(args.src_format, args.src_shape, args, result, false);
|
||||
} else if (((args.dst_format == FORMAT_NDHWC) || (args.dst_format == FORMAT_DHWCN) ||
|
||||
(args.dst_format == FORMAT_NCDHW)) &&
|
||||
args.src_format == FORMAT_FRACTAL_Z_3D) {
|
||||
std::vector<int64_t> expect_input_shape;
|
||||
auto ret =
|
||||
TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, expect_input_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
if ((!args.src_shape.empty()) && (args.src_shape != expect_input_shape)) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return TransFormatWithGroups(args.dst_format, args.dst_shape, args, result, true);
|
||||
}
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalz3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
if (CheckDataTypeSupport(data_type) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (src_format == FORMAT_NDHWC &&
|
||||
GetPrimaryFormat(static_cast<int32_t>(dst_format)) == static_cast<int32_t>(FORMAT_FRACTAL_Z_3D)) {
|
||||
return TransShapeNdhwcToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
if ((src_format == FORMAT_DHWCN) &&
|
||||
GetPrimaryFormat(static_cast<int32_t>(dst_format)) == static_cast<int32_t>(FORMAT_FRACTAL_Z_3D)) {
|
||||
return TransShapeDhwcnToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
if (src_format == FORMAT_NCDHW &&
|
||||
GetPrimaryFormat(static_cast<int32_t>(dst_format)) == static_cast<int32_t>(FORMAT_FRACTAL_Z_3D)) {
|
||||
return TransShapeNcdhwToFzWithGroups(src_shape, data_type, dst_shape, groups);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_NCDHW, FORMAT_FRACTAL_Z_3D)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_NDHWC, FORMAT_FRACTAL_Z_3D)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_NCDHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_DHWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_NDHWC)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_FRACTAL_Z_3D_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_FRACTAL_Z_3D_H
|
||||
|
||||
#include <vector>
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
class FormatTransferFractalz3D : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_NZ_H_
|
|
@ -0,0 +1,209 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/format_transfer/format_transfer_ndc1hwc0.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
#include "cpu_kernel/format_transfer/formats_definitions.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
std::map<Format, std::string> kFormatTable = {
|
||||
{FORMAT_NCDHW, "NCDHW"},
|
||||
{FORMAT_NDHWC, "NDHWC"},
|
||||
};
|
||||
|
||||
KernelStatus CheckDataTypeSupport(DataType data_type) {
|
||||
return GetSizeByDataType(data_type) > 0 ? KERNEL_STATUS_OK : KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
void TransSrcDataToDstData(const TransArgs &args, const std::vector<int64_t> &shape_ndhwc,
|
||||
std::shared_ptr<uint8_t> &dst, int64_t c0, int32_t data_size) {
|
||||
const int64_t n = shape_ndhwc[0];
|
||||
const int64_t d = shape_ndhwc[1];
|
||||
const int64_t h = shape_ndhwc[2];
|
||||
const int64_t w = shape_ndhwc[3];
|
||||
const int64_t c = shape_ndhwc[4];
|
||||
// c0 is definitely a number greater than 0
|
||||
const int64_t c1 = ((c - 1) / c0) + 1;
|
||||
const int64_t hw = h * w;
|
||||
const int64_t dhw = d * hw;
|
||||
const int64_t dhwc = dhw * c;
|
||||
const int64_t hwc0 = hw * c0;
|
||||
const int64_t c1hwc0 = c1 * hwc0;
|
||||
const int64_t dc1hwc0 = d * c1hwc0;
|
||||
const int64_t ndhwc = n * dhwc;
|
||||
int64_t src_index = 0;
|
||||
|
||||
for (int64_t ndhwc_idx = 0; ndhwc_idx < ndhwc; ++ndhwc_idx) {
|
||||
const int64_t n_idx = ndhwc_idx / dhwc;
|
||||
const int64_t dhw_idx = ndhwc_idx % dhwc / c;
|
||||
const int64_t c_idx = ndhwc_idx % c;
|
||||
const int64_t dst_index =
|
||||
n_idx * dc1hwc0 + (dhw_idx / hw) * c1hwc0 + (c_idx / c0) * hwc0 + (dhw_idx % hw) * c0 + c_idx % c0;
|
||||
src_index = n_idx * dhwc + c_idx * dhw + dhw_idx;
|
||||
if (args.src_format == FORMAT_NDHWC) {
|
||||
src_index = n_idx * dhwc + dhw_idx * c + c_idx;
|
||||
}
|
||||
uint8_t *dst_data = dst.get() + dst_index * data_size;
|
||||
const uint8_t *src_data = args.data + src_index * data_size;
|
||||
for (int64_t index = 0; index < data_size; ++index) {
|
||||
*dst_data++ = *src_data++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t TransDstDataToNdc1hwc0(const TransArgs &args, TransResult &result) {
|
||||
const int32_t data_size = GetSizeByDataType(args.src_data_type);
|
||||
const auto dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly
|
||||
if (dst_size == 0) {
|
||||
result.length = 0;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
if (dst == nullptr) {
|
||||
KERNEL_LOG_ERROR("Failed to allocate memory for dst buf [%ld] when trans format from [%s] to [%s]", dst_size,
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
errno_t ret = memset_s(dst.get(), static_cast<size_t>(dst_size), 0, static_cast<size_t>(dst_size));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR("memset failed, ret is [%d]", ret);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto iter = kFormatTable.find(args.src_format);
|
||||
if (iter == kFormatTable.end()) {
|
||||
KERNEL_LOG_ERROR("src_format is wrong, now format is [%d]", static_cast<int32_t>(args.src_format));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::string cur_format = iter->second;
|
||||
size_t n_index = cur_format.find('N');
|
||||
size_t d_index = cur_format.find('D');
|
||||
size_t h_index = cur_format.find('H');
|
||||
size_t w_index = cur_format.find('W');
|
||||
size_t c_index = cur_format.find('C');
|
||||
std::vector<int64_t> shape_ndhwc;
|
||||
shape_ndhwc.push_back(args.src_shape.at(n_index));
|
||||
shape_ndhwc.push_back(args.src_shape.at(d_index));
|
||||
shape_ndhwc.push_back(args.src_shape.at(h_index));
|
||||
shape_ndhwc.push_back(args.src_shape.at(w_index));
|
||||
shape_ndhwc.push_back(args.src_shape.at(c_index));
|
||||
const int64_t c0 = GetCubeSizeByDataType(args.src_data_type);
|
||||
if (c0 <= 0) {
|
||||
KERNEL_LOG_ERROR("Failed to get c0, c0 is [%ld]", c0);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
TransSrcDataToDstData(args, shape_ndhwc, dst, c0, data_size);
|
||||
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeToNdc1hwc0(const std::vector<int64_t> &src_shape, const Format &src_format,
|
||||
const DataType &data_type, std::vector<int64_t> &dst_shape) {
|
||||
auto iter = kFormatTable.find(src_format);
|
||||
if (iter == kFormatTable.end()) {
|
||||
KERNEL_LOG_ERROR("src_format is wrong, now format is [%d]", static_cast<int32_t>(src_format));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::string cur_format = iter->second;
|
||||
size_t n_index = cur_format.find('N');
|
||||
size_t d_index = cur_format.find('D');
|
||||
size_t h_index = cur_format.find('H');
|
||||
size_t w_index = cur_format.find('W');
|
||||
size_t c_index = cur_format.find('C');
|
||||
const int64_t c0 = GetCubeSizeByDataType(data_type);
|
||||
if (c0 <= 0) {
|
||||
KERNEL_LOG_ERROR("Failed to get c0, c0 is [%ld]", c0);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(cur_format.length()))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(src_shape.at(n_index));
|
||||
dst_shape.push_back(src_shape.at(d_index));
|
||||
dst_shape.push_back(Ceil(src_shape.at(c_index), c0));
|
||||
dst_shape.push_back(src_shape.at(h_index));
|
||||
dst_shape.push_back(src_shape.at(w_index));
|
||||
dst_shape.push_back(c0);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferNdc1hwc0::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
KERNEL_LOG_INFO(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], data type [%s], dst "
|
||||
"shape [%s]",
|
||||
FormatToSerialString(args.src_format).c_str(), FormatToSerialString(args.dst_format).c_str(),
|
||||
VectorToString(args.src_shape).c_str(), DTypeStr(args.src_data_type).c_str(),
|
||||
VectorToString(args.dst_shape).c_str());
|
||||
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsTransShapeDstCorrect(args, expect_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return TransDstDataToNdc1hwc0(args, result);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferNdc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
(void)dst_format;
|
||||
(void)groups;
|
||||
if (CheckDataTypeSupport(data_type) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (src_format != FORMAT_NCDHW && src_format != FORMAT_NDHWC) {
|
||||
KERNEL_LOG_ERROR("The current format is not supported, src_format is [%s]",
|
||||
FormatToSerialString(src_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return TransShapeToNdc1hwc0(src_shape, src_format, data_type, dst_shape);
|
||||
}
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferNdc1hwc0, FORMAT_NCDHW, FORMAT_NDC1HWC0)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferNdc1hwc0, FORMAT_NDHWC, FORMAT_NDC1HWC0)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_NDC1HWC0_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_NDC1HWC0_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
class FormatTransferNdc1hwc0 : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_NDC1HWC0_H
|
|
@ -0,0 +1,274 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/format_transfer/format_transfer_transpose.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
#include "cpu_kernel/format_transfer/formats_definitions.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
|
||||
{FORMAT_NCHW,
|
||||
{{FORMAT_NHWC, std::vector<int64_t>({kNchwN, kNchwH, kNchwW, kNchwC})},
|
||||
{FORMAT_HWCN, std::vector<int64_t>({kNchwH, kNchwW, kNchwC, kNchwN})},
|
||||
{FORMAT_CHWN, std::vector<int64_t>({kNchwC, kNchwH, kNchwW, kNchwN})}}},
|
||||
{FORMAT_NHWC,
|
||||
{{FORMAT_NCHW, std::vector<int64_t>({kNhwcN, kNhwcC, kNhwcH, kNhwcW})},
|
||||
{FORMAT_CHWN, std::vector<int64_t>({kNhwcC, kNhwcH, kNhwcW, kNhwcN})},
|
||||
{FORMAT_HWCN, std::vector<int64_t>({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}},
|
||||
{FORMAT_HWCN,
|
||||
{{FORMAT_NCHW, std::vector<int64_t>({kHwcnN, kHwcnC, kHwcnH, kHwcnW})},
|
||||
{FORMAT_NHWC, std::vector<int64_t>({kHwcnN, kHwcnH, kHwcnW, kHwcnC})},
|
||||
{FORMAT_CHWN, std::vector<int64_t>({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}},
|
||||
{FORMAT_CHWN,
|
||||
{{FORMAT_NCHW, std::vector<int64_t>({kChwnN, kChwnC, kChwnH, kChwnW})},
|
||||
{FORMAT_NHWC, std::vector<int64_t>({kChwnN, kChwnH, kChwnW, kChwnC})},
|
||||
{FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}},
|
||||
};
|
||||
|
||||
bool ShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {
|
||||
if (src_shape.empty()) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, src shape is empty");
|
||||
return false;
|
||||
}
|
||||
for (auto dim : src_shape) {
|
||||
if (dim < 0) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, negative dim [%d] in src shape [%s]", dim,
|
||||
FmtToStr(VectorToString(src_shape)).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (perm_arg.size() != src_shape.size()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to transpose, the size of src shape [%s] and perm arg [%s] are "
|
||||
"different",
|
||||
FmtToStr(src_shape.size()).c_str(), FmtToStr(perm_arg.size()).c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> exists(perm_arg.size());
|
||||
for (auto perm : perm_arg) {
|
||||
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, invalid perm [%s], perm arg [%s]", FmtToStr(perm).c_str(),
|
||||
FmtToStr(VectorToString(perm_arg)).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg) {
|
||||
if (src == nullptr) {
|
||||
KERNEL_LOG_ERROR("Src should not be nullptr");
|
||||
return false;
|
||||
}
|
||||
if (GetSizeByDataType(src_data_type) < 0) {
|
||||
KERNEL_LOG_ERROR("The data type [%s] is not support", DTypeStr(src_data_type).c_str());
|
||||
return false;
|
||||
}
|
||||
return ShapeArgValid(src_shape, perm_arg);
|
||||
}
|
||||
|
||||
void GenHeads(const std::vector<int64_t> &shape, std::vector<int64_t> &heads) {
|
||||
heads.resize(shape.size());
|
||||
heads[shape.size() - 1] = 1;
|
||||
for (auto i = static_cast<int64_t>(shape.size() - 2); i >= 0; --i) {
|
||||
heads[i] = shape[i + 1] * heads[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
int64_t GenOffset(const std::vector<int64_t> &offsets, const std::vector<int64_t> &indexes) {
|
||||
int64_t offset = 0;
|
||||
for (size_t i = 0; i < indexes.size(); ++i) {
|
||||
offset += offsets[i] * indexes[i];
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
void AddOne(const std::vector<int64_t> &shape, std::vector<int64_t> &indexes) {
|
||||
size_t i = indexes.size() - 1;
|
||||
indexes[i]++;
|
||||
while (i > 0) {
|
||||
if (indexes[i] >= shape[i]) {
|
||||
indexes[i] = 0;
|
||||
indexes[i - 1]++;
|
||||
--i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TransShapeByPerm(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg,
|
||||
std::vector<int64_t> &dst_shape) {
|
||||
dst_shape.resize(src_shape.size());
|
||||
for (size_t i = 0; i < perm_arg.size(); ++i) {
|
||||
dst_shape[i] = src_shape[perm_arg[i]];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result) {
|
||||
if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int64_t> dst_shape;
|
||||
TransShapeByPerm(src_shape, perm_arg, dst_shape);
|
||||
std::vector<int64_t> src_origin_ordered_heads;
|
||||
GenHeads(src_shape, src_origin_ordered_heads);
|
||||
std::vector<int64_t> src_heads;
|
||||
TransShapeByPerm(src_origin_ordered_heads, perm_arg, src_heads);
|
||||
|
||||
int64_t dst_ele_num = GetItemNumByShape(dst_shape);
|
||||
int64_t data_size = GetSizeByDataType(src_data_type);
|
||||
int64_t dst_size = data_size * dst_ele_num;
|
||||
|
||||
KERNEL_LOG_INFO(
|
||||
"Begin to transpose, src shape [%s], perm arg [%s], dst shape [%s], data "
|
||||
"type [%s]",
|
||||
VectorToString(src_shape).c_str(), VectorToString(perm_arg).c_str(), VectorToString(dst_shape).c_str(),
|
||||
DTypeStr(src_data_type).c_str());
|
||||
if (dst_ele_num == 0) {
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
if (dst == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to allcoate memory for dst buf [%ld] when transpsose from [%s] "
|
||||
"to [%s]",
|
||||
dst_size, VectorToString(src_shape).c_str(), VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t dst_index = 0;
|
||||
std::vector<int64_t> dst_indexes(dst_shape.size());
|
||||
while (dst_index < dst_ele_num) {
|
||||
auto src_offset = GenOffset(src_heads, dst_indexes) * data_size;
|
||||
auto dst_offset_bytes = dst_index * data_size;
|
||||
auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
|
||||
? dst_size - dst_offset_bytes
|
||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
|
||||
auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset,
|
||||
static_cast<size_t>(data_size));
|
||||
if (ret != EOK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to transpose, src shape [%s], perm arg [%s], dst shape [%s], "
|
||||
"failed to write to dst offset [%ld], current dim offset [%s]",
|
||||
VectorToString(src_shape).c_str(), VectorToString(perm_arg).c_str(), VectorToString(dst_shape).c_str(),
|
||||
dst_offset_bytes, VectorToString(dst_indexes).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
AddOne(dst_shape, dst_indexes);
|
||||
++dst_index;
|
||||
}
|
||||
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> &src_shape,
|
||||
const std::vector<int64_t> &dst_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result) {
|
||||
if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int64_t> expected_shape;
|
||||
TransShapeByPerm(src_shape, perm_arg, expected_shape);
|
||||
if (dst_shape != expected_shape) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans axis for perm_arg [%s], invalid dst shape [%s], "
|
||||
"expect [%s]",
|
||||
VectorToString(perm_arg).c_str(), VectorToString(dst_shape).c_str(), VectorToString(expected_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return Transpose(data, src_shape, src_data_type, perm_arg, result);
|
||||
}
|
||||
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) {
|
||||
auto dst_iter = perm_args.find(src_format);
|
||||
if (dst_iter == perm_args.end()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans shape , do not support transpose from format [%s] to "
|
||||
"[%s]",
|
||||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
auto iter = dst_iter->second.find(dst_format);
|
||||
if (iter == dst_iter->second.end()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans shape , do not support transpose from format [%s] to "
|
||||
"[%s]",
|
||||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
perm = iter->second;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
std::vector<int64_t> expected_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!IsTransShapeDstCorrect(args, expected_shape)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return Transpose(args.data, args.src_shape, args.src_data_type, perm_args[args.src_format][args.dst_format], result);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
std::vector<int64_t> perm_arg;
|
||||
if (GetPermByForamt(src_format, dst_format, perm_arg) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (!ShapeArgValid(src_shape, perm_arg)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
TransShapeByPerm(src_shape, perm_arg, dst_shape);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_NHWC)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_HWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NCHW, FORMAT_CHWN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_NCHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_CHWN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_NHWC, FORMAT_HWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_NCHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_NHWC)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_HWCN, FORMAT_CHWN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_NCHW)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_NHWC)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferTranspose, FORMAT_CHWN, FORMAT_HWCN)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_TRANSFER_TRANSPOSE_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_TRANSFER_TRANSPOSE_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result);
|
||||
|
||||
uint32_t TransposeWithShapeCheck(const uint8_t *src, const std::vector<int64_t> &src_shape,
|
||||
const std::vector<int64_t> &dst_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result);
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm);
|
||||
class FormatTransferTranspose : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_TRANSFER_H_
|
|
@ -0,0 +1,211 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cpu_kernel/format_transfer/format_transfer_utils.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
#include "cpu_kernel/format_transfer/formats_definitions.h"
|
||||
#include "utils/kernel_util.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
bool IsShapeValid(const std::vector<int64_t> &shape) {
|
||||
if (shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
int64_t num = 1;
|
||||
for (auto dim : shape) {
|
||||
if (dim < 0) {
|
||||
std::string error = "Invalid negative dims in the shape " + FmtToStr(VectorToString(shape));
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
if (dim != 0 && kShapeItemNumMAX / dim < num) {
|
||||
std::string error = "Shape overflow, the total count should be less than " + FmtToStr(kShapeItemNumMAX);
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
num *= dim;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) {
|
||||
if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) {
|
||||
std::string error = "Invalid shape, dims num " + FmtToStr(shape.size()) + ", expect " + FmtToStr(expect_dims);
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
return IsShapeValid(shape);
|
||||
}
|
||||
|
||||
int64_t GetCubeSizeByDataType(DataType data_type) {
|
||||
// Current cube does not support 4 bytes and longer data
|
||||
auto size = GetSizeByDataType(data_type);
|
||||
if (size <= 0) {
|
||||
std::string error = "Failed to get cube size, the data type " + FmtToStr(DTypeStr(data_type)) + " is invalid";
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return -1;
|
||||
} else if (size == 1) {
|
||||
return kCubeSize * 2; // 32 bytes cube size
|
||||
} else {
|
||||
return kCubeSize;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) {
|
||||
if (args.src_shape != expect_shape) {
|
||||
std::string error = "Failed to trans format from" + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
FmtToStr(FormatToSerialString(args.dst_format)) + ", invalid relationship between src shape " +
|
||||
FmtToStr(VectorToString(args.src_shape)) + " and dst " +
|
||||
FmtToStr(VectorToString(args.dst_shape));
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) {
|
||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) {
|
||||
std::string error = "Failed to trans format from " + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
FmtToStr(FormatToSerialString(args.dst_format)) + ", the dst shape" +
|
||||
FmtToStr(VectorToString(args.dst_shape)) + " is invalid, expect" +
|
||||
FmtToStr(VectorToString(expect_shape));
|
||||
KERNEL_LOG_ERROR("%s", error.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape) {
|
||||
// shape will not be greater than INT_MAX
|
||||
int64_t num = 1;
|
||||
for (auto dim : shape) {
|
||||
num *= dim;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) {
|
||||
auto transfer = BuildFormatTransfer(args);
|
||||
if (transfer == nullptr) {
|
||||
std::string error = "Failed to trans data from format " + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
FmtToStr(FormatToSerialString(args.dst_format));
|
||||
KERNEL_LOG_WARN("%s", error.c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto src_shape_size = GetItemNumByShape(args.src_shape);
|
||||
if (args.data == nullptr && src_shape_size != 0) {
|
||||
KERNEL_LOG_WARN("Invalid input null data");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return transfer->TransFormat(args, result);
|
||||
}
|
||||
|
||||
int64_t Measure(int64_t x, int64_t y) {
|
||||
int64_t z = y;
|
||||
while (x % y != 0) {
|
||||
z = x % y;
|
||||
x = y;
|
||||
y = z;
|
||||
}
|
||||
return z;
|
||||
}
|
||||
// least common multiple
|
||||
int64_t Lcm(int64_t a, int64_t b) {
|
||||
if (b == 0) {
|
||||
return -1;
|
||||
}
|
||||
int64_t temp = (a * b) / (Measure(a, b));
|
||||
return temp;
|
||||
}
|
||||
|
||||
void copy_data(const uint8_t *input_data, std::shared_ptr<uint8_t> dst, int64_t src_index, int64_t dst_index,
|
||||
int64_t data_size) {
|
||||
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_index * data_size);
|
||||
const char *src_data = reinterpret_cast<const char *>(input_data + src_index * data_size);
|
||||
for (int64_t index = 0; index < data_size; index++) {
|
||||
*dst_data++ = *src_data++;
|
||||
}
|
||||
}
|
||||
|
||||
KernelStatus CheckDimOri(int64_t cin_ori, int64_t cout_ori) {
|
||||
if (cin_ori == 0 || cout_ori == 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Cin_ori, cout_ori must not be equal 0, and current cin_ori is [%ld], "
|
||||
"cout_ori is [%ld]",
|
||||
cin_ori, cout_ori);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
KernelStatus GetFormatDim(int64_t &d_dim, int64_t &h_dim, int64_t &w_dim, int64_t &c_dim, int64_t &n_dim,
|
||||
const Format &input_format, const std::vector<int64_t> &dims) {
|
||||
if (input_format == FORMAT_NCDHW) {
|
||||
n_dim = dims[kNcdhwN];
|
||||
c_dim = dims[kNcdhwC];
|
||||
d_dim = dims[kNcdhwD];
|
||||
h_dim = dims[kNcdhwH];
|
||||
w_dim = dims[kNcdhwW];
|
||||
} else if (input_format == FORMAT_DHWCN) {
|
||||
d_dim = dims[kDhwcnD];
|
||||
h_dim = dims[kDhwcnH];
|
||||
w_dim = dims[kDhwcnW];
|
||||
c_dim = dims[kDhwcnC];
|
||||
n_dim = dims[kDhwcnN];
|
||||
} else if (input_format == FORMAT_NDHWC) {
|
||||
n_dim = dims[kNdhwcN];
|
||||
d_dim = dims[kNdhwcD];
|
||||
h_dim = dims[kNdhwcH];
|
||||
w_dim = dims[kNdhwcW];
|
||||
c_dim = dims[kNdhwcC];
|
||||
} else if (input_format == FORMAT_NHWC) {
|
||||
n_dim = dims[kNhwcN];
|
||||
h_dim = dims[kNhwcH];
|
||||
d_dim = 1;
|
||||
w_dim = dims[kNhwcW];
|
||||
c_dim = dims[kNhwcC];
|
||||
} else if (input_format == FORMAT_NCHW) {
|
||||
n_dim = dims[kNchwN];
|
||||
c_dim = dims[kNchwC];
|
||||
h_dim = dims[kNchwH];
|
||||
w_dim = dims[kNchwW];
|
||||
d_dim = 1;
|
||||
} else if (input_format == FORMAT_HWCN) {
|
||||
h_dim = dims[kHwcnH];
|
||||
w_dim = dims[kHwcnW];
|
||||
c_dim = dims[kHwcnC];
|
||||
n_dim = dims[kHwcnN];
|
||||
d_dim = 1;
|
||||
} else {
|
||||
KERNEL_LOG_WARN(
|
||||
"Format is not FORMAT_DHWCN or FORMAT_NDHWC or FORMAT_NCDHW or "
|
||||
"FORMAT_NHWC or FORMAT_NCHW or FORMAT_HWCN, current input "
|
||||
"format is [%d]",
|
||||
static_cast<int32_t>(input_format));
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_UTILS_H_
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cpu_kernel/common/status.h"
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
static const int kCubeSize = 16;
|
||||
static const int kNiSize = 16;
|
||||
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;
|
||||
int64_t Lcm(int64_t a, int64_t b);
|
||||
bool IsShapeValid(const std::vector<int64_t> &shape);
|
||||
|
||||
bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims);
|
||||
|
||||
int64_t GetCubeSizeByDataType(DataType data_type);
|
||||
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape);
|
||||
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape);
|
||||
|
||||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape);
|
||||
|
||||
void copy_data(const uint8_t *input_data, std::shared_ptr<uint8_t> dst, int64_t src_index, int64_t dst_index,
|
||||
int64_t data_size);
|
||||
|
||||
KernelStatus GetFormatDim(int64_t &d_dim, int64_t &h_dim, int64_t &w_dim, int64_t &c_dim, int64_t &n_dim,
|
||||
const Format &input_format, const std::vector<int64_t> &dims);
|
||||
KernelStatus CheckDimOri(int64_t cin_ori, int64_t cout_ori);
|
||||
|
||||
template <typename T>
|
||||
T Ceil(T n1, T n2) {
|
||||
if (n1 == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the data format, and put the converted format and length in the
|
||||
* result
|
||||
* @param args
|
||||
* @param result
|
||||
* @return
|
||||
*/
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result);
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_UTILS_H_
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFERS_FORMAT_TRANSFER_DEFINITIONS_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFERS_FORMAT_TRANSFER_DEFINITIONS_H
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
enum NchwDimIndex { kNchwN, kNchwC, kNchwH, kNchwW, kNchwDimsNum };
|
||||
|
||||
enum NhwcDimIndex { kNhwcN, kNhwcH, kNhwcW, kNhwcC, kNhwcDimsNum };
|
||||
|
||||
enum HwcnDimIndex { kHwcnH, kHwcnW, kHwcnC, kHwcnN, kHwcnDimsNum };
|
||||
|
||||
enum ChwnDimIndex { kChwnC, kChwnH, kChwnW, kChwnN, kChwnDimsNum };
|
||||
|
||||
enum Nc1hwc0DimIndex { kNc1hwc0N, kNc1hwc0C1, kNc1hwc0H, kNc1hwc0W, kNc1hwc0C0, kNc1hwc0DimsNum };
|
||||
|
||||
enum C1hwncoc0DimIndex {
|
||||
kC1hwncoc0C1,
|
||||
kC1hwncoc0H,
|
||||
kC1hwncoc0W,
|
||||
kC1hwncoc0N,
|
||||
kC1hwncoc0Co,
|
||||
kC1hwncoc0C0,
|
||||
kC1hwncoc0DimsNum
|
||||
};
|
||||
|
||||
enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum };
|
||||
|
||||
enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum };
|
||||
|
||||
enum NcdhwDimIndex { kNcdhwN, kNcdhwC, kNcdhwD, kNcdhwH, kNcdhwW, kNcdhwDimsNum };
|
||||
|
||||
enum NdhwcDimIndex { kNdhwcN, kNdhwcD, kNdhwcH, kNdhwcW, kNdhwcC, kNdhwcDimsNum };
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFERS_FORMAT_TRANSFER_DEFINITIONS_H_
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/format_transfer/register_format_transfer.h"
|
||||
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
namespace {
|
||||
struct FormatTransferRegistry {
|
||||
void RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) {
|
||||
src_dst_builder[src][dst] = move(builder);
|
||||
}
|
||||
std::map<Format, std::map<Format, FormatTransferBuilder>> src_dst_builder;
|
||||
};
|
||||
|
||||
FormatTransferRegistry &GetFormatTransferRegistry() {
|
||||
static FormatTransferRegistry registry;
|
||||
return registry;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) {
|
||||
GetFormatTransferRegistry().RegisterBuilder(src, dst, move(builder));
|
||||
}
|
||||
|
||||
std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args) {
|
||||
auto ®istry = GetFormatTransferRegistry();
|
||||
auto dst_builder = registry.src_dst_builder.find(args.src_format);
|
||||
if (dst_builder == registry.src_dst_builder.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto builder_iter = dst_builder->second.find(args.dst_format);
|
||||
if (builder_iter == dst_builder->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder_iter->second();
|
||||
}
|
||||
|
||||
bool FormatTransferExists(const TransArgs &args) {
|
||||
auto ®istry = GetFormatTransferRegistry();
|
||||
auto dst_builder = registry.src_dst_builder.find(args.src_format);
|
||||
if (dst_builder == registry.src_dst_builder.end()) {
|
||||
return false;
|
||||
}
|
||||
return dst_builder->second.count(args.dst_format) > 0;
|
||||
}
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_KERNELS_HOST_FORMAT_TRANSFER_REGISTER_FORMAT_TRANSFER_H
|
||||
#define AICPU_KERNELS_HOST_FORMAT_TRANSFER_REGISTER_FORMAT_TRANSFER_H
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
struct TransArgs {
|
||||
const uint8_t *data;
|
||||
// primary format
|
||||
Format src_format;
|
||||
Format dst_format;
|
||||
// For scenes that need to supplement the shape, for example, 5D to 4D
|
||||
// It is not possible to convert the format normally if you only get the
|
||||
// src_shape, and must get the shape before you mend the shape. So the
|
||||
// parameters here need to be passed in both src_shape and dst_shape
|
||||
std::vector<int64_t> src_shape;
|
||||
std::vector<int64_t> dst_shape;
|
||||
DataType src_data_type;
|
||||
int64_t groups;
|
||||
};
|
||||
|
||||
struct TransResult {
|
||||
std::shared_ptr<uint8_t> data;
|
||||
// data length in bytes
|
||||
size_t length;
|
||||
};
|
||||
|
||||
class FormatTransfer {
|
||||
public:
|
||||
virtual ~FormatTransfer() = default;
|
||||
virtual uint32_t TransFormat(const TransArgs &args, TransResult &result) = 0;
|
||||
virtual uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
Format dst_format, std::vector<int64_t> &dst_shape, int64_t groups) = 0;
|
||||
};
|
||||
|
||||
using FormatTransferBuilder = std::function<std::shared_ptr<FormatTransfer>()>;
|
||||
|
||||
class FormatTransferRegister {
|
||||
public:
|
||||
FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst);
|
||||
~FormatTransferRegister() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_FORMAT_TRANSFER(TransferClass, format1, format2) \
|
||||
namespace { \
|
||||
FormatTransferRegister format_transfer_register_##TransferClass##format1##format2( \
|
||||
[]() { return std::make_shared<TransferClass>(); }, format1, format2); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a FormatTransfer according to 'args'
|
||||
* @param args
|
||||
* @param result
|
||||
* @return
|
||||
*/
|
||||
std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args);
|
||||
|
||||
bool FormatTransferExists(const TransArgs &args);
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,304 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of attr
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_ATTR_VALUE_H
|
||||
#define CPU_KERNEL_ATTR_VALUE_H
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor_shape.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AttrValueImpl;
|
||||
class AICPU_VISIBILITY AttrValue {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
AttrValue() = delete;
|
||||
~AttrValue() = default;
|
||||
|
||||
AttrValue(const AttrValue &) = delete;
|
||||
AttrValue(AttrValue &&) = delete;
|
||||
AttrValue &operator=(const AttrValue &) = delete;
|
||||
AttrValue &operator=(AttrValue &&) = delete;
|
||||
|
||||
/*
|
||||
* get string value of attr.
|
||||
* @return string: string value of attr
|
||||
*/
|
||||
std::string GetString() const;
|
||||
|
||||
/*
|
||||
* get string list value of attr.
|
||||
* @return vector<std::string>: string list value of attr
|
||||
*/
|
||||
std::vector<std::string> GetListString() const;
|
||||
|
||||
/*
|
||||
* attr add string value to list.
|
||||
* @param string: string value need to add to list
|
||||
*/
|
||||
void AddListString(const std::string &string);
|
||||
|
||||
/*
|
||||
* get string list size of attr.
|
||||
* @return int32_t: string list size of attr
|
||||
*/
|
||||
int32_t ListStringSize() const;
|
||||
|
||||
/*
|
||||
* set string value to attr.
|
||||
* @param string: string value need to set to attr
|
||||
*/
|
||||
void SetString(const std::string &string);
|
||||
|
||||
/*
|
||||
* set string list value to attr.
|
||||
* @param vector<std::string>: string list value need to set to attr
|
||||
*/
|
||||
void SetListString(const std::vector<std::string> &bytes);
|
||||
|
||||
/*
|
||||
* get int value of attr.
|
||||
* @return int64_t: int value of attr
|
||||
*/
|
||||
int64_t GetInt() const;
|
||||
|
||||
/*
|
||||
* get int list value of attr.
|
||||
* @return vector<int64_t>: int list value of attr
|
||||
*/
|
||||
std::vector<int64_t> GetListInt() const;
|
||||
|
||||
/*
|
||||
* attr add int value to list.
|
||||
* @param i: int value need to add to list
|
||||
*/
|
||||
void AddListInt(int64_t i);
|
||||
|
||||
/*
|
||||
* get int list size of attr.
|
||||
* @return int32_t: int list size of attr
|
||||
*/
|
||||
int32_t ListIntSize() const;
|
||||
|
||||
/*
|
||||
* set int value to attr.
|
||||
* @param i: int value need to set to attr
|
||||
*/
|
||||
void SetInt(int64_t i);
|
||||
|
||||
/*
|
||||
* set int list value to attr.
|
||||
* @param vector<int64_t>: int list value need to set to attr
|
||||
*/
|
||||
void SetListInt(const std::vector<int64_t> &i);
|
||||
|
||||
/*
|
||||
* get int list list value of attr.
|
||||
* @return vector<vector<int64_t>>: int list list value of attr
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> GetListListInt() const;
|
||||
|
||||
/*
|
||||
* set int list list value to attr.
|
||||
* @param vector<vector<int64_t>>: int list list value need to set to attr
|
||||
*/
|
||||
void SetListListInt(const std::vector<std::vector<int64_t>> &i);
|
||||
|
||||
/*
|
||||
* get float value of attr.
|
||||
* @return float: float value of attr
|
||||
*/
|
||||
float GetFloat() const;
|
||||
|
||||
/*
|
||||
* get float list value of attr.
|
||||
* @return vector<float>: float list value of attr
|
||||
*/
|
||||
std::vector<float> GetListFloat() const;
|
||||
|
||||
/*
|
||||
* attr add float value to list.
|
||||
* @param f: float value need to add to list
|
||||
*/
|
||||
void AddListFloat(float f);
|
||||
|
||||
/*
|
||||
* get float list size of attr.
|
||||
* @return int32_t: float list size of attr
|
||||
*/
|
||||
int32_t ListFloatSize() const;
|
||||
|
||||
/*
|
||||
* set float value to attr.
|
||||
* @param f: float value need to set to attr
|
||||
*/
|
||||
void SetFloat(float f);
|
||||
|
||||
/*
|
||||
* set float list value to attr.
|
||||
* @param vector<float>: float list value need to set to attr
|
||||
*/
|
||||
void SetListFloat(const std::vector<float> &f);
|
||||
|
||||
/*
|
||||
* get bool value of attr.
|
||||
* @return bool: bool value of attr
|
||||
*/
|
||||
bool GetBool() const;
|
||||
|
||||
/*
|
||||
* get bool list value of attr.
|
||||
* @return vector<bool>: bool list value of attr
|
||||
*/
|
||||
std::vector<bool> GetListBool() const;
|
||||
|
||||
/*
|
||||
* attr add bool value to list.
|
||||
* @param b: bool value need to add to list
|
||||
*/
|
||||
void AddListBool(bool b);
|
||||
|
||||
/*
|
||||
* get bool list size of attr.
|
||||
* @return int32_t: bool list size of attr
|
||||
*/
|
||||
int32_t ListBoolSize() const;
|
||||
|
||||
/*
|
||||
* set bool value to attr.
|
||||
* @param b: bool value need to set to attr
|
||||
*/
|
||||
void SetBool(bool b);
|
||||
|
||||
/*
|
||||
* set bool list value to attr.
|
||||
* @param vector<bool>: bool list value need to set to attr
|
||||
*/
|
||||
void SetListBool(const std::vector<bool> &b);
|
||||
|
||||
/*
|
||||
* get data type value of attr.
|
||||
* @return DataType: data type value of attr
|
||||
*/
|
||||
DataType GetDataType() const;
|
||||
|
||||
/*
|
||||
* get data type list value of attr.
|
||||
* @return vector<DataType>: data type list value of attr
|
||||
*/
|
||||
std::vector<DataType> GetListDataType() const;
|
||||
|
||||
/*
|
||||
* attr add data type value to list.
|
||||
* @param type: data type value need to add to list
|
||||
*/
|
||||
void AddListDataType(DataType type);
|
||||
|
||||
/*
|
||||
* get data type list size of attr.
|
||||
* @return int32_t: data type list size of attr
|
||||
*/
|
||||
int32_t ListDataTypeSize() const;
|
||||
|
||||
/*
|
||||
* set data type value to attr.
|
||||
* @param type: data type value need to set to attr
|
||||
*/
|
||||
void SetDataType(DataType type);
|
||||
|
||||
/*
|
||||
* set data type list value to attr.
|
||||
* @param vector<int32_t>: data type list value need to set to attr
|
||||
*/
|
||||
void SetListDataType(const std::vector<DataType> &type);
|
||||
|
||||
/*
|
||||
* set tensor shape value to attr.
|
||||
* @param shape: tensor shape value need to set to attr
|
||||
* @return bool: true->success false->failed
|
||||
*/
|
||||
bool SetTensorShape(const TensorShape *shape);
|
||||
|
||||
/*
|
||||
* set tensor shape list value to attr.
|
||||
* @param vector<TensorShape>: tensor shape list value need to set to attr
|
||||
* @return uint32_t: success number
|
||||
*/
|
||||
uint32_t SetListTensorShape(const std::vector<TensorShape *> &shape);
|
||||
|
||||
/*
|
||||
* attr add tensor shape value to list.
|
||||
* @return shared_ptr<TensorShape>: tensor shape value ptr added to list
|
||||
*/
|
||||
std::shared_ptr<TensorShape> AddListTensorShape();
|
||||
|
||||
/*
|
||||
* get tensor shape value of attr.
|
||||
* @return TensorShape: tensor shape value of attr
|
||||
*/
|
||||
std::shared_ptr<TensorShape> GetTensorShape() const;
|
||||
|
||||
/*
|
||||
* get tensor shape list value of attr.
|
||||
* @return vector<TensorShape>: tensor shape list value of attr
|
||||
*/
|
||||
std::vector<TensorShape> GetListTensorShape() const;
|
||||
|
||||
/*
|
||||
* get tensor shape list size of attr.
|
||||
* @return int32_t: tensor shape list size of attr
|
||||
*/
|
||||
int32_t ListTensorShapeSize() const;
|
||||
|
||||
/*
|
||||
* set tensor value to attr.
|
||||
* @param tensor: tensor value need to set to attr
|
||||
* @return bool: true->success false->failed
|
||||
*/
|
||||
bool SetTensor(const Tensor *tensor);
|
||||
|
||||
/*
|
||||
* set tensor list value to attr.
|
||||
* @param vector<Tensor>: tensor list value need to set to attr
|
||||
* @return uint32_t: success number
|
||||
*/
|
||||
uint32_t SetListTensor(const std::vector<Tensor *> &tensor);
|
||||
|
||||
/*
|
||||
* attr add tensor value to list.
|
||||
* @return shared_ptr<Tensor>: tensor value ptr added to list
|
||||
*/
|
||||
std::shared_ptr<Tensor> AddListTensor();
|
||||
|
||||
/*
|
||||
* get tensor value of attr.
|
||||
* @return Tensor: tensor value of attr
|
||||
*/
|
||||
std::shared_ptr<Tensor> GetTensor() const;
|
||||
|
||||
/*
|
||||
* get tensor list value of attr.
|
||||
* @return vector<Tensor>: tensor list value of attr
|
||||
*/
|
||||
std::vector<Tensor> GetListTensor() const;
|
||||
|
||||
/*
|
||||
* get tensor list size of attr.
|
||||
* @return int32_t: tensor list size of attr
|
||||
*/
|
||||
int32_t ListTensorSize() const;
|
||||
|
||||
private:
|
||||
explicit AttrValue(AttrValueImpl *impl);
|
||||
|
||||
private:
|
||||
std::shared_ptr<AttrValueImpl> impl_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNEL_ATTR_VALUE_H
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of context
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNELS_CONTEXT_H
|
||||
#define CPU_KERNELS_CONTEXT_H
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
|
||||
namespace aicpu {
|
||||
class Device;
|
||||
class NodeDef;
|
||||
class AICPU_VISIBILITY CpuKernelContext {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
explicit CpuKernelContext(DeviceType type);
|
||||
CpuKernelContext() = delete;
|
||||
~CpuKernelContext() = default;
|
||||
CpuKernelContext(const CpuKernelContext &) = delete;
|
||||
CpuKernelContext(CpuKernelContext &&) = delete;
|
||||
CpuKernelContext &operator=(const CpuKernelContext &) = delete;
|
||||
CpuKernelContext &operator=(CpuKernelContext &&) = delete;
|
||||
|
||||
uint32_t Init(NodeDef *nodeDef);
|
||||
|
||||
/*
|
||||
* get op type.
|
||||
* @return string: op type
|
||||
*/
|
||||
std::string GetOpType() const;
|
||||
|
||||
/*
|
||||
* get input tensor.
|
||||
* @return Tensor *: not null->success, null->failed
|
||||
*/
|
||||
Tensor *Input(uint32_t index) const;
|
||||
|
||||
/*
|
||||
* get output tensor.
|
||||
* @return Tensor *: not null->success, null->failed
|
||||
*/
|
||||
Tensor *Output(uint32_t index) const;
|
||||
|
||||
/*
|
||||
* get attr.
|
||||
* @return AttrValue *: not null->success, null->failed
|
||||
*/
|
||||
AttrValue *GetAttr(std::string name) const;
|
||||
|
||||
/*
|
||||
* get input size.
|
||||
* @return uint32_t: input size
|
||||
*/
|
||||
uint32_t GetInputsSize() const;
|
||||
|
||||
/*
|
||||
* get output size.
|
||||
* @return uint32_t: output size
|
||||
*/
|
||||
uint32_t GetOutputsSize() const;
|
||||
|
||||
private:
|
||||
std::string op_; // op type
|
||||
std::vector<std::shared_ptr<Tensor> > inputs_; // input tensor list
|
||||
std::vector<std::shared_ptr<Tensor> > outputs_; // out tensor list
|
||||
std::unordered_map<std::string, std::shared_ptr<AttrValue> > attrs_; // attr list
|
||||
std::shared_ptr<Device> device_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNELS_CONTEXT_H
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
* Description: api of the nodedef builder
|
||||
*/
|
||||
|
||||
#ifndef CPU_NODEDEF_BUILDER_H
|
||||
#define CPU_NODEDEF_BUILDER_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY NodeDefBuilder {
|
||||
public:
|
||||
struct InputOutputNode {
|
||||
std::string node;
|
||||
aicpu::DataType dType;
|
||||
std::vector<int64_t> dims;
|
||||
void *data;
|
||||
aicpu::Format format;
|
||||
};
|
||||
|
||||
static std::shared_ptr<NodeDef> CreateNodeDef();
|
||||
|
||||
NodeDefBuilder(NodeDef *nodeDef, std::string name, std::string opName);
|
||||
|
||||
NodeDefBuilder &Input(const InputOutputNode &input);
|
||||
|
||||
NodeDefBuilder &Output(const InputOutputNode &output);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, int32_t value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, int64_t value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, float value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, double value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, bool value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::DataType value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<bool> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::string &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::string> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<float> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<aicpu::DataType> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &shape, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::Tensor *tensor);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, std::vector<aicpu::Tensor *> &tensors);
|
||||
|
||||
private:
|
||||
void BuildNodeFromInputOutputNode(const InputOutputNode &node, bool isInput);
|
||||
|
||||
NodeDef *nodeDef_;
|
||||
|
||||
std::string name_;
|
||||
|
||||
std::string opName_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of cpu kernel
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_H
|
||||
#define CPU_KERNEL_H
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernel {
|
||||
public:
|
||||
virtual uint32_t Compute(CpuKernelContext &ctx) = 0;
|
||||
|
||||
virtual ~CpuKernel() {}
|
||||
};
|
||||
|
||||
using KERNEL_CREATOR_FUN = std::function<std::shared_ptr<CpuKernel>(void)>;
|
||||
|
||||
AICPU_VISIBILITY bool RegistCpuKernel(const std::string &type, const KERNEL_CREATOR_FUN &fun);
|
||||
|
||||
template <typename T, typename... Args>
|
||||
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
|
||||
typedef typename std::remove_const<T>::type T_nc;
|
||||
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#define REGISTER_CPU_KERNEL(type, clazz) \
|
||||
std::shared_ptr<CpuKernel> Creator_##type##_Kernel() { \
|
||||
std::shared_ptr<clazz> ptr = nullptr; \
|
||||
ptr = MakeShared<clazz>(); \
|
||||
return ptr; \
|
||||
} \
|
||||
bool g_##type##_Kernel_Creator __attribute__((unused)) = RegistCpuKernel(type, Creator_##type##_Kernel)
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNEL_H
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of tensor
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_TENSOR_H
|
||||
#define CPU_KERNEL_TENSOR_H
|
||||
#include <memory>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_tensor_shape.h"
|
||||
|
||||
namespace aicpu {
|
||||
class TensorImpl;
|
||||
class AICPU_VISIBILITY Tensor {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
Tensor() = delete;
|
||||
~Tensor() = default;
|
||||
|
||||
/*
|
||||
* set tensor shape value to tensor.
|
||||
* @param shape: tensor shape value need to set to tensor
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SetTensorShape(const TensorShape *shape);
|
||||
|
||||
/*
|
||||
* get tensor shape value of tensor.
|
||||
* @return std::shared_ptr<TensorShape>: tensor shape value of tensor
|
||||
*/
|
||||
std::shared_ptr<TensorShape> GetTensorShape() const;
|
||||
|
||||
/*
|
||||
* set data type value to tensor.
|
||||
* @param type: data type value need to set to tensor
|
||||
*/
|
||||
void SetDataType(DataType type);
|
||||
|
||||
/*
|
||||
* get data type value of tensor.
|
||||
* @return DataType: data type value of tensor
|
||||
*/
|
||||
DataType GetDataType() const;
|
||||
|
||||
/*
|
||||
* set data ptr to tensor.
|
||||
* @param addr: tensor data ptr
|
||||
*/
|
||||
void SetData(void *addr);
|
||||
|
||||
/*
|
||||
* get data ptr of tensor.
|
||||
* @return void *: tensor data ptr
|
||||
*/
|
||||
void *GetData() const;
|
||||
|
||||
/*
|
||||
* set data size to tensor.
|
||||
* @param size: tensor data size
|
||||
*/
|
||||
void SetDataSize(uint64_t size);
|
||||
|
||||
/*
|
||||
* get data size of tensor.
|
||||
* @return uint64_t: tensor data size
|
||||
*/
|
||||
uint64_t GetDataSize() const;
|
||||
|
||||
/*
|
||||
* calculate data size by tensor shape.
|
||||
* @return success->not less than 0, failed->less than 0
|
||||
*/
|
||||
int64_t CalcDataSizeByShape() const;
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
* @return success->not less than 0, unknown->less than 0
|
||||
*/
|
||||
int64_t NumElements() const;
|
||||
|
||||
private:
|
||||
explicit Tensor(TensorImpl *impl);
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorImpl> impl_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNEL_TENSOR_H
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of tensor shape
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_TENSOR_SHAPE_H
|
||||
#define CPU_KERNEL_TENSOR_SHAPE_H
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_types.h"
|
||||
|
||||
namespace aicpu {
|
||||
#ifdef VISIBILITY
|
||||
#define AICPU_VISIBILITY __attribute__((visibility("default")))
|
||||
#else
|
||||
#define AICPU_VISIBILITY
|
||||
#endif
|
||||
|
||||
class TensorShapeImpl;
|
||||
class AICPU_VISIBILITY TensorShape {
|
||||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
TensorShape() = delete;
|
||||
~TensorShape() = default;
|
||||
|
||||
/*
|
||||
* set format value to tensor shape.
|
||||
* @param format: format value need to set to tensor shape
|
||||
*/
|
||||
void SetFormat(Format format);
|
||||
|
||||
/*
|
||||
* get format value of tensor shape.
|
||||
* @return Format: format value of tensor shape
|
||||
*/
|
||||
Format GetFormat() const;
|
||||
|
||||
/*
|
||||
* get unknown rank value of tensor shape.
|
||||
* @return bool: unknown rank value of tensor shape
|
||||
*/
|
||||
bool GetUnknownRank() const;
|
||||
|
||||
/*
|
||||
* set unknown rank value to tensor shape.
|
||||
* @param unknownRank: unknown rank value need to set to tensor shape
|
||||
*/
|
||||
void SetUnknownRank(bool unknownRank);
|
||||
|
||||
/*
|
||||
* set dims value to tensor shape.
|
||||
* @param dims: dims value need to set to tensor shape
|
||||
*/
|
||||
void SetDimSizes(const std::vector<int64_t> &dims);
|
||||
|
||||
/*
|
||||
* get dims value of tensor shape.
|
||||
* @return int32_t: dims value of tensor shape
|
||||
*/
|
||||
std::vector<int64_t> GetDimSizes() const;
|
||||
|
||||
/*
|
||||
* get dim value of tensor shape index dim.
|
||||
* @param index: index dim of tensor shape
|
||||
* @return int64_t: dim value of tensor shape index dim
|
||||
*/
|
||||
int64_t GetDimSize(int32_t index) const;
|
||||
|
||||
/*
|
||||
* get dims size of tensor shape.
|
||||
* @return int32_t: dims size of tensor shape
|
||||
*/
|
||||
int32_t GetDims() const;
|
||||
|
||||
/*
|
||||
* get data elements number.
|
||||
* @return success->not less than 0, unknown->less than 0
|
||||
*/
|
||||
int64_t NumElements() const;
|
||||
|
||||
private:
|
||||
explicit TensorShape(TensorShapeImpl *tensorShape);
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorShapeImpl> impl_{nullptr};
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNEL_TENSOR_SHAPE_H
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
* Description: api of types
|
||||
*/
|
||||
|
||||
#ifndef CPU_KERNEL_TYPES_H
|
||||
#define CPU_KERNEL_TYPES_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace aicpu {
|
||||
#ifdef VISIBILITY
|
||||
#define AICPU_VISIBILITY __attribute__((visibility("default")))
|
||||
#else
|
||||
#define AICPU_VISIBILITY
|
||||
#endif
|
||||
|
||||
enum DataType {
|
||||
DT_FLOAT = 0, // float type
|
||||
DT_FLOAT16 = 1, // fp16 type
|
||||
DT_INT8 = 2, // int8 type
|
||||
DT_INT16 = 6, // int16 type
|
||||
DT_UINT16 = 7, // uint16 type
|
||||
DT_UINT8 = 4, // uint8 type
|
||||
DT_INT32 = 3, //
|
||||
DT_INT64 = 9, // int64 type
|
||||
DT_UINT32 = 8, // unsigned int32
|
||||
DT_UINT64 = 10, // unsigned int64
|
||||
DT_BOOL = 12, // bool type
|
||||
DT_DOUBLE = 11, // double type
|
||||
DT_STRING = 13, // string type
|
||||
DT_DUAL_SUB_INT8 = 14, // dual output int8 type
|
||||
DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type
|
||||
DT_COMPLEX64 = 16, // complex64 type
|
||||
DT_COMPLEX128 = 17, // complex128 type
|
||||
DT_QINT8 = 18, // qint8 type
|
||||
DT_QINT16 = 19, // qint16 type
|
||||
DT_QINT32 = 20, // qint32 type
|
||||
DT_QUINT8 = 21, // quint8 type
|
||||
DT_QUINT16 = 22, // quint16 type
|
||||
DT_RESOURCE = 23, // resource type
|
||||
DT_STRING_REF = 24, // string ref type
|
||||
DT_DUAL = 25, // dual output type
|
||||
DT_UNDEFINED // Used to indicate a DataType field has not been set.
|
||||
};
|
||||
|
||||
AICPU_VISIBILITY inline int GetSizeByDataType(DataType dataType) {
|
||||
const std::map<DataType, int> sizeMap = {
|
||||
{DT_FLOAT, 4}, {DT_FLOAT16, 2}, {DT_INT8, 1}, {DT_INT16, 2}, {DT_UINT16, 2},
|
||||
{DT_UINT8, 1}, {DT_INT32, 4}, {DT_INT64, 8}, {DT_UINT32, 4}, {DT_UINT64, 8},
|
||||
{DT_BOOL, 1}, {DT_DOUBLE, 8}, {DT_STRING, -1}, {DT_DUAL_SUB_INT8, 1}, {DT_DUAL_SUB_UINT8, 1},
|
||||
{DT_COMPLEX64, 8}, {DT_COMPLEX128, 16}, {DT_QINT8, 1}, {DT_QINT16, 2}, {DT_QINT32, 4},
|
||||
{DT_QUINT8, 1}, {DT_QUINT16, 2}, {DT_RESOURCE, -1}, {DT_STRING_REF, -1}, {DT_DUAL, 5}};
|
||||
auto iter = sizeMap.find(dataType);
|
||||
if (iter == sizeMap.end()) {
|
||||
return -1;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
enum Format {
|
||||
FORMAT_NCHW = 0, // NCHW
|
||||
FORMAT_NHWC, // NHWC
|
||||
FORMAT_ND, // Nd Tensor
|
||||
FORMAT_NC1HWC0, // NC1HWC0
|
||||
FORMAT_FRACTAL_Z, // FRACTAL_Z
|
||||
FORMAT_NC1C0HWPAD,
|
||||
FORMAT_NHWC1C0,
|
||||
FORMAT_FSR_NCHW,
|
||||
FORMAT_FRACTAL_DECONV,
|
||||
FORMAT_C1HWNC0,
|
||||
FORMAT_FRACTAL_DECONV_TRANSPOSE,
|
||||
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS,
|
||||
FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4
|
||||
FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4
|
||||
FORMAT_CHWN,
|
||||
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS,
|
||||
FORMAT_HWCN,
|
||||
FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format
|
||||
FORMAT_BN_WEIGHT,
|
||||
FORMAT_FILTER_HWCK, // filter input tensor format
|
||||
FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20,
|
||||
FORMAT_HASHTABLE_LOOKUP_KEYS,
|
||||
FORMAT_HASHTABLE_LOOKUP_VALUE,
|
||||
FORMAT_HASHTABLE_LOOKUP_OUTPUT,
|
||||
FORMAT_HASHTABLE_LOOKUP_HITS = 24,
|
||||
FORMAT_C1HWNCoC0,
|
||||
FORMAT_MD,
|
||||
FORMAT_NDHWC,
|
||||
FORMAT_FRACTAL_ZZ,
|
||||
FORMAT_FRACTAL_NZ,
|
||||
FORMAT_NCDHW,
|
||||
FORMAT_DHWCN, // 3D filter input tensor format
|
||||
FORMAT_NDC1HWC0,
|
||||
FORMAT_FRACTAL_Z_3D,
|
||||
FORMAT_CN,
|
||||
FORMAT_NC,
|
||||
FORMAT_DHWNC,
|
||||
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
|
||||
FORMAT_FRACTAL_ZN_LSTM,
|
||||
FORMAT_FRACTAL_Z_G,
|
||||
FORMAT_RESERVED,
|
||||
FORMAT_ALL,
|
||||
FORMAT_NULL
|
||||
};
|
||||
|
||||
enum DeviceType { HOST, DEVICE };
|
||||
} // namespace aicpu
|
||||
#endif // CPU_KERNEL_TYPES_H
|
|
@ -0,0 +1,159 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kGatherNd = "GatherNd";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t GatherNdCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check GatherNd Input and Output failed.");
|
||||
|
||||
Tensor *input_x = ctx.Input(0);
|
||||
Tensor *input_indices = ctx.Input(1);
|
||||
|
||||
auto shape_x = input_x->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto indices_rank = shape_indices->GetDims();
|
||||
auto indices_nd = shape_indices->GetDimSize(indices_rank - 1);
|
||||
|
||||
if (shape_x->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_x's rank is less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (indices_rank < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank is less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (indices_nd > shape_x->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Slice's length must be less than x rank. ", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto data_type0 = input_x->GetDataType();
|
||||
auto data_type1 = input_indices->GetDataType();
|
||||
|
||||
if (data_type1 != DT_INT32 && data_type1 != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type1).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type0) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("GatherNd kernel data type [%s] not support.", DTypeStr(data_type0).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_type>
|
||||
uint32_t GatherNdCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return GatherNdComputeRealKernel<int32_t, data_type>(ctx);
|
||||
case DT_INT64:
|
||||
return GatherNdComputeRealKernel<int64_t, data_type>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename indices_type, typename data_type>
|
||||
uint32_t GatherNdCpuKernel::GatherNdComputeRealKernel(CpuKernelContext &ctx) {
|
||||
auto x_shape = ctx.Input(0)->GetTensorShape();
|
||||
auto indices_shape = ctx.Input(1)->GetTensorShape();
|
||||
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
const int64_t indices_dims = indices_shape->GetDims();
|
||||
int64_t indices_nd = indices_shape->GetDimSize(indices_dims - 1);
|
||||
|
||||
const int64_t params_dims = x_shape->GetDims();
|
||||
|
||||
for (int64_t i = 0; i < indices_dims - 1; ++i) {
|
||||
n_slices *= indices_shape->GetDimSize(i);
|
||||
}
|
||||
for (int64_t i = indices_nd; i < params_dims; ++i) {
|
||||
slice_size *= x_shape->GetDimSize(i);
|
||||
}
|
||||
|
||||
int64_t remain_flat_size = x_shape->NumElements();
|
||||
std::vector<int64_t> dims_to_count = std::vector<int64_t>(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / x_shape->GetDimSize(i);
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
|
||||
auto x_data = reinterpret_cast<data_type *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<data_type *>(ctx.Output(0)->GetData());
|
||||
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t from_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
from_pos += indices_data[i * indices_nd + j] * dims_to_count[j];
|
||||
}
|
||||
std::memcpy(output_data + i * slice_size, x_data + from_pos, sizeof(data_type) * slice_size);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kGatherNd, GatherNdCpuKernel);
|
||||
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_GATHERND_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_GATHERND_H_
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class GatherNdCpuKernel : public CpuKernel {
|
||||
public:
|
||||
GatherNdCpuKernel() = default;
|
||||
~GatherNdCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename data_type>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename indices_type, typename data_type>
|
||||
uint32_t GatherNdComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "scatter_nd.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 3;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kScatterNd = "ScatterNd";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ScatterNdCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check ScatterNd Input and Output failed.");
|
||||
|
||||
Tensor *input_indices = ctx.Input(0);
|
||||
Tensor *input_x = ctx.Input(1);
|
||||
Tensor *input_shape = ctx.Input(2);
|
||||
|
||||
auto shape_x = input_x->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto shape_shape = input_shape->GetTensorShape();
|
||||
int64_t indices_shape_m = shape_indices->GetDimSize(shape_indices->GetDims() - 1);
|
||||
|
||||
if (shape_x->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_x's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_indices->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_shape->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_shape's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (indices_shape_m > shape_shape->NumElements()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_shape&input_indices ranks mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < shape_indices->GetDims() - 1; i++) {
|
||||
if (shape_indices->GetDimSize(i) != shape_x->GetDimSize(i)) {
|
||||
KERNEL_LOG_ERROR("[%s], shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_type_x = input_x->GetDataType();
|
||||
auto data_type_indices = input_indices->GetDataType();
|
||||
auto data_type_shape = input_shape->GetDataType();
|
||||
if (data_type_shape != DT_INT32 && data_type_shape != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("ScatterNd kernel data type [%s] not support.", DTypeStr(data_type_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (data_type_indices != DT_INT32 && data_type_indices != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("ScatterNd kernel data type [%s] not support.", DTypeStr(data_type_indices).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (data_type_indices != data_type_shape) {
|
||||
KERNEL_LOG_ERROR("Indices and shape must have the same type.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type_x) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ScatterNd kernel data type [%s] not support.", DTypeStr(data_type_x).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_type_x>
|
||||
uint32_t ScatterNdCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(0)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return ScatterNdComputeRealKernel<int32_t, data_type_x>(ctx);
|
||||
case DT_INT64:
|
||||
return ScatterNdComputeRealKernel<int64_t, data_type_x>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename indices_type, typename data_type_x>
|
||||
uint32_t ScatterNdCpuKernel::ScatterNdComputeRealKernel(CpuKernelContext &ctx) {
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
|
||||
const int64_t outer_dims = ctx.Input(0)->GetTensorShape()->GetDims() - 1;
|
||||
const int64_t indices_nd = ctx.Input(0)->GetTensorShape()->GetDimSize(outer_dims);
|
||||
const int64_t updates_dims = ctx.Input(1)->GetTensorShape()->GetDims();
|
||||
|
||||
auto shape_indices = ctx.Input(0)->GetTensorShape();
|
||||
auto data_shape = reinterpret_cast<indices_type *>(ctx.Input(2)->GetData());
|
||||
auto dims_shape = ctx.Input(2)->GetTensorShape()->NumElements();
|
||||
auto updates_shape = ctx.Input(1)->GetTensorShape();
|
||||
for (int64_t i = 0; i < dims_shape - indices_nd; i++) {
|
||||
if (updates_shape->GetDimSize(i + shape_indices->GetDims() - 1) != data_shape[i + indices_nd]) {
|
||||
KERNEL_LOG_ERROR("[%s], shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < outer_dims; ++i) {
|
||||
n_slices *= ctx.Input(0)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
for (int64_t i = outer_dims; i < updates_dims; ++i) {
|
||||
slice_size *= ctx.Input(1)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
const int kNumberInputTwo = 2;
|
||||
int64_t output_flat_size = 1;
|
||||
int64_t num_shape = ctx.Input(kNumberInputTwo)->NumElements();
|
||||
for (int64_t i = 0; i < num_shape; i++) {
|
||||
output_flat_size *= data_shape[i];
|
||||
}
|
||||
int64_t remain_flat_size = output_flat_size;
|
||||
std::vector<int64_t> dims_to_count(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / data_shape[i];
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto Indices_data = reinterpret_cast<indices_type *>(ctx.Input(0)->GetData());
|
||||
auto Updates_data = reinterpret_cast<data_type_x *>(ctx.Input(1)->GetData());
|
||||
auto Output_data = reinterpret_cast<data_type_x *>(ctx.Output(0)->GetData());
|
||||
|
||||
memset(Output_data, 0, sizeof(data_type_x) * output_flat_size);
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t to_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
int64_t idx = Indices_data[i * indices_nd + j];
|
||||
|
||||
if (idx < 0 || idx >= data_shape[j]) {
|
||||
KERNEL_LOG_ERROR("The indices[%d] is so big or small", idx);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
to_pos += idx * dims_to_count[j];
|
||||
}
|
||||
for (int64_t j = 0; j < slice_size; j++) {
|
||||
Output_data[to_pos + j] += Updates_data[i * slice_size + j];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kScatterNd, ScatterNdCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_SCATTERND_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_SCATTERND_H_
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
class ScatterNdCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ScatterNdCpuKernel() = default;
|
||||
~ScatterNdCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename data_type0>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename indices_type, typename data_type0>
|
||||
uint32_t ScatterNdComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,210 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "scatter_nd_update.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 3;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kScatterNdUpdate = "ScatterNdUpdate";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ScatterNdUpdateCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check ScatterNdUpdate Input and Output failed.");
|
||||
|
||||
Tensor *input_var = ctx.Input(0);
|
||||
Tensor *input_indices = ctx.Input(1);
|
||||
Tensor *input_updates = ctx.Input(2);
|
||||
|
||||
auto shape_var = input_var->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto shape_updates = input_updates->GetTensorShape();
|
||||
|
||||
if (shape_var->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_indices->GetDims() < 2) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank less than 2.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_updates->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_updates's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto index_size = shape_indices->GetDims() - 1;
|
||||
auto index_depth = shape_indices->GetDimSize(index_size);
|
||||
|
||||
if (index_depth > shape_var->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var&input_indices ranks mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::vector<int64_t> batch_shape;
|
||||
for (int64_t i = 0; i < index_size; ++i) {
|
||||
batch_shape.push_back(shape_indices->GetDimSize(i));
|
||||
}
|
||||
|
||||
for (int64_t i = index_depth; i <= shape_var->GetDims() - 1; ++i) {
|
||||
batch_shape.push_back(shape_var->GetDimSize(i));
|
||||
}
|
||||
|
||||
if (batch_shape != shape_updates->GetDimSizes()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor indices's & updates' and var's shape are dismatch .", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < index_size; i++) {
|
||||
if (shape_indices->GetDimSize(i) != shape_updates->GetDimSize(i)) {
|
||||
KERNEL_LOG_ERROR("[%s], Tensor indices and updates should have the same batch number.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_type_var = input_var->GetDataType();
|
||||
auto data_type_indices = input_indices->GetDataType();
|
||||
|
||||
if (data_type_indices != DT_INT32 && data_type_indices != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("ScatterNdUpdate kernel data type [%s] not support.", DTypeStr(data_type_indices).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type_var) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ScatterNdUpdate kernel data type [%s] not support.", DTypeStr(data_type_var).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename var_type>
|
||||
uint32_t ScatterNdUpdateCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return ScatterNdUpdateComputeRealKernel<var_type, int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return ScatterNdUpdateComputeRealKernel<var_type, int64_t>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t ScatterNdUpdateCpuKernel::ScatterNdUpdateComputeRealKernel(CpuKernelContext &ctx) {
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
|
||||
const int64_t indices_dims = ctx.Input(1)->GetTensorShape()->GetDims() - 1;
|
||||
const int64_t indices_nd = ctx.Input(1)->GetTensorShape()->GetDimSize(indices_dims);
|
||||
const int64_t updates_dims = ctx.Input(2)->GetTensorShape()->GetDims();
|
||||
|
||||
auto shape_var = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto shape_indices = ctx.Input(1)->GetTensorShape();
|
||||
auto dims_shape = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
for (int64_t i = 0; i < dims_shape - indices_nd; i++) {
|
||||
if (ctx.Input(2)->GetTensorShape()->GetDimSize(i + shape_indices->GetDims() - 1) != shape_var[i + indices_nd]) {
|
||||
KERNEL_LOG_ERROR("[%s] shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < indices_dims; ++i) {
|
||||
n_slices *= ctx.Input(1)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
for (int i = indices_dims; i < updates_dims; ++i) {
|
||||
slice_size *= ctx.Input(2)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
|
||||
const int64_t var_flat_size = ctx.Input(0)->GetTensorShape()->NumElements();
|
||||
std::vector<int64_t> output_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
|
||||
int64_t remain_flat_size = var_flat_size;
|
||||
std::vector<int64_t> dims_to_count(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / output_shape[i];
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto Var_data = reinterpret_cast<var_type *>(ctx.Input(0)->GetData());
|
||||
auto Indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
|
||||
auto Updates_data = reinterpret_cast<var_type *>(ctx.Input(2)->GetData());
|
||||
auto Output_data = reinterpret_cast<var_type *>(ctx.Output(0)->GetData());
|
||||
|
||||
for (int64_t i = 0; i < var_flat_size; ++i) {
|
||||
Output_data[i] = Var_data[i];
|
||||
}
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t to_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
int64_t idx = Indices_data[i * indices_nd + j];
|
||||
|
||||
if (idx < 0 || idx >= output_shape[j]) {
|
||||
KERNEL_LOG_ERROR("The indices[%d] is so big or small", idx);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
to_pos += idx * dims_to_count[j];
|
||||
}
|
||||
for (int64_t j = 0; j < slice_size; j++) {
|
||||
Output_data[to_pos + j] = Updates_data[i * slice_size + j];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kScatterNdUpdate, ScatterNdUpdateCpuKernel);
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_SCATTERNDUPDATE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_SCATTERNDUPDATE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
#include <string.h>
|
||||
|
||||
namespace aicpu {
|
||||
class ScatterNdUpdateCpuKernel : public CpuKernel {
|
||||
public:
|
||||
ScatterNdUpdateCpuKernel() = default;
|
||||
~ScatterNdUpdateCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename var_type>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t ScatterNdUpdateComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,211 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensor_scatter_update.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t kInputNum = 3;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const char *kTensorScatterUpdate = "TensorScatterUpdate";
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t TensorScatterUpdateCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check TensorScatterUpdate Input and Output failed.");
|
||||
|
||||
Tensor *input_var = ctx.Input(0);
|
||||
Tensor *input_indices = ctx.Input(1);
|
||||
Tensor *input_updates = ctx.Input(2);
|
||||
|
||||
auto shape_var = input_var->GetTensorShape();
|
||||
auto shape_indices = input_indices->GetTensorShape();
|
||||
auto shape_updates = input_updates->GetTensorShape();
|
||||
|
||||
if (shape_var->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_indices->GetDims() < 2) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_indices's rank less than 2.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (shape_updates->GetDims() < 1) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_updates's rank less than 1.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto index_size = shape_indices->GetDims() - 1;
|
||||
auto index_depth = shape_indices->GetDimSize(index_size);
|
||||
|
||||
if (index_depth > shape_var->GetDims()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor input_var&input_indices ranks mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::vector<int64_t> batch_shape;
|
||||
for (int64_t i = 0; i < index_size; ++i) {
|
||||
batch_shape.push_back(shape_indices->GetDimSize(i));
|
||||
}
|
||||
|
||||
for (int64_t i = index_depth; i <= shape_var->GetDims() - 1; ++i) {
|
||||
batch_shape.push_back(shape_var->GetDimSize(i));
|
||||
}
|
||||
|
||||
if (batch_shape != shape_updates->GetDimSizes()) {
|
||||
KERNEL_LOG_ERROR("[%s] Tensor indices's & updates' and var's shape are dismatch .", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < index_size; i++) {
|
||||
if (shape_indices->GetDimSize(i) != shape_updates->GetDimSize(i)) {
|
||||
KERNEL_LOG_ERROR("[%s], Tensor indices and updates should have the same batch number.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_type_var = input_var->GetDataType();
|
||||
auto data_type_indices = input_indices->GetDataType();
|
||||
|
||||
if (data_type_indices != DT_INT32 && data_type_indices != DT_INT64) {
|
||||
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_indices).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
switch (data_type_var) {
|
||||
case DT_INT8:
|
||||
return DTYPE_CHOOSE<int8_t>(ctx);
|
||||
case DT_INT16:
|
||||
return DTYPE_CHOOSE<int16_t>(ctx);
|
||||
case DT_INT32:
|
||||
return DTYPE_CHOOSE<int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return DTYPE_CHOOSE<int64_t>(ctx);
|
||||
case DT_UINT8:
|
||||
return DTYPE_CHOOSE<uint8_t>(ctx);
|
||||
case DT_UINT16:
|
||||
return DTYPE_CHOOSE<uint16_t>(ctx);
|
||||
case DT_UINT32:
|
||||
return DTYPE_CHOOSE<uint32_t>(ctx);
|
||||
case DT_UINT64:
|
||||
return DTYPE_CHOOSE<uint64_t>(ctx);
|
||||
case DT_FLOAT16:
|
||||
return DTYPE_CHOOSE<Eigen::half>(ctx);
|
||||
case DT_FLOAT:
|
||||
return DTYPE_CHOOSE<float>(ctx);
|
||||
case DT_DOUBLE:
|
||||
return DTYPE_CHOOSE<double>(ctx);
|
||||
case DT_COMPLEX64:
|
||||
return DTYPE_CHOOSE<std::complex<float>>(ctx);
|
||||
case DT_COMPLEX128:
|
||||
return DTYPE_CHOOSE<std::complex<double>>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("TensorScatterUpdate kernel data type [%s] not support.", DTypeStr(data_type_var).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename var_type>
|
||||
uint32_t TensorScatterUpdateCpuKernel::DTYPE_CHOOSE(CpuKernelContext &ctx) {
|
||||
auto indices_type = static_cast<DataType>(ctx.Input(1)->GetDataType());
|
||||
switch (indices_type) {
|
||||
case DT_INT32:
|
||||
return TensorScatterUpdateComputeRealKernel<var_type, int32_t>(ctx);
|
||||
case DT_INT64:
|
||||
return TensorScatterUpdateComputeRealKernel<var_type, int64_t>(ctx);
|
||||
default:
|
||||
KERNEL_LOG_ERROR("[%s] Data type of input is not supported, input data type is [%s].", ctx.GetOpType().c_str(),
|
||||
DTypeStr(indices_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t TensorScatterUpdateCpuKernel::TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx) {
|
||||
int64_t n_slices = 1;
|
||||
int64_t slice_size = 1;
|
||||
|
||||
const int64_t indices_dims = ctx.Input(1)->GetTensorShape()->GetDims() - 1;
|
||||
const int64_t indices_nd = ctx.Input(1)->GetTensorShape()->GetDimSize(indices_dims);
|
||||
const int64_t updates_dims = ctx.Input(2)->GetTensorShape()->GetDims();
|
||||
|
||||
auto shape_var = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto shape_indices = ctx.Input(1)->GetTensorShape();
|
||||
auto dims_shape = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
for (int64_t i = 0; i < dims_shape - indices_nd; i++) {
|
||||
if (ctx.Input(2)->GetTensorShape()->GetDimSize(i + shape_indices->GetDims() - 1) != shape_var[i + indices_nd]) {
|
||||
KERNEL_LOG_ERROR("[%s] shape_indices and shape_updates mismatch.", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < indices_dims; ++i) {
|
||||
n_slices *= ctx.Input(1)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
for (int i = indices_dims; i < updates_dims; ++i) {
|
||||
slice_size *= ctx.Input(2)->GetTensorShape()->GetDimSize(i);
|
||||
}
|
||||
|
||||
const int64_t var_flat_size = ctx.Input(0)->GetTensorShape()->NumElements();
|
||||
std::vector<int64_t> output_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
|
||||
int64_t remain_flat_size = var_flat_size;
|
||||
std::vector<int64_t> dims_to_count(indices_nd, 0);
|
||||
for (int64_t i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / output_shape[i];
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
auto Var_data = reinterpret_cast<var_type *>(ctx.Input(0)->GetData());
|
||||
auto Indices_data = reinterpret_cast<indices_type *>(ctx.Input(1)->GetData());
|
||||
auto Updates_data = reinterpret_cast<var_type *>(ctx.Input(2)->GetData());
|
||||
auto Output_data = reinterpret_cast<var_type *>(ctx.Output(0)->GetData());
|
||||
|
||||
for (int64_t i = 0; i < var_flat_size; ++i) {
|
||||
Output_data[i] = Var_data[i];
|
||||
}
|
||||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t to_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
int64_t idx = Indices_data[i * indices_nd + j];
|
||||
|
||||
if (idx < 0 || idx >= output_shape[j]) {
|
||||
KERNEL_LOG_ERROR("The indices[%d] is so big or small", idx);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
to_pos += idx * dims_to_count[j];
|
||||
}
|
||||
for (int64_t j = 0; j < slice_size; j++) {
|
||||
Output_data[to_pos + j] = Updates_data[i * slice_size + j];
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
REGISTER_CPU_KERNEL(kTensorScatterUpdate, TensorScatterUpdateCpuKernel);
|
||||
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
|
||||
#define AICPU_KERNELS_NORMALIZED_TENSORSCATTERUPDATE_H_
|
||||
|
||||
#include "cpu_ops_kernel.h"
|
||||
#include "cpu_types.h"
|
||||
#include "utils/bcast.h"
|
||||
#include <string.h>
|
||||
|
||||
namespace aicpu {
|
||||
class TensorScatterUpdateCpuKernel : public CpuKernel {
|
||||
public:
|
||||
TensorScatterUpdateCpuKernel() = default;
|
||||
~TensorScatterUpdateCpuKernel() override = default;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename var_type>
|
||||
uint32_t DTYPE_CHOOSE(CpuKernelContext &ctx);
|
||||
|
||||
template <typename var_type, typename indices_type>
|
||||
uint32_t TensorScatterUpdateComputeRealKernel(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/utils/allocator_utils.h"
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
#include "cce/fwk_adpt_struct.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace {
|
||||
std::unordered_set<uint64_t> g_allocated_ptr;
|
||||
}
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t CpuKernelAllocatorUtils::ParamCheck(const std::vector<int64_t> &dims, const void *data_ptr,
|
||||
Tensor *&outputResultTensor) {
|
||||
if (dims.empty()) {
|
||||
KERNEL_LOG_ERROR("UpdateOutputDataTensor dims size == 0.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_CHECK_NULLPTR(outputResultTensor, KERNEL_STATUS_PARAM_INVALID, "outputResultTensor nullptr");
|
||||
KERNEL_CHECK_NULLPTR(data_ptr, KERNEL_STATUS_PARAM_INVALID, "data_ptr nullptr");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelAllocatorUtils::UpdateOutputDataTensor(const std::vector<int64_t> &dims, DataType type,
|
||||
const void *data_ptr, int64_t input_data_size,
|
||||
Tensor *&outputResultTensor) {
|
||||
uint32_t check_ret = ParamCheck(dims, &data_ptr, outputResultTensor);
|
||||
if (check_ret != KERNEL_STATUS_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
KERNEL_LOG_INFO("UpdateOutputDataTensor::START!!");
|
||||
|
||||
int64_t data_size = GetInputDataSize(dims, type);
|
||||
if (data_size < 0) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (data_size > input_data_size) {
|
||||
KERNEL_LOG_ERROR("data_size[%ld] mast less than input_data_size[%ld]!", data_size, input_data_size);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
int64_t shape_buff_size = 0;
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(int64_t(dims.size()), int64_t(sizeof(int64_t)), shape_buff_size,
|
||||
KERNEL_STATUS_PARAM_INVALID);
|
||||
|
||||
void *output_shape_ptr = malloc(shape_buff_size);
|
||||
KERNEL_CHECK_NULLPTR(output_shape_ptr, KERNEL_STATUS_PARAM_INVALID, "malloc error, size[%ld]!", shape_buff_size);
|
||||
|
||||
int32_t ret = memcpy_s(output_shape_ptr, shape_buff_size, dims.data(), shape_buff_size);
|
||||
if (ret != EOK) {
|
||||
free(output_shape_ptr);
|
||||
KERNEL_LOG_ERROR("memcpy error, size[%ld], ret[%d]!", shape_buff_size, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
aicpu::FWKAdapter::ResultSummary *result_summary =
|
||||
reinterpret_cast<aicpu::FWKAdapter::ResultSummary *>(outputResultTensor->GetData());
|
||||
result_summary->raw_data_size = data_size;
|
||||
result_summary->shape_data_size = shape_buff_size;
|
||||
|
||||
if (data_size == 0) {
|
||||
result_summary->raw_data_ptr = reinterpret_cast<uint64_t>(nullptr);
|
||||
result_summary->shape_data_ptr = reinterpret_cast<uint64_t>(output_shape_ptr);
|
||||
(void)g_allocated_ptr.insert(result_summary->shape_data_ptr);
|
||||
KERNEL_LOG_INFO("UpdateOutputDataTensor:: empty tensor END!!");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
void *output_data_ptr = malloc(data_size);
|
||||
if (output_data_ptr == nullptr) {
|
||||
KERNEL_LOG_ERROR("malloc error, size[%ld]!", data_size);
|
||||
free(output_shape_ptr);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
ret = memcpy_s(output_data_ptr, data_size, data_ptr, data_size);
|
||||
if (ret != EOK) {
|
||||
free(output_data_ptr);
|
||||
free(output_shape_ptr);
|
||||
KERNEL_LOG_ERROR("memcpy_s error, size[%ld], ret[%d]!", data_size, ret);
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
|
||||
result_summary->raw_data_ptr = reinterpret_cast<uint64_t>(output_data_ptr);
|
||||
result_summary->shape_data_ptr = reinterpret_cast<uint64_t>(output_shape_ptr);
|
||||
KERNEL_LOG_INFO("raw_data_ptr [%p]", output_data_ptr);
|
||||
KERNEL_LOG_INFO("shape_data_ptr [%p]", output_shape_ptr);
|
||||
|
||||
(void)g_allocated_ptr.insert(result_summary->raw_data_ptr);
|
||||
(void)g_allocated_ptr.insert(result_summary->shape_data_ptr);
|
||||
KERNEL_LOG_INFO("UpdateOutputDataTensor :: END!!");
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
int64_t CpuKernelAllocatorUtils::GetInputDataSize(const std::vector<int64_t> &dims, DataType type) {
|
||||
int64_t num_elements = 1;
|
||||
int64_t dim_size = 0;
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
dim_size = dims[i];
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, dim_size, num_elements, KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
int64_t data_size = 0;
|
||||
int element_size = GetSizeByDataType(type);
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, int64_t(element_size), data_size, KERNEL_STATUS_PARAM_INVALID);
|
||||
|
||||
if (data_size < 0) {
|
||||
KERNEL_LOG_ERROR("UpdateOutputDataTensor data_size[%ld].", data_size);
|
||||
}
|
||||
|
||||
return data_size;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelAllocatorUtils::CheckOutputDataPtr(const uint64_t data_ptr) {
|
||||
auto find_data_ptr = g_allocated_ptr.find(data_ptr);
|
||||
if ((find_data_ptr == g_allocated_ptr.end())) {
|
||||
KERNEL_LOG_ERROR("CheckOutputDataPtr invalid [%lu].", data_ptr);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelAllocatorUtils::DeleteOutputDataPtr(const uint64_t data_ptr) {
|
||||
KERNEL_LOG_INFO("DeleteOutputDataPtr [%lu]", data_ptr);
|
||||
auto find_data_ptr = g_allocated_ptr.find(data_ptr);
|
||||
if (find_data_ptr != g_allocated_ptr.end()) {
|
||||
free(reinterpret_cast<void *>(data_ptr));
|
||||
g_allocated_ptr.erase(find_data_ptr);
|
||||
} else {
|
||||
KERNEL_LOG_EVENT("DeleteOutputDataPtr invalid [%lu].", data_ptr);
|
||||
}
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_UTILS_ALLOCATOR_UTILS_H_
|
||||
#define AICPU_UTILS_ALLOCATOR_UTILS_H_
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_attr_value.h"
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "cpu_kernel/common/cpu_node_def.h"
|
||||
#include "cpu_kernel/inc/cpu_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernelAllocatorUtils {
|
||||
public:
|
||||
static uint32_t ParamCheck(const std::vector<int64_t> &dims, const void *data_ptr, Tensor *&outputResultTensor);
|
||||
static uint32_t UpdateOutputDataTensor(const std::vector<int64_t> &dims, DataType type, const void *data_ptr,
|
||||
int64_t input_data_size, Tensor *&outputResultTensor);
|
||||
static uint32_t CheckOutputDataPtr(const uint64_t data_ptr);
|
||||
static uint32_t DeleteOutputDataPtr(const uint64_t data_ptr);
|
||||
static int64_t GetInputDataSize(const std::vector<int64_t> &dims, DataType type);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_UTILS_ALLOCATOR_UTILS_H_
|
|
@ -0,0 +1,309 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cpu_kernel/utils/bcast.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace {
|
||||
const int64_t kNoBroadcastValue = 1;
|
||||
|
||||
enum class State { UNKNOWN, SAME, X_ONE, Y_ONE };
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t Bcast::Init(const std::vector<int64_t> &x, const std::vector<int64_t> &y) {
|
||||
State prev = State::UNKNOWN;
|
||||
for (size_t i = 0; i < x.size(); ++i) {
|
||||
State curr = State::UNKNOWN;
|
||||
const int64_t x_i = x[i];
|
||||
const int64_t y_i = y[i];
|
||||
int64_t o_i = 0;
|
||||
int64_t bx_i = 0;
|
||||
int64_t by_i = 0;
|
||||
if (x_i == y_i) {
|
||||
// No broadcast
|
||||
o_i = x_i;
|
||||
bx_i = kNoBroadcastValue;
|
||||
by_i = kNoBroadcastValue;
|
||||
curr = State::SAME;
|
||||
} else if (x_i == kNoBroadcastValue) {
|
||||
// x broadcast to y on this dimension
|
||||
o_i = y_i;
|
||||
bx_i = y_i;
|
||||
by_i = kNoBroadcastValue;
|
||||
curr = State::X_ONE;
|
||||
} else if (y_i == kNoBroadcastValue) {
|
||||
// y broadcast to x on this dimension
|
||||
o_i = x_i;
|
||||
bx_i = kNoBroadcastValue;
|
||||
by_i = x_i;
|
||||
curr = State::Y_ONE;
|
||||
} else {
|
||||
valid_ = false;
|
||||
KERNEL_LOG_ERROR("Broadcast failed, x_shape[%zu]=%ld, y_shape[%zu]=%ld", i, x_i, i, y_i);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
shape_out_.emplace_back(o_i);
|
||||
if (curr == State::SAME && x_i == kNoBroadcastValue) {
|
||||
continue;
|
||||
} else if (prev == curr) {
|
||||
result_shape_.back() *= o_i;
|
||||
x_reshape_.back() *= x_i;
|
||||
x_bcast_.back() *= bx_i;
|
||||
y_reshape_.back() *= y_i;
|
||||
y_bcast_.back() *= by_i;
|
||||
} else {
|
||||
result_shape_.emplace_back(o_i);
|
||||
x_reshape_.emplace_back(x_i);
|
||||
x_bcast_.emplace_back(bx_i);
|
||||
y_reshape_.emplace_back(y_i);
|
||||
y_bcast_.emplace_back(by_i);
|
||||
}
|
||||
prev = curr;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
Bcast::Bcast(std::vector<int64_t> &x_shape, std::vector<int64_t> &y_shape) : valid_(true) {
|
||||
if (x_shape == y_shape) {
|
||||
int64_t elements_num = 1;
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
elements_num *= x_shape[i];
|
||||
shape_out_.emplace_back(x_shape[i]);
|
||||
}
|
||||
x_reshape_.emplace_back(elements_num);
|
||||
y_reshape_.emplace_back(elements_num);
|
||||
result_shape_.emplace_back(elements_num);
|
||||
x_bcast_.emplace_back(kNoBroadcastValue);
|
||||
y_bcast_.emplace_back(kNoBroadcastValue);
|
||||
} else {
|
||||
std::vector<int64_t> x = x_shape;
|
||||
std::vector<int64_t> y = y_shape;
|
||||
std::reverse(x.begin(), x.end());
|
||||
std::reverse(y.begin(), y.end());
|
||||
if (x.size() > y.size()) {
|
||||
y.resize(x.size(), kNoBroadcastValue);
|
||||
} else {
|
||||
x.resize(y.size(), kNoBroadcastValue);
|
||||
}
|
||||
|
||||
auto ret = Init(x, y);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result_shape_.empty()) {
|
||||
// when both x and y are scalar
|
||||
result_shape_.emplace_back(kNoBroadcastValue);
|
||||
x_reshape_.emplace_back(kNoBroadcastValue);
|
||||
x_bcast_.emplace_back(kNoBroadcastValue);
|
||||
y_reshape_.emplace_back(kNoBroadcastValue);
|
||||
y_bcast_.emplace_back(kNoBroadcastValue);
|
||||
}
|
||||
std::reverse(result_shape_.begin(), result_shape_.end());
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(x_bcast_.begin(), x_bcast_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
std::reverse(y_bcast_.begin(), y_bcast_.end());
|
||||
|
||||
// generate strides, just for row major
|
||||
int32_t size = static_cast<int32_t>(result_shape_.size());
|
||||
x_input_strides_.resize(size, 0);
|
||||
y_input_strides_.resize(size, 0);
|
||||
x_output_strides_.resize(size, 0);
|
||||
y_output_strides_.resize(size, 0);
|
||||
x_input_strides_[size - 1] = 1;
|
||||
y_input_strides_[size - 1] = 1;
|
||||
x_output_strides_[size - 1] = 1;
|
||||
y_output_strides_[size - 1] = 1;
|
||||
for (int32_t i = size - 2; i >= 0; --i) {
|
||||
x_input_strides_[i] = x_input_strides_[i + 1] * x_reshape_[i + 1];
|
||||
y_input_strides_[i] = y_input_strides_[i + 1] * y_reshape_[i + 1];
|
||||
x_output_strides_[i] = x_output_strides_[i + 1] * result_shape_[i + 1];
|
||||
y_output_strides_[i] = y_output_strides_[i + 1] * result_shape_[i + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t Bcast::GetBroadcastXIndex(int64_t index) const {
|
||||
int64_t input_index = 0;
|
||||
const size_t num_dims = result_shape_.size();
|
||||
for (size_t i = 0; i < num_dims - 1; ++i) {
|
||||
const int64_t idx = index / x_output_strides_[i];
|
||||
if (x_bcast_[i] == kNoBroadcastValue) {
|
||||
input_index += idx * x_input_strides_[i];
|
||||
} else {
|
||||
if (x_reshape_[i] != kNoBroadcastValue) {
|
||||
input_index += (idx % x_reshape_[i]) * x_input_strides_[i];
|
||||
}
|
||||
}
|
||||
index -= idx * x_output_strides_[i];
|
||||
}
|
||||
if (x_bcast_[num_dims - 1] == kNoBroadcastValue) {
|
||||
input_index += index;
|
||||
} else {
|
||||
if (x_reshape_[num_dims - 1] != kNoBroadcastValue) {
|
||||
input_index += (index % x_reshape_[num_dims - 1]);
|
||||
}
|
||||
}
|
||||
return input_index;
|
||||
}
|
||||
|
||||
int64_t Bcast::GetBroadcastYIndex(int64_t index) const {
|
||||
int64_t input_index = 0;
|
||||
const size_t num_dims = result_shape_.size();
|
||||
for (size_t i = 0; i < num_dims - 1; ++i) {
|
||||
const int64_t idx = index / y_output_strides_[i];
|
||||
if (y_bcast_[i] == kNoBroadcastValue) {
|
||||
input_index += idx * y_input_strides_[i];
|
||||
} else {
|
||||
if (y_reshape_[i] != kNoBroadcastValue) {
|
||||
input_index += (idx % y_reshape_[i]) * y_input_strides_[i];
|
||||
}
|
||||
}
|
||||
index -= idx * y_output_strides_[i];
|
||||
}
|
||||
if (y_bcast_[num_dims - 1] == kNoBroadcastValue) {
|
||||
input_index += index;
|
||||
} else {
|
||||
if (y_reshape_[num_dims - 1] != kNoBroadcastValue) {
|
||||
input_index += (index % y_reshape_[num_dims - 1]);
|
||||
}
|
||||
}
|
||||
return input_index;
|
||||
}
|
||||
|
||||
uint32_t Bcast::GenerateBcastInfo(const BCalcInfo &calcInfo) {
|
||||
const std::vector<int64_t> &shape_x = calcInfo.input_0->GetTensorShape()->GetDimSizes();
|
||||
const std::vector<int64_t> &shape_y = calcInfo.input_1->GetTensorShape()->GetDimSizes();
|
||||
const std::vector<int64_t> &shape_out = calcInfo.output->GetTensorShape()->GetDimSizes();
|
||||
x_reshape_ = shape_x;
|
||||
y_reshape_ = shape_y;
|
||||
shape_out_ = shape_out;
|
||||
if (shape_x.empty() && shape_y.empty() && shape_out.empty()) {
|
||||
// Eigen support scalar
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
// resize shape_x or shape_y to make size equal
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
|
||||
size_t dim_num_x = x_reshape_.size();
|
||||
size_t dim_num_y = y_reshape_.size();
|
||||
size_t max_size = dim_num_x > dim_num_y ? dim_num_x : dim_num_y;
|
||||
if (dim_num_x < dim_num_y) {
|
||||
x_reshape_.resize(max_size, kNoBroadcastValue);
|
||||
} else if (dim_num_x > dim_num_y) {
|
||||
y_reshape_.resize(max_size, kNoBroadcastValue);
|
||||
}
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
// Check if shape match
|
||||
if (shape_out.size() != max_size) {
|
||||
KERNEL_LOG_ERROR("shape mismatch, max_dim_in=%zu, dim_out=%zu.", max_size, shape_out.size());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < max_size; i++) {
|
||||
if (shape_out_[i] != std::max(x_reshape_[i], y_reshape_[i])) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"shape mismatch, dim_x[%zu]=%ld, dim_y[%zu]=%ld, "
|
||||
"dim_out[%zu]=%ld.",
|
||||
i, x_reshape_[i], i, y_reshape_[i], i, shape_out_[i]);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
// generate broadcast info
|
||||
x_bcast_.resize(max_size, kNoBroadcastValue);
|
||||
y_bcast_.resize(max_size, kNoBroadcastValue);
|
||||
for (size_t i = 0; i < max_size; i++) {
|
||||
if (x_reshape_[i] == y_reshape_[i]) {
|
||||
continue;
|
||||
}
|
||||
if (x_reshape_[i] == kNoBroadcastValue) {
|
||||
x_bcast_[i] = y_reshape_[i];
|
||||
} else if (y_reshape_[i] == kNoBroadcastValue) {
|
||||
y_bcast_[i] = x_reshape_[i];
|
||||
} else {
|
||||
KERNEL_LOG_ERROR("Broadcast not support, dim_x[%zu]=%ld, dim_y[%zu]=%ld.", i, x_reshape_[i], i, y_reshape_[i]);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
void Bcast::GetBcastVec(BCalcInfo &calcInfo) {
|
||||
calcInfo.reshape_0 = std::move(x_reshape_);
|
||||
calcInfo.reshape_1 = std::move(y_reshape_);
|
||||
calcInfo.shape_out = std::move(shape_out_);
|
||||
calcInfo.bcast_0 = std::move(x_bcast_);
|
||||
calcInfo.bcast_1 = std::move(y_bcast_);
|
||||
}
|
||||
|
||||
void Bcast::BCastIndexes(std::vector<int64_t> &x_indexes, std::vector<int64_t> &y_indexes) {
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
std::reverse(shape_out_.begin(), shape_out_.end());
|
||||
|
||||
// Process 0-th dimension
|
||||
int64_t x_dim = 1;
|
||||
int64_t y_dim = 1;
|
||||
int64_t out_dim = 1;
|
||||
|
||||
// If shape_out_ is not empty, get dim of shape vector
|
||||
if (!shape_out_.empty()) {
|
||||
x_dim = x_reshape_.at(0);
|
||||
y_dim = y_reshape_.at(0);
|
||||
out_dim = shape_out_.at(0);
|
||||
}
|
||||
|
||||
int64_t x_bias = x_dim;
|
||||
int64_t y_bias = y_dim;
|
||||
|
||||
for (int64_t i = 0; i < out_dim; i++) {
|
||||
x_indexes.push_back(x_dim == 1 ? 0 : i);
|
||||
y_indexes.push_back(y_dim == 1 ? 0 : i);
|
||||
}
|
||||
|
||||
// Process the remaining dimensions
|
||||
for (size_t i = 1; i < shape_out_.size(); i++) {
|
||||
x_dim = x_reshape_.at(i); // i-th dimension of x.
|
||||
y_dim = y_reshape_.at(i); // i-th dimension of y.
|
||||
out_dim = shape_out_.at(i); // i-th dimension of shape_out_.
|
||||
|
||||
std::vector<int64_t>::size_type stride = x_indexes.size();
|
||||
for (int64_t j = 1; j < out_dim; j++) {
|
||||
for (std::vector<int64_t>::size_type k = 0; k < stride; k++) {
|
||||
x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
|
||||
y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
|
||||
}
|
||||
}
|
||||
x_bias *= x_dim;
|
||||
y_bias *= y_dim;
|
||||
}
|
||||
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
std::reverse(shape_out_.begin(), shape_out_.end());
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_BCAST_H_
|
||||
#define _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_BCAST_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
|
||||
namespace aicpu {
|
||||
// broadcast shape type
|
||||
// 1. SAME_SHAPE : x and y have the same shape
|
||||
// 2. X_ONE : x has only one element
|
||||
// 3. Y_ONE : y has only one element
|
||||
enum class BcastShapeType {
|
||||
SAME_SHAPE = 0,
|
||||
X_ONE_ELEMENT = 1,
|
||||
Y_ONE_ELEMENT = 2,
|
||||
DIFF_SHAPE = 3,
|
||||
};
|
||||
|
||||
struct BCalcInfo {
|
||||
BCalcInfo() : input_0(nullptr), input_1(nullptr), output(nullptr) {}
|
||||
Tensor *input_0;
|
||||
Tensor *input_1;
|
||||
Tensor *output;
|
||||
std::vector<int64_t> reshape_0;
|
||||
std::vector<int64_t> reshape_1;
|
||||
std::vector<int64_t> shape_out;
|
||||
std::vector<int64_t> bcast_0;
|
||||
std::vector<int64_t> bcast_1;
|
||||
std::vector<int64_t> x_indexes;
|
||||
std::vector<int64_t> y_indexes;
|
||||
};
|
||||
|
||||
class Bcast {
|
||||
public:
|
||||
Bcast() : valid_(true){};
|
||||
Bcast(std::vector<int64_t> &x_shape, std::vector<int64_t> &y_shape);
|
||||
~Bcast() = default;
|
||||
|
||||
uint32_t GenerateBcastInfo(const BCalcInfo &calcInfo);
|
||||
void GetBcastVec(BCalcInfo &calcInfo);
|
||||
void BCastIndexes(std::vector<int64_t> &x_indexes, std::vector<int64_t> &y_indexes);
|
||||
int64_t GetBroadcastXIndex(int64_t index) const;
|
||||
int64_t GetBroadcastYIndex(int64_t index) const;
|
||||
bool IsValid() const { return valid_; }
|
||||
const std::vector<int64_t> &x_reshape() const { return x_reshape_; }
|
||||
const std::vector<int64_t> &y_reshape() const { return y_reshape_; }
|
||||
const std::vector<int64_t> &result_shape() const { return result_shape_; }
|
||||
const std::vector<int64_t> &x_bcast() const { return x_bcast_; }
|
||||
const std::vector<int64_t> &y_bcast() const { return y_bcast_; }
|
||||
|
||||
private:
|
||||
uint32_t Init(const std::vector<int64_t> &x, const std::vector<int64_t> &y);
|
||||
|
||||
bool valid_;
|
||||
std::vector<int64_t> x_reshape_;
|
||||
std::vector<int64_t> y_reshape_;
|
||||
std::vector<int64_t> shape_out_;
|
||||
std::vector<int64_t> x_bcast_;
|
||||
std::vector<int64_t> y_bcast_;
|
||||
std::vector<int64_t> result_shape_;
|
||||
std::vector<int64_t> x_input_strides_;
|
||||
std::vector<int64_t> y_input_strides_;
|
||||
std::vector<int64_t> x_output_strides_;
|
||||
std::vector<int64_t> y_output_strides_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_BCAST_H_
|
|
@ -0,0 +1,124 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/utils/broadcast_iterator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace aicpu {
|
||||
BroadcastIterator::BroadcastIterator(std::vector<int64_t> &input_shape_a, std::vector<int64_t> &input_shape_b,
|
||||
std::vector<int64_t> &output_shape)
|
||||
: input_shape_a_(std::move(input_shape_a)),
|
||||
input_shape_b_(std::move(input_shape_b)),
|
||||
output_shape_(std::move(output_shape)) {
|
||||
output_dimension_ = output_shape_.size(); // Assign dimension to int for iterator
|
||||
BroadcastShape();
|
||||
// Allocate strides memory
|
||||
input_strides_a_.resize(output_dimension_);
|
||||
input_strides_b_.resize(output_dimension_);
|
||||
input_back_strides_a_.resize(output_dimension_);
|
||||
input_back_strides_b_.resize(output_dimension_);
|
||||
coordinates_.resize(output_dimension_);
|
||||
InitStrides();
|
||||
}
|
||||
|
||||
void BroadcastIterator::SetPos(int64_t pos) {
|
||||
for (int i = output_dimension_ - 1; i >= 0 && pos != 0; --i) {
|
||||
coordinates_[i] = pos % output_shape_[i];
|
||||
input_pos_[0] += coordinates_[i] * input_strides_a_[i];
|
||||
input_pos_[1] += coordinates_[i] * input_strides_b_[i];
|
||||
pos /= output_shape_[i];
|
||||
}
|
||||
}
|
||||
|
||||
void BroadcastIterator::GenNextPos() {
|
||||
// Calculate output next coordinate
|
||||
for (int i = output_dimension_ - 1; i >= 0; --i) {
|
||||
if (coordinates_[i] + 1 == output_shape_[i]) {
|
||||
coordinates_[i] = 0;
|
||||
input_pos_[0] -= input_back_strides_a_[i];
|
||||
input_pos_[1] -= input_back_strides_b_[i];
|
||||
} else {
|
||||
++coordinates_[i];
|
||||
input_pos_[0] += input_strides_a_[i];
|
||||
input_pos_[1] += input_strides_b_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BroadcastIterator::BroadcastShape() {
|
||||
size_t input_dimension_a = input_shape_a_.size();
|
||||
if (input_dimension_a < output_dimension_) {
|
||||
input_shape_a_.insert(input_shape_a_.begin(), output_dimension_ - input_dimension_a, 1);
|
||||
}
|
||||
|
||||
size_t input_dimension_b = input_shape_b_.size();
|
||||
if (input_dimension_b < output_dimension_) {
|
||||
input_shape_b_.insert(input_shape_b_.begin(), output_dimension_ - input_dimension_b, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void BroadcastIterator::InitStrides() {
|
||||
input_strides_a_[output_dimension_ - 1] = 1;
|
||||
input_strides_b_[output_dimension_ - 1] = 1;
|
||||
for (int i = output_dimension_ - 2; i >= 0; --i) {
|
||||
input_strides_a_[i] = input_shape_a_[i + 1] * input_strides_a_[i + 1];
|
||||
input_strides_b_[i] = input_shape_b_[i + 1] * input_strides_b_[i + 1];
|
||||
input_back_strides_a_[i + 1] = (input_shape_a_[i + 1] - 1) * input_strides_a_[i + 1];
|
||||
input_back_strides_b_[i + 1] = (input_shape_b_[i + 1] - 1) * input_strides_b_[i + 1];
|
||||
}
|
||||
|
||||
// Update strides for broadcast
|
||||
// While the axis value is 1, the stride is 0
|
||||
(void)std::transform(input_strides_a_.begin(), input_strides_a_.end(), input_shape_a_.begin(),
|
||||
input_strides_a_.begin(), [](const int64_t &a, const int64_t &b) { return (b == 1) ? 0 : a; });
|
||||
(void)std::transform(input_strides_b_.begin(), input_strides_b_.end(), input_shape_b_.begin(),
|
||||
input_strides_b_.begin(), [](const int64_t &a, const int64_t &b) { return (b == 1) ? 0 : a; });
|
||||
}
|
||||
|
||||
uint32_t GetBroadcastShape(const std::vector<int64_t> &x, const std::vector<int64_t> &y,
|
||||
std::vector<int64_t> &broadcast_shape) {
|
||||
int64_t x_len = x.size();
|
||||
int64_t y_len = y.size();
|
||||
int64_t length = x_len < y_len ? x_len : y_len;
|
||||
std::vector<int64_t> broadcast_shape_back;
|
||||
for (int64_t i = -length; i < 0; ++i) {
|
||||
if (x[x_len + i] == 1) {
|
||||
broadcast_shape_back.push_back(y[y_len + i]);
|
||||
} else if (y[y_len + i] == 1) {
|
||||
broadcast_shape_back.push_back(x[x_len + i]);
|
||||
} else if (x[x_len + i] == y[y_len + i]) {
|
||||
broadcast_shape_back.push_back(x[x_len + i]);
|
||||
} else {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (length == x_len) {
|
||||
for (int64_t i = 0; i < y_len - length; ++i) {
|
||||
broadcast_shape.push_back(y[i]);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < x_len - length; ++i) {
|
||||
broadcast_shape.push_back(x[i]);
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
broadcast_shape.push_back(broadcast_shape_back[i]);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_UTILS_BROADCAST_ITERATOR_H
|
||||
#define AICPU_UTILS_BROADCAST_ITERATOR_H
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
class BroadcastIterator {
|
||||
public:
|
||||
BroadcastIterator(std::vector<int64_t> &input_shape_a, std::vector<int64_t> &input_shape_b,
|
||||
std::vector<int64_t> &output_shape);
|
||||
virtual ~BroadcastIterator() = default;
|
||||
inline int64_t GetInputPosA() const { return input_pos_[0]; }
|
||||
inline int64_t GetInputPosB() const { return input_pos_[1]; }
|
||||
/**
|
||||
* @brief set broadcast start position
|
||||
* @param broadcast start position
|
||||
*/
|
||||
void SetPos(int64_t pos);
|
||||
/**
|
||||
* @brief generate next position
|
||||
*/
|
||||
void GenNextPos();
|
||||
|
||||
private:
|
||||
void BroadcastShape();
|
||||
void InitStrides();
|
||||
|
||||
std::vector<int64_t> coordinates_;
|
||||
std::vector<int64_t> input_shape_a_;
|
||||
std::vector<int64_t> input_shape_b_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<int64_t> input_strides_a_;
|
||||
std::vector<int64_t> input_strides_b_;
|
||||
std::vector<int64_t> input_back_strides_a_;
|
||||
std::vector<int64_t> input_back_strides_b_;
|
||||
std::array<int64_t, 2> input_pos_ = {{0, 0}};
|
||||
size_t output_dimension_{0};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief get broadcast shape
|
||||
* @param shape to broadcast
|
||||
* @return status
|
||||
*/
|
||||
uint32_t GetBroadcastShape(const std::vector<int64_t> &x, const std::vector<int64_t> &y,
|
||||
std::vector<int64_t> &broadcast_shape);
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_UTILS_DISTINCT_UNIFORM_INT_DISTRIBUTION_H
|
||||
#define AICPU_UTILS_DISTINCT_UNIFORM_INT_DISTRIBUTION_H
|
||||
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
template <typename IntType = int>
|
||||
class DistinctUniformIntDistribution {
|
||||
public:
|
||||
using ResultType = IntType;
|
||||
|
||||
private:
|
||||
using SetType = std::unordered_set<ResultType>;
|
||||
using DistrType = std::uniform_int_distribution<ResultType>;
|
||||
|
||||
public:
|
||||
DistinctUniformIntDistribution(ResultType inf, ResultType sup)
|
||||
: inf_(inf), sup_(sup), range_(sup_ - inf_ + 1), distr_(inf_, sup_) {}
|
||||
~DistinctUniformIntDistribution() = default;
|
||||
void Reset() {
|
||||
uset_.clear();
|
||||
distr_.reset();
|
||||
}
|
||||
|
||||
template <typename Generator>
|
||||
ResultType exec(Generator &engine) {
|
||||
if (not(uset_.size() < range_)) {
|
||||
std::terminate();
|
||||
}
|
||||
ResultType res;
|
||||
do {
|
||||
res = distr_(engine);
|
||||
} while (uset_.count(res) > 0);
|
||||
uset_.insert(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
const ResultType inf_;
|
||||
const ResultType sup_;
|
||||
const size_t range_ = 0;
|
||||
DistrType distr_;
|
||||
SetType uset_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_UTILS_DISTINCT_UNIFORM_INT_DISTRIBUTION_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,17 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/kernel_log.h"
|
||||
|
||||
#include "eigen_tensor.h"
|
||||
|
||||
namespace aicpu {
|
||||
static int log_level = AICPU_LOG_ERROR;
|
||||
|
||||
int LogSetLevel(int level) {
|
||||
log_level = level;
|
||||
return log_level;
|
||||
}
|
||||
|
||||
int LogGetLevel(void) { return log_level; }
|
||||
|
||||
bool CheckLogLevel(int log_level_check) { return log_level >= log_level_check; }
|
||||
const Tensor *EigenTensor::GetTensor() const { return tensor_; }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_EIGENTENSOR_H
|
||||
#define AICPU_EIGENTENSOR_H
|
||||
|
||||
#include "cpu_tensor.h"
|
||||
#include "kernel_log.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace aicpu {
|
||||
// Helper to define Tensor types given that the scalar is of type T.
|
||||
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
|
||||
struct TTypes {
|
||||
// Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned> Tensor;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstTensor;
|
||||
|
||||
// Unaligned Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> > UnalignedTensor;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> > UnalignedConstTensor;
|
||||
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>, Eigen::Aligned> Tensor32Bit;
|
||||
|
||||
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
Scalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstScalar;
|
||||
|
||||
// Unaligned Scalar tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> > UnalignedScalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> >
|
||||
UnalignedConstScalar;
|
||||
|
||||
// Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> Flat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> Vec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstVec;
|
||||
|
||||
// Unaligned Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> > UnalignedFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > UnalignedConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> > UnalignedVec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > UnalignedConstVec;
|
||||
|
||||
// Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned> Matrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned> ConstMatrix;
|
||||
|
||||
// Unaligned Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> > UnalignedMatrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> > UnalignedConstMatrix;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class EigenTensor {
|
||||
public:
|
||||
EigenTensor() = delete;
|
||||
EigenTensor(Tensor *tensor, void *data) : tensor_(tensor), tensor_data_(data) {}
|
||||
~EigenTensor() = default;
|
||||
|
||||
/*
|
||||
* Get tensor
|
||||
* @return succ: tensor, error : nullptr
|
||||
*/
|
||||
const Tensor *GetTensor() const;
|
||||
|
||||
/*
|
||||
* Eigen vec
|
||||
* @return Eigen vec
|
||||
*/
|
||||
template <typename T>
|
||||
typename TTypes<T>::Vec vec() {
|
||||
return tensor<T, 1>();
|
||||
}
|
||||
|
||||
/*
|
||||
* Eigen matrix
|
||||
* @return Eigen matrix
|
||||
*/
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix matrix() {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
/*
|
||||
* Eigen ConstMatrix
|
||||
* @return Eigen ConstMatrix
|
||||
*/
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix matrix() const {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
/*
|
||||
* Eigen tensor
|
||||
* @return Eigen tensor
|
||||
*/
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor tensor() {
|
||||
return typename TTypes<T, NDIMS>::Tensor(reinterpret_cast<T *>(tensor_data_), AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
/*
|
||||
* Eigen ConstTensor
|
||||
* @return Eigen ConstTensor
|
||||
*/
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor tensor() const {
|
||||
return typename TTypes<T, NDIMS>::ConstTensor(reinterpret_cast<const T *>(tensor_data_), AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
/*
|
||||
* Eigen Flat
|
||||
* @return Eigen Flat
|
||||
*/
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat flat() {
|
||||
return typename TTypes<T>::Flat(reinterpret_cast<T *>(tensor_data_), {tensor_->GetTensorShape()->NumElements()});
|
||||
}
|
||||
|
||||
/*
|
||||
* which case we pad the rest of the sizes with 1.
|
||||
* @return Eigen::DSizes: pad the rest of the sizes with 1
|
||||
*/
|
||||
template <int NDIMS, typename IndexType>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const {
|
||||
Eigen::DSizes<IndexType, NDIMS> dsizes;
|
||||
for (int d = 0; d < tensor_->GetTensorShape()->GetDims(); d++) {
|
||||
dsizes[d] = static_cast<IndexType>(tensor_->GetTensorShape()->GetDimSize(d));
|
||||
}
|
||||
for (int d = tensor_->GetTensorShape()->GetDims(); d < NDIMS; d++) {
|
||||
dsizes[d] = 1;
|
||||
}
|
||||
return dsizes;
|
||||
}
|
||||
|
||||
/*
|
||||
* Fill `*dsizes` from `*this`
|
||||
* @return Eigen::DSizes: pad the rest of the sizes with 1
|
||||
*/
|
||||
template <int NDIMS, typename IndexType = Eigen::DenseIndex>
|
||||
Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const {
|
||||
return AsEigenDSizesWithPadding<NDIMS, IndexType>();
|
||||
}
|
||||
|
||||
private:
|
||||
Tensor *tensor_;
|
||||
void *tensor_data_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
|
||||
#endif // AICPU_EIGENTENSOR_H
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_UTILS_EQUAL_UTIL_H
|
||||
#define AICPU_UTILS_EQUAL_UTIL_H
|
||||
|
||||
#include "cpu_kernel/inc/cpu_ops_kernel.h"
|
||||
#include "utils/bcast.h"
|
||||
|
||||
namespace aicpu {
|
||||
/**
|
||||
* @brief Parameter verification
|
||||
* @param flag equal or not equal
|
||||
* @return status code
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EqualCalculate(const CpuKernelContext &ctx, BCalcInfo &calcInfo, bool flag) {
|
||||
auto input_x1 = reinterpret_cast<T *>(calcInfo.input_0->GetData());
|
||||
auto input_x2 = reinterpret_cast<T *>(calcInfo.input_1->GetData());
|
||||
auto output_y = reinterpret_cast<bool *>(calcInfo.output->GetData());
|
||||
KERNEL_CHECK_NULLPTR(input_x1, KERNEL_STATUS_PARAM_INVALID, "Get input x1 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(input_x2, KERNEL_STATUS_PARAM_INVALID, "Get input x2 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(output_y, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.")
|
||||
size_t data_num = calcInfo.x_indexes.size();
|
||||
auto shard_equal = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto x_index = input_x1 + calcInfo.x_indexes[i];
|
||||
auto y_index = input_x2 + calcInfo.y_indexes[i];
|
||||
output_y[i] = (flag == true) ? (*x_index == *y_index) : (*x_index != *y_index);
|
||||
}
|
||||
};
|
||||
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, 1, shard_equal), "Equal calculate failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
/**
|
||||
* @brief Parameter verification
|
||||
* @param ctx op context
|
||||
* @param flag equal or not equal
|
||||
* @return status code
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EqualCompute(const CpuKernelContext &ctx, bool flag) {
|
||||
BCalcInfo calcInfo;
|
||||
calcInfo.input_0 = ctx.Input(0);
|
||||
calcInfo.input_1 = ctx.Input(1);
|
||||
calcInfo.output = ctx.Output(0);
|
||||
DataType input0_type = calcInfo.input_0->GetDataType();
|
||||
DataType input1_type = calcInfo.input_1->GetDataType();
|
||||
KERNEL_CHECK_FALSE((input0_type == input1_type), KERNEL_STATUS_PARAM_INVALID,
|
||||
"DataType of x1 [%d] should be same as x2 [%d].", static_cast<int32_t>(input0_type),
|
||||
static_cast<int32_t>(input1_type))
|
||||
KERNEL_LOG_INFO(
|
||||
"CpuKernel[%s], input x1 : addr[%p], size[%llu];"
|
||||
"input x2: addr[%p], size[%llu];"
|
||||
"output: addr[%p], size[%llu].",
|
||||
ctx.GetOpType().c_str(), calcInfo.input_0->GetData(), calcInfo.input_0->GetDataSize(), calcInfo.input_1->GetData(),
|
||||
calcInfo.input_1->GetDataSize(), calcInfo.output->GetData(), calcInfo.output->GetDataSize());
|
||||
|
||||
Bcast bcast;
|
||||
KERNEL_HANDLE_ERROR(bcast.GenerateBcastInfo(calcInfo), "Generate broadcast info failed.");
|
||||
bcast.BCastIndexes(calcInfo.x_indexes, calcInfo.y_indexes);
|
||||
bcast.GetBcastVec(calcInfo);
|
||||
|
||||
return EqualCalculate<T>(ctx, calcInfo, flag);
|
||||
}
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,238 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/kernel_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace aicpu {
|
||||
namespace {
|
||||
const std::map<Format, std::string> kFormatToStringMap = {
|
||||
{FORMAT_NCHW, "NCHW"},
|
||||
{FORMAT_NHWC, "NHWC"},
|
||||
{FORMAT_ND, "ND"},
|
||||
{FORMAT_NC1HWC0, "NC1HWC0"},
|
||||
{FORMAT_FRACTAL_Z, "FRACTAL_Z"},
|
||||
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
|
||||
{FORMAT_NHWC1C0, "NHWC1C0"},
|
||||
{FORMAT_FSR_NCHW, "FSR_NCHW"},
|
||||
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
|
||||
{FORMAT_C1HWNC0, "C1HWNC0"},
|
||||
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
|
||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
|
||||
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
|
||||
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
|
||||
{FORMAT_CHWN, "CHWN"},
|
||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
|
||||
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
|
||||
{FORMAT_BN_WEIGHT, "BN_WEIGHT"},
|
||||
{FORMAT_FILTER_HWCK, "FILTER_HWCK"},
|
||||
{FORMAT_HWCN, "HWCN"},
|
||||
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
|
||||
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
|
||||
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
|
||||
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
|
||||
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
|
||||
{FORMAT_MD, "MD"},
|
||||
{FORMAT_NDHWC, "NDHWC"},
|
||||
{FORMAT_NCDHW, "NCDHW"},
|
||||
{FORMAT_DHWCN, "DHWCN"},
|
||||
{FORMAT_DHWNC, "DHWNC"},
|
||||
{FORMAT_NDC1HWC0, "NDC1HWC0"},
|
||||
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
|
||||
{FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
|
||||
{FORMAT_C1HWNCoC0, "C1HWNCoC0"},
|
||||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
|
||||
{FORMAT_CN, "CN"},
|
||||
{FORMAT_NC, "NC"},
|
||||
{FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
|
||||
{FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"},
|
||||
{FORMAT_RESERVED, "FORMAT_RESERVED"},
|
||||
{FORMAT_ALL, "ALL"},
|
||||
{FORMAT_NULL, "NULL"}};
|
||||
}
|
||||
|
||||
std::string FormatToSerialString(Format format) {
|
||||
auto it = kFormatToStringMap.find(static_cast<Format>(GetPrimaryFormat(static_cast<int32_t>(format))));
|
||||
if (it != kFormatToStringMap.end()) {
|
||||
if (HasSubFormat(static_cast<int32_t>(format))) {
|
||||
return it->second + ":" + std::to_string(GetSubFormat(static_cast<int32_t>(format)));
|
||||
}
|
||||
return it->second;
|
||||
} else {
|
||||
KERNEL_LOG_ERROR("Format not support [%u]", format);
|
||||
return "UNDEFINED";
|
||||
}
|
||||
}
|
||||
|
||||
const std::map<std::string, DataType> dtype_maps{{"DT_FLOAT", DT_FLOAT},
|
||||
{"DT_FLOAT16", DT_FLOAT16},
|
||||
{"DT_INT8", DT_INT8},
|
||||
{"DT_INT16", DT_INT16},
|
||||
{"DT_UINT16", DT_UINT16},
|
||||
{"DT_UINT8", DT_UINT8},
|
||||
{"DT_INT32", DT_INT32},
|
||||
{"DT_INT64", DT_INT64},
|
||||
{"DT_UINT32", DT_UINT32},
|
||||
{"DT_UINT64", DT_UINT64},
|
||||
{"DT_BOOL", DT_BOOL},
|
||||
{"DT_DOUBLE", DT_DOUBLE},
|
||||
{"DT_STRING", DT_STRING},
|
||||
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8},
|
||||
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8},
|
||||
{"DT_COMPLEX64", DT_COMPLEX64},
|
||||
{"DT_COMPLEX128", DT_COMPLEX128},
|
||||
{"DT_QINT8", DT_QINT8},
|
||||
{"DT_QINT16", DT_QINT16},
|
||||
{"DT_QINT32", DT_QINT32},
|
||||
{"DT_QUINT8", DT_QUINT8},
|
||||
{"DT_QUINT16", DT_QUINT16},
|
||||
{"DT_RESOURCE", DT_RESOURCE},
|
||||
{"DT_STRING_REF", DT_STRING_REF},
|
||||
{"DT_DUAL", DT_DUAL},
|
||||
{"DT_UNDEFINED", DT_UNDEFINED}};
|
||||
|
||||
bool IsEmptyTensor(Tensor *tensor) {
|
||||
auto dims = tensor->GetTensorShape()->GetDimSizes();
|
||||
if (tensor->GetData() == nullptr) {
|
||||
for (uint32_t i = 0; i < dims.size(); i++) {
|
||||
if (dims[i] == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t NormalMathCheck(CpuKernelContext &ctx) {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
|
||||
if ((ctx.GetInputsSize() != kInputNum) || (ctx.GetOutputsSize() != kOutputNum)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"[%s] Input size or Output size is unexpected,"
|
||||
"expected input size [%u], real input size [%u],"
|
||||
"expected output size [%u], real output size [%u]",
|
||||
ctx.GetOpType().c_str(), kInputNum, ctx.GetInputsSize(), kOutputNum, ctx.GetOutputsSize());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
Tensor *input_0 = ctx.Input(kFirstInputIndex);
|
||||
KERNEL_CHECK_NULLPTR(input_0, KERNEL_STATUS_PARAM_INVALID, "[%s] Get input[0] failed", ctx.GetOpType().c_str());
|
||||
Tensor *input_1 = ctx.Input(kSecondInputIndex);
|
||||
KERNEL_CHECK_NULLPTR(input_1, KERNEL_STATUS_PARAM_INVALID, "[%s] Get input[1] failed", ctx.GetOpType().c_str());
|
||||
|
||||
if (input_0->GetDataType() != input_1->GetDataType()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"[%s] dtype of inputs not matched, input[0] data_type is [%d], "
|
||||
"input[1] data_type is [%d]",
|
||||
ctx.GetOpType().c_str(), input_0->GetDataType(), input_1->GetDataType());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
Tensor *output = ctx.Output(kFirstOutputIndex);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "[%s] get output failed", ctx.GetOpType().c_str());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num) {
|
||||
if (inputs_num != kDynamicInput) {
|
||||
KERNEL_CHECK_FALSE((ctx.GetInputsSize() >= inputs_num), KERNEL_STATUS_PARAM_INVALID,
|
||||
"[%s] need [%u] inputs, but got [%u].", ctx.GetOpType().c_str(), inputs_num,
|
||||
ctx.GetInputsSize());
|
||||
for (uint32_t i = 0; i < inputs_num; ++i) {
|
||||
Tensor *input = ctx.Input(i);
|
||||
KERNEL_CHECK_NULLPTR(input, KERNEL_STATUS_INNER_ERROR, "[%s] get input[%u] failed.", ctx.GetOpType().c_str(), i);
|
||||
auto input_shape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(input_shape, KERNEL_STATUS_PARAM_INVALID, "%s input[%u] tensor shape is nullptr.",
|
||||
ctx.GetOpType().c_str(), i);
|
||||
if (!IsEmptyTensor(input)) {
|
||||
auto input_data = input->GetData();
|
||||
KERNEL_CHECK_NULLPTR(input_data, KERNEL_STATUS_PARAM_INVALID, "%s input[%u] tensor data is nullptr.",
|
||||
ctx.GetOpType().c_str(), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs_num != kDynamicOutput) {
|
||||
KERNEL_CHECK_FALSE((ctx.GetOutputsSize() == outputs_num), KERNEL_STATUS_PARAM_INVALID,
|
||||
"[%s] need [%u] outputs, but got [%u].", ctx.GetOpType().c_str(), outputs_num,
|
||||
ctx.GetOutputsSize());
|
||||
for (uint32_t i = 0; i < outputs_num; ++i) {
|
||||
Tensor *output = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_INNER_ERROR, "[%s] get output[%u] failed.", ctx.GetOpType().c_str(),
|
||||
i);
|
||||
auto output_shape = output->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(output_shape, KERNEL_STATUS_PARAM_INVALID, "%s output[%u] tensor shape is nullptr.",
|
||||
ctx.GetOpType().c_str(), i);
|
||||
if (!IsEmptyTensor(output)) {
|
||||
auto output_data = output->GetData();
|
||||
KERNEL_CHECK_NULLPTR(output_data, KERNEL_STATUS_PARAM_INVALID, "%s output[%u] tensor data is nullptr.",
|
||||
ctx.GetOpType().c_str(), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
const std::vector<std::string> &attr_names) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, inputs_num, outputs_num), "Check Greater params failed.");
|
||||
for (auto const &attr_name : attr_names) {
|
||||
auto attr = ctx.GetAttr(attr_name);
|
||||
KERNEL_CHECK_NULLPTR(attr, KERNEL_STATUS_PARAM_INVALID, "%s get attr[%s] is nullptr.", ctx.GetOpType().c_str(),
|
||||
attr_name.c_str());
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
bool IsScalar(const std::vector<int64_t> &shape) { return (shape.size() == 0); }
|
||||
|
||||
bool IsVector(const std::vector<int64_t> &shape) { return (shape.size() == 1); }
|
||||
|
||||
bool IsMatrix(const std::vector<int64_t> &shape) { return (shape.size() == 2); }
|
||||
|
||||
bool IsSquareMatrix(const std::vector<int64_t> &shape) { return ((shape.size() == 2) && (shape[0] == shape[1])); }
|
||||
|
||||
bool AddrAlignedCheck(const void *addr, uint64_t alignment) {
|
||||
return reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(addr)) % alignment == 0;
|
||||
}
|
||||
|
||||
bool IsVectorOrHigher(const std::vector<int64_t> &shape) { return (shape.size() >= 1); }
|
||||
|
||||
DataType DType(std::string dtype_str) {
|
||||
auto iter = dtype_maps.find(dtype_str);
|
||||
if (iter != dtype_maps.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
return DT_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
std::string DTypeStr(DataType dtype) {
|
||||
auto iter =
|
||||
std::find_if(dtype_maps.begin(), dtype_maps.end(),
|
||||
[dtype](const std::map<std::string, DataType>::value_type &kv) { return (kv.second == dtype); });
|
||||
if (iter != dtype_maps.end()) {
|
||||
return iter->first;
|
||||
} else {
|
||||
return std::string("DT_UNDEFINED");
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,254 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_UTILS_KERNEL_UTIL_H_
|
||||
#define AICPU_UTILS_KERNEL_UTIL_H_
|
||||
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_kernel/inc/cpu_context.h"
|
||||
#include "mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/kernel_log.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
namespace aicpu {
|
||||
constexpr uint32_t kResvCpuNum = 2;
|
||||
constexpr uint32_t kThreadNum = 32;
|
||||
constexpr uint32_t kFirstInputIndex = 0;
|
||||
constexpr uint32_t kSecondInputIndex = 1;
|
||||
constexpr uint32_t kThirdInputIndex = 2;
|
||||
constexpr uint32_t kFirstOutputIndex = 0;
|
||||
constexpr uint32_t kSecondOutputIndex = 1;
|
||||
constexpr uint32_t kDynamicInput = -1;
|
||||
constexpr uint32_t kDynamicOutput = -2;
|
||||
constexpr uint64_t kEigenAlignmentBytes = 16;
|
||||
|
||||
constexpr uint64_t kFormatNCHWIndexN = 0;
|
||||
constexpr uint64_t kFormatNCHWIndexC = 1;
|
||||
constexpr uint64_t kFormatNCHWIndexH = 2;
|
||||
constexpr uint64_t kFormatNCHWIndexW = 3;
|
||||
|
||||
constexpr uint64_t kFormatCHWIndexC = 0;
|
||||
constexpr uint64_t kFormatCHWIndexH = 1;
|
||||
constexpr uint64_t kFormatCHWIndexW = 2;
|
||||
|
||||
constexpr uint64_t kFormatNHWCIndexN = 0;
|
||||
constexpr uint64_t kFormatNHWCIndexH = 1;
|
||||
constexpr uint64_t kFormatNHWCIndexW = 2;
|
||||
constexpr uint64_t kFormatNHWCIndexC = 3;
|
||||
|
||||
constexpr uint64_t kFormatHWCIndexH = 0;
|
||||
constexpr uint64_t kFormatHWCIndexW = 1;
|
||||
constexpr uint64_t kFormatHWCIndexC = 2;
|
||||
|
||||
const size_t INPUT_NUM0 = 0;
|
||||
const size_t INPUT_NUM1 = 1;
|
||||
const size_t INPUT_NUM2 = 2;
|
||||
const size_t INPUT_NUM3 = 3;
|
||||
const size_t INPUT_NUM4 = 4;
|
||||
const size_t INPUT_NUM5 = 5;
|
||||
const size_t INPUT_NUM6 = 6;
|
||||
const size_t INPUT_NUM7 = 7;
|
||||
const size_t INPUT_NUM8 = 8;
|
||||
const size_t INPUT_NUM9 = 9;
|
||||
const size_t INPUT_NUM32 = 32;
|
||||
/*
|
||||
* str cat util function
|
||||
* param[in] params need concat to string
|
||||
* return concatted string
|
||||
*/
|
||||
template <typename T>
|
||||
std::string ConcatString(T arg) {
|
||||
std::ostringstream oss;
|
||||
oss << arg;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
std::string ConcatString(T arg, Ts... arg_left) {
|
||||
std::ostringstream oss;
|
||||
oss << arg;
|
||||
oss << ConcatString(arg_left...);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get debug string of vector
|
||||
* @param values values in vector
|
||||
* @return string of values
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::string VectorToString(const std::vector<T> &values) {
|
||||
std::stringstream ss;
|
||||
for (auto iter = values.begin(); iter != values.end(); ++iter) {
|
||||
ss << *iter;
|
||||
if (iter != values.end() - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string FmtToStr(const T &t) {
|
||||
std::string fmt;
|
||||
std::stringstream st;
|
||||
st << "[" << t << "]";
|
||||
fmt = st.str();
|
||||
return fmt;
|
||||
}
|
||||
|
||||
std::string FormatToSerialString(Format format);
|
||||
|
||||
/**
|
||||
* Get primary-format from format,
|
||||
* in bits field:
|
||||
* ------------------------------------------
|
||||
* | 1 byte | 2 bytes | 1 byt |
|
||||
* |----------|------------|----------------|
|
||||
* | reserved | sub-format | primary-format |
|
||||
* ------------------------------------------
|
||||
* @param format
|
||||
* @return
|
||||
*/
|
||||
inline int32_t GetPrimaryFormat(int32_t format) { return static_cast<int32_t>(static_cast<uint32_t>(format) & 0xff); }
|
||||
|
||||
inline int32_t GetSubFormat(int32_t format) {
|
||||
return static_cast<int32_t>((static_cast<uint32_t>(format) & 0xffff00) >> 8);
|
||||
}
|
||||
|
||||
inline bool HasSubFormat(int32_t format) { return GetSubFormat(format) > 0; }
|
||||
|
||||
/**
|
||||
* @brief Judge whether tensor is empty
|
||||
* @param tensor need judged tensor
|
||||
* @return true: is empty tensor, false: isn't empty tensor
|
||||
*/
|
||||
bool IsEmptyTensor(Tensor *tensor);
|
||||
|
||||
/**
|
||||
* @brief multiply two nonnegative int64's
|
||||
* @param x mul value x
|
||||
* @param y mul value y
|
||||
* @param xy product of x and y
|
||||
* @return true: normal, false: overflow
|
||||
*/
|
||||
inline bool MulWithoutOverflow(const int64_t x, const int64_t y, int64_t &xy) {
|
||||
// Multiply in uint64 rather than int64 since signed overflow is undefined.
|
||||
// Negative values will wrap around to large unsigned values in the casts
|
||||
// (see section 4.7 [conv.integral] of the C++14 standard).
|
||||
const uint64_t ux = static_cast<uint64_t>(x);
|
||||
const uint64_t uy = static_cast<uint64_t>(y);
|
||||
const uint64_t uxy = ux * uy;
|
||||
|
||||
// Check if we overflow uint64, using a cheap check if both inputs are small
|
||||
if ((ux | uy) >> 32 != 0) {
|
||||
// Ensure nonnegativity. Note that negative numbers will appear "large"
|
||||
// to the unsigned comparisons above.
|
||||
if (x < 0 || y < 0) {
|
||||
KERNEL_LOG_ERROR("Can't multiply negative numbers.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise, detect overflow using a division
|
||||
if (ux != 0 && uxy / ux != uy) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Cast back to signed. Any negative value will signal an error.
|
||||
xy = static_cast<int64_t>(uxy);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief add two int64's
|
||||
* @param x add value x
|
||||
* @param y add value y
|
||||
* @param sum sum of x and y
|
||||
* @return true: normal, false: overflow
|
||||
*/
|
||||
inline bool AddWithoutOverflow(const int64_t x, const int64_t y, int64_t &sum) {
|
||||
const uint64_t ux = static_cast<uint64_t>(x);
|
||||
const uint64_t uy = static_cast<uint64_t>(y);
|
||||
const uint64_t usum = ux + uy;
|
||||
sum = static_cast<int64_t>(usum);
|
||||
|
||||
return !(((x >= 0) == (y >= 0)) && ((sum >= 0) != (x >= 0)));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief normal check for calculation
|
||||
* @param ctx context
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalMathCheck(CpuKernelContext &ctx);
|
||||
|
||||
/**
|
||||
* @brief normal check for kernel
|
||||
* @param ctx context
|
||||
* @param inputs_num num of inputs
|
||||
* @param outputs_num num of outputs
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num);
|
||||
|
||||
/**
|
||||
* @brief normal check for kernel
|
||||
* @param ctx context
|
||||
* @param inputs_num num of inputs
|
||||
* @param outputs_num num of outputs
|
||||
* @param attr_names names of attrs
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
const std::vector<std::string> &attr_names);
|
||||
|
||||
bool IsScalar(const std::vector<int64_t> &shape);
|
||||
|
||||
bool IsMatrix(const std::vector<int64_t> &shape);
|
||||
|
||||
bool IsVector(const std::vector<int64_t> &shape);
|
||||
|
||||
bool IsSquareMatrix(const std::vector<int64_t> &shape);
|
||||
/**
|
||||
* @brief check if addr is aligned
|
||||
* @param addr address for check
|
||||
* @return true: aligned, false: not aligned
|
||||
*/
|
||||
bool AddrAlignedCheck(const void *addr, uint64_t alignment = kEigenAlignmentBytes);
|
||||
|
||||
bool IsVectorOrHigher(const std::vector<int64_t> &shape);
|
||||
|
||||
/**
|
||||
* @brief get data type from string
|
||||
* @param dtype_str string of data type
|
||||
* @return DataType
|
||||
*/
|
||||
DataType DType(std::string dtype_str);
|
||||
|
||||
/**
|
||||
* @brief get string from data type
|
||||
* @param dtype data type
|
||||
* @return string of data type
|
||||
*/
|
||||
std::string DTypeStr(DataType dtype);
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_PHILOX_RANDOM_H
|
||||
#define _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_PHILOX_RANDOM_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
/**
|
||||
* A class that represents an inline array.
|
||||
* Arguments:
|
||||
* T: the array element type;
|
||||
* ElementCount: the fixed size of the array;
|
||||
*/
|
||||
template <typename T, int ElementCount>
|
||||
class Array {
|
||||
public:
|
||||
static constexpr int kElementCount = ElementCount;
|
||||
Array() {
|
||||
for (int i = 0; i < ElementCount; ++i) {
|
||||
data_[i] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
const T &operator[](int index) const { return data_[index]; }
|
||||
|
||||
T &operator[](int index) { return data_[index]; }
|
||||
|
||||
size_t size() const { return ElementCount; }
|
||||
|
||||
private:
|
||||
T data_[ElementCount];
|
||||
};
|
||||
|
||||
class PhiloxRandom {
|
||||
public:
|
||||
using ResultType = Array<uint32_t, 4>;
|
||||
using ResultElementType = uint32_t;
|
||||
// The number of elements that will be returned.
|
||||
static constexpr int kResultElementCount = 4;
|
||||
// Cost of generation of a single element (in cycles).
|
||||
static constexpr int kElementCost = 10;
|
||||
/*
|
||||
* The type for the 64-bit key stored in the form of two 32-bit uint
|
||||
* that are used in the diffusion process.
|
||||
*/
|
||||
using Key = Array<uint32_t, 2>;
|
||||
|
||||
PhiloxRandom() {}
|
||||
|
||||
PhiloxRandom(int64_t seed, uint64_t offset) {
|
||||
const uint32_t seed_low_index = 0;
|
||||
const uint32_t seed_high_index = 1;
|
||||
const uint32_t offset_low_index = 2;
|
||||
const uint32_t offset_high_index = 3;
|
||||
key_[seed_low_index] = static_cast<uint32_t>(seed);
|
||||
key_[seed_high_index] = static_cast<uint32_t>(seed >> 32);
|
||||
counter_[offset_low_index] = static_cast<uint32_t>(offset);
|
||||
counter_[offset_high_index] = static_cast<uint32_t>(offset >> 32);
|
||||
}
|
||||
|
||||
ResultType const &counter() const { return counter_; }
|
||||
|
||||
Key const &key() const { return key_; }
|
||||
|
||||
// Skip the specified number of samples of 128-bits in the current stream.
|
||||
void Skip(uint64_t count) {
|
||||
const uint32_t count_lo = static_cast<uint32_t>(count);
|
||||
uint32_t count_hi = static_cast<uint32_t>(count >> 32);
|
||||
|
||||
counter_[0] += count_lo;
|
||||
if (counter_[0] < count_lo) {
|
||||
++count_hi;
|
||||
}
|
||||
|
||||
counter_[1] += count_hi;
|
||||
if (counter_[1] < count_hi) {
|
||||
if (++counter_[2] == 0) {
|
||||
++counter_[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
* Returns a group of four random numbers using the underlying Philox
|
||||
* algorithm.
|
||||
*/
|
||||
ResultType operator()() {
|
||||
ResultType counter = counter_;
|
||||
Key key = key_;
|
||||
/*
|
||||
* Run the single rounds for ten times. Manually unrolling the loop
|
||||
* for better performance.
|
||||
*/
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
RaiseKey(&key);
|
||||
counter = ComputeSingleRound(counter, key);
|
||||
SkipOne();
|
||||
return counter;
|
||||
}
|
||||
|
||||
private:
|
||||
// We use the same constants as recommended by the original paper.
|
||||
static constexpr uint32_t kPhiloxW32A = 0x9E3779B9;
|
||||
static constexpr uint32_t kPhiloxW32B = 0xBB67AE85;
|
||||
static constexpr uint32_t kPhiloxM4x32A = 0xD2511F53;
|
||||
static constexpr uint32_t kPhiloxM4x32B = 0xCD9E8D57;
|
||||
|
||||
// Helper function to skip the next sample of 128-bits in the current stream.
|
||||
void SkipOne() {
|
||||
if (++counter_[0] == 0) {
|
||||
if (++counter_[1] == 0) {
|
||||
if (++counter_[2] == 0) {
|
||||
++counter_[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
* Helper function to return the lower and higher 32-bits from two 32-bit
|
||||
* integer multiplications.
|
||||
*/
|
||||
static void MultiplyHighLow(uint32_t a, uint32_t b, uint32_t *result_low, uint32_t *result_high) {
|
||||
const uint64_t product = static_cast<uint64_t>(a) * b;
|
||||
*result_low = static_cast<uint32_t>(product);
|
||||
*result_high = static_cast<uint32_t>(product >> 32);
|
||||
}
|
||||
|
||||
// Helper function for a single round of the underlying Philox algorithm.
|
||||
static ResultType ComputeSingleRound(const ResultType &counter, const Key &key) {
|
||||
uint32_t lo0;
|
||||
uint32_t hi0;
|
||||
MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
|
||||
|
||||
uint32_t lo1;
|
||||
uint32_t hi1;
|
||||
MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
|
||||
|
||||
ResultType result;
|
||||
result[0] = hi1 ^ counter[1] ^ key[0];
|
||||
result[1] = lo1;
|
||||
result[2] = hi0 ^ counter[3] ^ key[1];
|
||||
result[3] = lo0;
|
||||
return result;
|
||||
}
|
||||
|
||||
void RaiseKey(Key *key) {
|
||||
(*key)[0] += kPhiloxW32A;
|
||||
(*key)[1] += kPhiloxW32B;
|
||||
}
|
||||
|
||||
private:
|
||||
ResultType counter_;
|
||||
Key key_;
|
||||
};
|
||||
#endif // _AICPU_AICPU_DEVICE_CPU_KERNELS_UTILS_PHILOX_RANDOM_H_
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sampling_kernels.h"
|
||||
#include <algorithm>
|
||||
#include "kernel_log.h"
|
||||
#include "status.h"
|
||||
using namespace std;
|
||||
|
||||
namespace aicpu {
|
||||
SamplingKernelType SamplingKernelTypeFromString(std::string str) {
|
||||
if (str == "lanczos1") return Lanczos1Kernel;
|
||||
if (str == "lanczos3") return Lanczos3Kernel;
|
||||
if (str == "lanczos5") return Lanczos5Kernel;
|
||||
if (str == "gaussian") return GaussianKernel;
|
||||
if (str == "box") return BoxKernel;
|
||||
if (str == "triangle") return TriangleKernel;
|
||||
if (str == "keyscubic") return KeysCubicKernel;
|
||||
if (str == "mitchellcubic") return MitchellCubicKernel;
|
||||
return SamplingKernelTypeEnd;
|
||||
}
|
||||
} // namespace aicpu
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue