forked from mindspore-Ecosystem/mindspore
parent 8f5852f526
author lilinjie <lilinjie11@huawei.com> 1672055454 +0800 committer lilinjie <lilinjie11@huawei.com> 1672146913 +0800 roll back cpu_kernel/common
This commit is contained in:
parent
8f5852f526
commit
07dffaea7e
|
@ -95,3 +95,8 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constParameter"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "constVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "nullPointerRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "variableScope"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "useStlAlgorithm"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "constParameter"
|
||||
|
|
|
@ -107,28 +107,21 @@
|
|||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include"
|
||||
|
||||
# AICPU migration
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "whitespace/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/" "build/namespaces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/utils/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/namespaces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/multiline_comment"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/parens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/alt_tokens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/comments"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/string"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "cpu_kernel/common/notification.h"
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t AsyncCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t AsyncCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
Notification n;
|
||||
uint32_t ret = ComputeAsync(ctx, [&n](uint32_t status) { n.Notify(); });
|
||||
n.WaitForNotification();
|
||||
|
|
|
@ -26,9 +26,9 @@ class AICPU_VISIBILITY AsyncCpuKernel : public CpuKernel {
|
|||
|
||||
using DoneCallback = std::function<void(uint32_t status)>;
|
||||
|
||||
virtual uint32_t ComputeAsync(const CpuKernelContext &ctx, DoneCallback done) = 0;
|
||||
virtual uint32_t ComputeAsync(CpuKernelContext &ctx, DoneCallback done) = 0;
|
||||
|
||||
uint32_t Compute(const CpuKernelContext &ctx);
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // ASYNC_CPU_KERNEL_H
|
||||
#endif // ASYNC_CPU_KERNEL_H
|
|
@ -21,12 +21,13 @@
|
|||
#include "aicpu_sharder/aicpu_context.h"
|
||||
|
||||
namespace aicpu {
|
||||
using NotifyWaitFunc = void (*)(void *notify_param, const uint32_t param_len);
|
||||
using RegEventCbFunc = bool (*)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb);
|
||||
using RegEventCbWithTimesFunc = bool (*)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb, const int32_t times);
|
||||
using UnregEventCbFunc = void (*)(const uint32_t event_id, const uint32_t sub_event_id);
|
||||
typedef void (*NotifyWaitFunc)(void *notify_param, const uint32_t param_len);
|
||||
typedef bool (*RegEventCbFunc)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb);
|
||||
typedef bool (*RegEventCbWithTimesFunc)(const uint32_t event_id, const uint32_t sub_event_id,
|
||||
const std::function<void(void *)> &cb, const int32_t times);
|
||||
typedef void (*UnregEventCbFunc)(const uint32_t event_id, const uint32_t sub_event_id);
|
||||
|
||||
class AsyncEventUtil {
|
||||
public:
|
||||
static AsyncEventUtil &GetInstance();
|
||||
|
@ -53,4 +54,4 @@ class AsyncEventUtil {
|
|||
UnregEventCbFunc unreg_event_cb_func_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_ASYNC_EVENT_H_
|
||||
#endif // AICPU_CONTEXT_COMMON_ASYNC_EVENT_H_
|
|
@ -49,8 +49,8 @@ int32_t CpuKernelCache::InitParameter() {
|
|||
/*
|
||||
* update framework output tensor shape.
|
||||
*/
|
||||
uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg *ext_info_msg, const CpuKernelContext &ctx) const {
|
||||
if (ext_info_msg->unknown_shape) {
|
||||
uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg &ext_info_msg, const CpuKernelContext &ctx) const {
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
for (size_t i = 0; i < ctx.GetOutputsSize(); ++i) {
|
||||
Tensor *output = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
|
||||
|
@ -58,12 +58,12 @@ uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg *ext_info_msg, const Cp
|
|||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
|
||||
|
||||
for (int32_t index = 0; index < shape->GetDims(); ++index) {
|
||||
ext_info_msg->output_shape_and_type[i]->dims[index] = shape->GetDimSize(index);
|
||||
ext_info_msg.output_shape_and_type[i]->dims[index] = shape->GetDimSize(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto it = ext_info_msg->unknown_shape_output_index_addr.begin();
|
||||
it != ext_info_msg->unknown_shape_output_index_addr.end(); ++it) {
|
||||
for (auto it = ext_info_msg.unknown_shape_output_index_addr.begin();
|
||||
it != ext_info_msg.unknown_shape_output_index_addr.end(); ++it) {
|
||||
Tensor *output = ctx.Output(it->first);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%u] failed.", it->first)
|
||||
auto shape = output->GetTensorShape();
|
||||
|
@ -83,7 +83,7 @@ uint32_t CpuKernelCache::UpdateFWKOutputShape(ExtInfoMsg *ext_info_msg, const Cp
|
|||
* get shape information from framework.
|
||||
*/
|
||||
void CpuKernelCache::GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *shape_and_type,
|
||||
std::vector<int64_t> *dims) const {
|
||||
std::vector<int64_t> &dims) const {
|
||||
for (uint32_t index = 0; index < FWKAdapter::kMaxShapeDims; ++index) {
|
||||
// LLONG_MIN for dim end flag
|
||||
if (shape_and_type->dims[index] == LLONG_MIN) {
|
||||
|
@ -91,22 +91,22 @@ void CpuKernelCache::GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *sha
|
|||
}
|
||||
int64_t dim_value = shape_and_type->dims[index];
|
||||
KERNEL_LOG_INFO("Get extend shape[%u] is [%ld]", index, dim_value);
|
||||
dims->emplace_back(dim_value);
|
||||
dims.emplace_back(dim_value);
|
||||
}
|
||||
}
|
||||
|
||||
void CpuKernelCache::GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> *dims) const {
|
||||
void CpuKernelCache::GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> &dims) const {
|
||||
for (size_t index = 0; index < len; ++index) {
|
||||
KERNEL_LOG_INFO("Get arrays shape[%zu] is [%ld]", index, shape[index]);
|
||||
dims->emplace_back(shape[index]);
|
||||
dims.emplace_back(shape[index]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* update tensor information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg *ext_info_msg,
|
||||
const CpuKernelContext &ctx) const {
|
||||
uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg &ext_info_msg,
|
||||
CpuKernelContext &ctx) const {
|
||||
KERNEL_LOG_INFO("Update tensor info begin.");
|
||||
if (io_addrs.size() != ctx.GetInputsSize() + ctx.GetOutputsSize()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
|
@ -116,14 +116,14 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if ((ext_info_msg->unknown_shape) && ((ext_info_msg->input_shape_and_type.size() != ctx.GetInputsSize()) ||
|
||||
(ext_info_msg->output_shape_and_type.size() != ctx.GetOutputsSize()))) {
|
||||
if ((ext_info_msg.unknown_shape) && ((ext_info_msg.input_shape_and_type.size() != ctx.GetInputsSize()) ||
|
||||
(ext_info_msg.output_shape_and_type.size() != ctx.GetOutputsSize()))) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Input shape_and_type size error, input size[%zu], input "
|
||||
"shape_and_type "
|
||||
"size[%zu], output size[%zu], output shape_and_type size[%zu].",
|
||||
ctx.GetInputsSize(), ext_info_msg->input_shape_and_type.size(), ctx.GetOutputsSize(),
|
||||
ext_info_msg->output_shape_and_type.size());
|
||||
ctx.GetInputsSize(), ext_info_msg.input_shape_and_type.size(), ctx.GetOutputsSize(),
|
||||
ext_info_msg.output_shape_and_type.size());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
@ -131,8 +131,8 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
for (size_t i = 0; i < ctx.GetInputsSize(); ++i, ++addr_index) {
|
||||
Tensor *input = ctx.Input(i);
|
||||
KERNEL_CHECK_NULLPTR(input, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] failed.", i)
|
||||
auto iter = ext_info_msg->unknown_shape_input_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg->unknown_shape_input_index_addr.end()) {
|
||||
auto iter = ext_info_msg.unknown_shape_input_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg.unknown_shape_input_index_addr.end()) {
|
||||
iter->second = io_addrs[addr_index];
|
||||
ge::RuntimeTensorDesc *tensor_desc =
|
||||
reinterpret_cast<ge::RuntimeTensorDesc *>(static_cast<uintptr_t>(io_addrs[addr_index]));
|
||||
|
@ -140,7 +140,7 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
KERNEL_CHECK_FALSE((tensor_desc->shape[0] <= ge::kMaxDimSize), KERNEL_STATUS_PARAM_INVALID,
|
||||
"Max shape size[%lld], but got input[%zu] shape size[%lld]", ge::kMaxDimSize, i,
|
||||
tensor_desc->shape[0])
|
||||
GetDimsFromArrays(&(tensor_desc->shape[1]), static_cast<size_t>(tensor_desc->shape[0]), &dims);
|
||||
GetDimsFromArrays(&(tensor_desc->shape[1]), static_cast<size_t>(tensor_desc->shape[0]), dims);
|
||||
auto shape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
|
@ -149,9 +149,9 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
input->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(io_addrs[addr_index])));
|
||||
}
|
||||
|
||||
if (ext_info_msg->unknown_shape) {
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
std::vector<int64_t> dims;
|
||||
GetDimsFromShapeAndType(ext_info_msg->input_shape_and_type[i], &dims);
|
||||
GetDimsFromShapeAndType(ext_info_msg.input_shape_and_type[i], dims);
|
||||
auto shape = input->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get input[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
|
@ -165,13 +165,13 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
KERNEL_LOG_INFO("Set input[%zu] addr[%lu] success.", i, io_addrs[addr_index]);
|
||||
}
|
||||
|
||||
bool no_tiling = ext_info_msg->unknown_shape_output_index_addr.empty();
|
||||
bool no_tiling = ext_info_msg.unknown_shape_output_index_addr.empty();
|
||||
|
||||
for (size_t i = 0; i < ctx.GetOutputsSize(); i++, addr_index++) {
|
||||
Tensor *output = ctx.Output(i);
|
||||
KERNEL_CHECK_NULLPTR(output, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] failed.", i)
|
||||
auto iter = ext_info_msg->unknown_shape_output_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg->unknown_shape_output_index_addr.end()) {
|
||||
auto iter = ext_info_msg.unknown_shape_output_index_addr.find(static_cast<uint32_t>(i));
|
||||
if (iter != ext_info_msg.unknown_shape_output_index_addr.end()) {
|
||||
iter->second = io_addrs[addr_index];
|
||||
ge::RuntimeTensorDesc *tensor_desc =
|
||||
reinterpret_cast<ge::RuntimeTensorDesc *>(static_cast<uintptr_t>(io_addrs[addr_index]));
|
||||
|
@ -180,15 +180,15 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
output->SetData(reinterpret_cast<void *>(static_cast<uintptr_t>(io_addrs[addr_index])));
|
||||
}
|
||||
|
||||
if (ext_info_msg->unknown_shape) {
|
||||
if (ext_info_msg.unknown_shape) {
|
||||
std::vector<int64_t> dims;
|
||||
GetDimsFromShapeAndType(ext_info_msg->output_shape_and_type[i], &dims);
|
||||
GetDimsFromShapeAndType(ext_info_msg.output_shape_and_type[i], dims);
|
||||
auto shape = output->GetTensorShape();
|
||||
KERNEL_CHECK_NULLPTR(shape, KERNEL_STATUS_PARAM_INVALID, "Get output[%zu] shape failed.", i)
|
||||
shape->SetDimSizes(dims);
|
||||
}
|
||||
|
||||
KERNEL_CHECK_FALSE((ext_info_msg->unknown_shape || (!no_tiling) || (output->NumElements() >= 0)),
|
||||
KERNEL_CHECK_FALSE((ext_info_msg.unknown_shape || (!no_tiling) || (output->NumElements() >= 0)),
|
||||
KERNEL_STATUS_PARAM_INVALID,
|
||||
"Output[%zu] data elements number must be >= 0 "
|
||||
"when known shape, got size[%lld].",
|
||||
|
@ -203,7 +203,7 @@ uint32_t CpuKernelCache::UpdateTensor(const std::vector<uint64_t> &io_addrs, Ext
|
|||
/*
|
||||
* parse extend tensor shape types information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool *unknown_shape) const {
|
||||
uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) const {
|
||||
if (ext_info->infoLen != sizeof(int32_t)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend shape type failed, as info length must be [%zu], but got "
|
||||
|
@ -212,7 +212,7 @@ uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info,
|
|||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
*unknown_shape = true;
|
||||
unknown_shape = true;
|
||||
KERNEL_LOG_INFO("Kernel has unknown shape.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ uint32_t CpuKernelCache::ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info,
|
|||
* parse extend tensor shape and types information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtShapeAndType(bool unknown_shape, FWKAdapter::ExtInfo *ext_info,
|
||||
std::vector<FWKAdapter::ShapeAndType *> *shape_and_type) const {
|
||||
std::vector<FWKAdapter::ShapeAndType *> &shape_and_type) const {
|
||||
if (!unknown_shape) {
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ uint32_t CpuKernelCache::ParseExtShapeAndType(bool unknown_shape, FWKAdapter::Ex
|
|||
|
||||
auto shapes = reinterpret_cast<FWKAdapter::ShapeAndType *>(ext_info->infoMsg);
|
||||
for (uint32_t index = 0; index < size; ++index) {
|
||||
shape_and_type->emplace_back(&shapes[index]);
|
||||
shape_and_type.emplace_back(&shapes[index]);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -246,7 +246,7 @@ uint32_t CpuKernelCache::ParseExtShapeAndType(bool unknown_shape, FWKAdapter::Ex
|
|||
/*
|
||||
* parse extend session information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t *kernel_id) const {
|
||||
uint32_t CpuKernelCache::ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t &kernel_id) const {
|
||||
// no overflow
|
||||
KERNEL_LOG_INFO("Parse extend session info.");
|
||||
auto need_len = sizeof(SessionInfo);
|
||||
|
@ -259,7 +259,7 @@ uint32_t CpuKernelCache::ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint
|
|||
}
|
||||
|
||||
auto session = reinterpret_cast<SessionInfo *>(ext_info->infoMsg);
|
||||
*kernel_id = session->kernelId;
|
||||
kernel_id = session->kernelId;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
|
@ -271,7 +271,7 @@ bool CpuKernelCache::GetBitStatus(uint64_t num, uint64_t pos) { return ((num & (
|
|||
/*
|
||||
* parse bitmap information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool *unknown_shape) {
|
||||
uint32_t CpuKernelCache::ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) {
|
||||
if (ext_info->infoLen != sizeof(int64_t)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend bitmap failed, as info length must be [%zu], but got "
|
||||
|
@ -281,27 +281,27 @@ uint32_t CpuKernelCache::ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, boo
|
|||
}
|
||||
|
||||
uint64_t bit_map = *(reinterpret_cast<const int64_t *>(ext_info->infoMsg));
|
||||
*unknown_shape = (!GetBitStatus(bit_map, 0));
|
||||
KERNEL_LOG_INFO("Unknown_shape_ is [%d].", *unknown_shape);
|
||||
unknown_shape = (!GetBitStatus(bit_map, 0));
|
||||
KERNEL_LOG_INFO("Unknown_shape_ is [%d].", unknown_shape);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
// parse async wait info
|
||||
uint32_t CpuKernelCache::ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t *wait_type, uint32_t *wait_id) const {
|
||||
uint32_t CpuKernelCache::ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t &wait_type, uint32_t &wait_id) const {
|
||||
if (ext_info->infoLen != sizeof(FWKAdapter::AsyncWait)) {
|
||||
KERNEL_LOG_ERROR("Parse extend async wait failed, as info length must be [%zu], but got [%u].",
|
||||
sizeof(FWKAdapter::AsyncWait), ext_info->infoLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
FWKAdapter::AsyncWait *async_info = reinterpret_cast<FWKAdapter::AsyncWait *>(ext_info->infoMsg);
|
||||
*wait_type = async_info->waitType;
|
||||
*wait_id = async_info->waitId;
|
||||
KERNEL_LOG_INFO("async wait type [%u], notify_id[%u].", *wait_type, *wait_id);
|
||||
wait_type = async_info->waitType;
|
||||
wait_id = async_info->waitId;
|
||||
KERNEL_LOG_INFO("async wait type [%u], notify_id[%u].", wait_type, wait_id);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelCache::ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info,
|
||||
std::map<uint32_t, uint64_t> *unknown_shape_index_addr) const {
|
||||
std::map<uint32_t, uint64_t> &unknown_shape_index_addr) const {
|
||||
if (ext_info->infoLen % sizeof(uint32_t) != 0) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse unknown shape index extend info length[%u] failed, must be "
|
||||
|
@ -313,7 +313,7 @@ uint32_t CpuKernelCache::ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info
|
|||
KERNEL_LOG_INFO("Parse extend unknown shape index, size[%u].", size);
|
||||
auto indexes = reinterpret_cast<uint32_t *>(ext_info->infoMsg);
|
||||
for (uint32_t i = 0U; i < size; ++i) {
|
||||
(*unknown_shape_index_addr)[indexes[i]] = 0U;
|
||||
unknown_shape_index_addr[indexes[i]] = 0U;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -321,10 +321,10 @@ uint32_t CpuKernelCache::ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info
|
|||
/*
|
||||
* parse extend information.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg *ext_info_msg) {
|
||||
uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg &ext_info_msg) {
|
||||
KERNEL_LOG_INFO("Parse extend info and update shape begin.");
|
||||
uint32_t offset = 0;
|
||||
ext_info_msg->async_flag = false;
|
||||
ext_info_msg.async_flag = false;
|
||||
FWKAdapter::ExtInfo *ext_info = nullptr;
|
||||
char *extInfo_buf = reinterpret_cast<char *>(static_cast<uintptr_t>(param_head->extInfoAddr));
|
||||
while (offset + sizeof(FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
|
||||
|
@ -340,36 +340,36 @@ uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg *ext
|
|||
uint32_t ret = KERNEL_STATUS_OK;
|
||||
switch (ext_info->infoType) {
|
||||
case FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE:
|
||||
ret = ParseExtShapeType(ext_info, &ext_info_msg->unknown_shape);
|
||||
ret = ParseExtShapeType(ext_info, ext_info_msg.unknown_shape);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_INPUT_SHAPE:
|
||||
ret = ParseExtShapeAndType(ext_info_msg->unknown_shape, ext_info, &ext_info_msg->input_shape_and_type);
|
||||
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.input_shape_and_type);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_OUTPUT_SHAPE:
|
||||
ret = ParseExtShapeAndType(ext_info_msg->unknown_shape, ext_info, &ext_info_msg->output_shape_and_type);
|
||||
ret = ParseExtShapeAndType(ext_info_msg.unknown_shape, ext_info, ext_info_msg.output_shape_and_type);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_SESSION_INFO:
|
||||
ext_info_msg->has_sess_info = true;
|
||||
ret = ParseExtSessionInfo(ext_info, &ext_info_msg->kernel_id);
|
||||
ext_info_msg.has_sess_info = true;
|
||||
ret = ParseExtSessionInfo(ext_info, ext_info_msg.kernel_id);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_BITMAP:
|
||||
ret = ParseExtBitMap(ext_info, &ext_info_msg->unknown_shape);
|
||||
ret = ParseExtBitMap(ext_info, ext_info_msg.unknown_shape);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT: {
|
||||
ret = ParseAsyncWait(ext_info, &ext_info_msg->wait_type, &ext_info_msg->wait_id);
|
||||
ret = ParseAsyncWait(ext_info, ext_info_msg.wait_type, ext_info_msg.wait_id);
|
||||
bool flag = ((ret == KERNEL_STATUS_OK) &&
|
||||
(ext_info_msg->wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_NULL) &&
|
||||
(ext_info_msg->wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_INVALID));
|
||||
(ext_info_msg.wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_NULL) &&
|
||||
(ext_info_msg.wait_type != FWKAdapter::FWKExtWaitType::FWK_ADPT_WAIT_TYPE_INVALID));
|
||||
if (flag) {
|
||||
ext_info_msg->async_flag = true;
|
||||
ext_info_msg.async_flag = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_INPUT_INDEX:
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, &ext_info_msg->unknown_shape_input_index_addr);
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_input_index_addr);
|
||||
break;
|
||||
case FWKAdapter::FWK_ADPT_EXT_UNKNOWN_SHAPE_OUTPUT_INDEX:
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, &ext_info_msg->unknown_shape_output_index_addr);
|
||||
ret = ParseExtUnknownShapeIndex(ext_info, ext_info_msg.unknown_shape_output_index_addr);
|
||||
break;
|
||||
default:
|
||||
KERNEL_LOG_INFO("Ignore infoType[%d], infoLen[%u].", ext_info->infoType, ext_info->infoLen);
|
||||
|
@ -391,8 +391,8 @@ uint32_t CpuKernelCache::ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg *ext
|
|||
/*
|
||||
* parse io address.
|
||||
*/
|
||||
uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> *io_addrs, char **nodedef,
|
||||
uint32_t *nodedef_len) const {
|
||||
uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> &io_addrs, char *&nodedef,
|
||||
uint32_t &nodedef_len) const {
|
||||
auto param_base = reinterpret_cast<char *>(param_head);
|
||||
char *extend_param_base = param_base + sizeof(AicpuParamHead);
|
||||
uint32_t extend_param_len = param_head->length - sizeof(AicpuParamHead);
|
||||
|
@ -414,7 +414,7 @@ uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uin
|
|||
|
||||
auto io_addr_base = reinterpret_cast<uint64_t *>(extend_param_base);
|
||||
for (uint32_t i = 0; i < param_head->ioAddrNum; ++i) {
|
||||
io_addrs->push_back(io_addr_base[i]);
|
||||
io_addrs.push_back(io_addr_base[i]);
|
||||
}
|
||||
extend_param_base = extend_param_base + addr_len;
|
||||
extend_param_len -= addr_len;
|
||||
|
@ -428,10 +428,10 @@ uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uin
|
|||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
*nodedef_len = *reinterpret_cast<uint32_t *>(extend_param_base);
|
||||
nodedef_len = *reinterpret_cast<uint32_t *>(extend_param_base);
|
||||
extend_param_base += sizeof(uint32_t);
|
||||
*nodedef = extend_param_base;
|
||||
KERNEL_LOG_INFO("Parse io addr success, io number[%zu], nodedef length[%u].", io_addrs->size(), *nodedef_len);
|
||||
nodedef = extend_param_base;
|
||||
KERNEL_LOG_INFO("Parse io addr success, io number[%zu], nodedef length[%u].", io_addrs.size(), nodedef_len);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
|
@ -440,7 +440,7 @@ uint32_t CpuKernelCache::ParseIoAddr(AicpuParamHead *param_head, std::vector<uin
|
|||
*/
|
||||
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContext(bool has_sess_info, uint64_t kernel_id,
|
||||
const char *nodedef, uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> *nodedef_proto) {
|
||||
std::shared_ptr<NodeDef> &nodedef_proto) {
|
||||
std::shared_ptr<CpuKernelContext> ctx = nullptr;
|
||||
KERNEL_LOG_INFO("Get cpu kernel context begin, kernel id[%lu].", kernel_id);
|
||||
if (has_sess_info) {
|
||||
|
@ -452,22 +452,22 @@ std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContext(bool has_s
|
|||
}
|
||||
|
||||
std::string str_data(nodedef, nodedef_len);
|
||||
*nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(*nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr), "Create node def failed.")
|
||||
if (!(*nodedef_proto)->ParseFromString(str_data)) {
|
||||
nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr), "Create node def failed.")
|
||||
if (!nodedef_proto->ParseFromString(str_data)) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
CpuKernelContext *tmp = new (std::nothrow) CpuKernelContext(DEVICE);
|
||||
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context failed.")
|
||||
ctx = std::shared_ptr<CpuKernelContext>(tmp);
|
||||
uint32_t ret = ctx->Init(nodedef_proto->get());
|
||||
uint32_t ret = ctx->Init(nodedef_proto.get());
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (has_sess_info) {
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData(*nodedef_proto, ctx);
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
|
||||
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
|
||||
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
|
||||
SetCache(kernel_id, cache_shared);
|
||||
|
@ -485,7 +485,7 @@ int32_t CpuKernelCache::RunKernel(void *param) {
|
|||
std::vector<uint64_t> io_addrs;
|
||||
char *nodedef = nullptr;
|
||||
uint32_t nodedef_len = 0;
|
||||
uint32_t ret = ParseIoAddr(param_head, &io_addrs, &nodedef, &nodedef_len);
|
||||
uint32_t ret = ParseIoAddr(param_head, io_addrs, nodedef, nodedef_len);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -496,17 +496,17 @@ int32_t CpuKernelCache::RunKernel(void *param) {
|
|||
KERNEL_LOG_ERROR("Create ExtInfoMsg failed");
|
||||
return -1;
|
||||
}
|
||||
ret = ParseExtMsg(param_head, ext_info_msg.get());
|
||||
ret = ParseExtMsg(param_head, *ext_info_msg);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeDef> nodedef_proto = nullptr;
|
||||
auto ctx =
|
||||
GetCpuKernelContext(ext_info_msg->has_sess_info, ext_info_msg->kernel_id, nodedef, nodedef_len, &nodedef_proto);
|
||||
GetCpuKernelContext(ext_info_msg->has_sess_info, ext_info_msg->kernel_id, nodedef, nodedef_len, nodedef_proto);
|
||||
KERNEL_CHECK_NULLPTR(ctx, KERNEL_STATUS_INNER_ERROR, "Get cpu kernel context from buff failed.")
|
||||
|
||||
ret = UpdateTensor(io_addrs, ext_info_msg.get(), *ctx);
|
||||
ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -514,13 +514,13 @@ int32_t CpuKernelCache::RunKernel(void *param) {
|
|||
if (ext_info_msg->async_flag) {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernelAsync(
|
||||
*ctx, ext_info_msg->wait_type, ext_info_msg->wait_id,
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(ext_info_msg.get(), *ctx); });
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(*ext_info_msg, *ctx); });
|
||||
} else {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernel(*ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
ret = UpdateFWKOutputShape(ext_info_msg.get(), *ctx);
|
||||
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
|
||||
}
|
||||
if (ret == KERNEL_STATUS_END_OF_SEQUENCE) {
|
||||
return ret;
|
||||
|
@ -539,7 +539,7 @@ int32_t CpuKernelCache::RunCpuKernelWithBlock(void *param, struct BlkDimInfo *bl
|
|||
std::vector<uint64_t> io_addrs;
|
||||
char *nodedef = nullptr;
|
||||
uint32_t nodedef_len = 0;
|
||||
uint32_t ret = ParseIoAddr(param_head, &io_addrs, &nodedef, &nodedef_len);
|
||||
uint32_t ret = ParseIoAddr(param_head, io_addrs, nodedef, nodedef_len);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -550,16 +550,16 @@ int32_t CpuKernelCache::RunCpuKernelWithBlock(void *param, struct BlkDimInfo *bl
|
|||
KERNEL_LOG_ERROR("Create ExtInfoMsg failed");
|
||||
return -1;
|
||||
}
|
||||
ret = ParseExtMsg(param_head, ext_info_msg.get());
|
||||
ret = ParseExtMsg(param_head, *ext_info_msg);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeDef> nodedef_proto = nullptr;
|
||||
auto ctx = GetCpuKernelContextWithBlock(ext_info_msg, nodedef, nodedef_len, &nodedef_proto, blkdim_info);
|
||||
auto ctx = GetCpuKernelContextWithBlock(ext_info_msg, nodedef, nodedef_len, nodedef_proto, blkdim_info);
|
||||
KERNEL_CHECK_NULLPTR(ctx, KERNEL_STATUS_INNER_ERROR, "Get cpu kernel context from buff failed.")
|
||||
|
||||
ret = UpdateTensor(io_addrs, ext_info_msg.get(), *ctx);
|
||||
ret = UpdateTensor(io_addrs, *ext_info_msg, *ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -567,13 +567,13 @@ int32_t CpuKernelCache::RunCpuKernelWithBlock(void *param, struct BlkDimInfo *bl
|
|||
if (ext_info_msg->async_flag) {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernelAsync(
|
||||
*ctx, ext_info_msg->wait_type, ext_info_msg->wait_id,
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(ext_info_msg.get(), *ctx); });
|
||||
[&, ctx, ext_info_msg]() { return UpdateFWKOutputShape(*ext_info_msg, *ctx); });
|
||||
} else {
|
||||
ret = CpuKernelRegister::Instance().RunCpuKernel(*ctx);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
}
|
||||
ret = UpdateFWKOutputShape(ext_info_msg.get(), *ctx);
|
||||
ret = UpdateFWKOutputShape(*ext_info_msg, *ctx);
|
||||
}
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return -1;
|
||||
|
@ -586,7 +586,7 @@ int32_t CpuKernelCache::RunCpuKernelWithBlock(void *param, struct BlkDimInfo *bl
|
|||
std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContextWithBlock(std::shared_ptr<ExtInfoMsg> extInfoMsg,
|
||||
const char *nodedef,
|
||||
uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> *nodedef_proto,
|
||||
std::shared_ptr<NodeDef> &nodedef_proto,
|
||||
struct BlkDimInfo *blkdim_info) {
|
||||
std::shared_ptr<CpuKernelContext> ctx = nullptr;
|
||||
KERNEL_LOG_INFO("Get cpu kernel context with block info begin. kernel id[%lu]", extInfoMsg->kernel_id);
|
||||
|
@ -598,34 +598,34 @@ std::shared_ptr<CpuKernelContext> CpuKernelCache::GetCpuKernelContextWithBlock(s
|
|||
}
|
||||
}
|
||||
std::string str_data(nodedef, nodedef_len);
|
||||
*nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(*nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr),
|
||||
nodedef_proto = CpuKernelUtils::CreateNodeDef();
|
||||
KERNEL_CHECK_NULLPTR(nodedef_proto, std::shared_ptr<CpuKernelContext>(nullptr),
|
||||
"Create node def with block info failed.")
|
||||
if (!(*nodedef_proto)->ParseFromString(str_data)) {
|
||||
if (!nodedef_proto->ParseFromString(str_data)) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (blkdim_info->blockNum != 1) {
|
||||
auto blockNum = CpuKernelUtils::CreateAttrValue();
|
||||
blockNum->SetInt(blkdim_info->blockNum);
|
||||
(*nodedef_proto)->AddAttrs("block_num", blockNum.get());
|
||||
nodedef_proto->AddAttrs("block_num", blockNum.get());
|
||||
|
||||
auto blockid = CpuKernelUtils::CreateAttrValue();
|
||||
blockid->SetInt(blkdim_info->blockId);
|
||||
(*nodedef_proto)->AddAttrs("block_id", blockid.get());
|
||||
nodedef_proto->AddAttrs("block_id", blockid.get());
|
||||
KERNEL_LOG_INFO("AddAttrs block info , blockNum[%u] blockId[%u].", blkdim_info->blockNum, blkdim_info->blockId);
|
||||
}
|
||||
|
||||
CpuKernelContext *tmp = new (std::nothrow) CpuKernelContext(DEVICE);
|
||||
KERNEL_CHECK_NULLPTR(tmp, std::shared_ptr<CpuKernelContext>(nullptr), "Create context with block info failed.")
|
||||
ctx = std::shared_ptr<CpuKernelContext>(tmp);
|
||||
uint32_t ret = ctx->Init(nodedef_proto->get());
|
||||
uint32_t ret = ctx->Init(nodedef_proto.get());
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return std::shared_ptr<CpuKernelContext>(nullptr);
|
||||
}
|
||||
|
||||
if (extInfoMsg->has_sess_info) {
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData((*nodedef_proto), ctx);
|
||||
CpuCacheData *cache_ptr = new (std::nothrow) CpuCacheData(nodedef_proto, ctx);
|
||||
KERNEL_CHECK_NULLPTR(cache_ptr, std::shared_ptr<CpuKernelContext>(nullptr), "Create cpu cache data failed.")
|
||||
std::shared_ptr<CpuCacheData> cache_shared = std::shared_ptr<CpuCacheData>(cache_ptr);
|
||||
SetCache(extInfoMsg->kernel_id, cache_shared);
|
||||
|
|
|
@ -83,41 +83,40 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* update framework output tensor shape.
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t UpdateFWKOutputShape(ExtInfoMsg *ext_info_msg, const CpuKernelContext &ctx) const;
|
||||
uint32_t UpdateFWKOutputShape(ExtInfoMsg &ext_info_msg, const CpuKernelContext &ctx) const;
|
||||
|
||||
/*
|
||||
* get shape information from framework.
|
||||
* @param dims: shape information
|
||||
*/
|
||||
void GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *shape_and_type, std::vector<int64_t> *dims) const;
|
||||
void GetDimsFromShapeAndType(const FWKAdapter::ShapeAndType *shape_and_type, std::vector<int64_t> &dims) const;
|
||||
|
||||
/*
|
||||
* get shape information from arrays.
|
||||
* @param dims: shape information
|
||||
*/
|
||||
void GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> *dims) const;
|
||||
void GetDimsFromArrays(const int64_t *shape, size_t len, std::vector<int64_t> &dims) const;
|
||||
|
||||
/*
|
||||
* update tensor information.
|
||||
* @param ctx: kernel context
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg *ext_info_msg,
|
||||
const CpuKernelContext &ctx) const;
|
||||
uint32_t UpdateTensor(const std::vector<uint64_t> &io_addrs, ExtInfoMsg &ext_info_msg, CpuKernelContext &ctx) const;
|
||||
|
||||
/*
|
||||
* parse extend tensor shape types information.
|
||||
* @param ext_info: extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool *unknown_shape) const;
|
||||
uint32_t ParseExtShapeType(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape) const;
|
||||
|
||||
/*
|
||||
* parse extend tensor bitmap information.
|
||||
* @param ext_info: extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool *unknown_shape);
|
||||
uint32_t ParseExtBitMap(const FWKAdapter::ExtInfo *ext_info, bool &unknown_shape);
|
||||
|
||||
/*
|
||||
* parse extend tensor shape and types information.
|
||||
|
@ -126,7 +125,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtShapeAndType(bool unknown_shape, FWKAdapter::ExtInfo *ext_info,
|
||||
std::vector<FWKAdapter::ShapeAndType *> *shape_and_type) const;
|
||||
std::vector<FWKAdapter::ShapeAndType *> &shape_and_type) const;
|
||||
|
||||
/*
|
||||
* parse extend unknown shape index information.
|
||||
|
@ -135,7 +134,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtUnknownShapeIndex(FWKAdapter::ExtInfo *ext_info,
|
||||
std::map<uint32_t, uint64_t> *unknown_shape_index_addr) const;
|
||||
std::map<uint32_t, uint64_t> &unknown_shape_index_addr) const;
|
||||
|
||||
/*
|
||||
* parse extend session information.
|
||||
|
@ -143,7 +142,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @param kernel_id: kernel id from extend information
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t *kernel_id) const;
|
||||
uint32_t ParseExtSessionInfo(FWKAdapter::ExtInfo *ext_info, uint64_t &kernel_id) const;
|
||||
|
||||
/*
|
||||
* parse extend async wait info
|
||||
|
@ -152,7 +151,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @param wait_id : event wait id
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t *wait_type, uint32_t *wait_id) const;
|
||||
uint32_t ParseAsyncWait(FWKAdapter::ExtInfo *ext_info, uint8_t &wait_type, uint32_t &wait_id) const;
|
||||
|
||||
/*
|
||||
* parse extend information.
|
||||
|
@ -160,7 +159,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @param ext_info_msg: extend info msg
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg *ext_info_msg);
|
||||
uint32_t ParseExtMsg(AicpuParamHead *param_head, ExtInfoMsg &ext_info_msg);
|
||||
|
||||
/*
|
||||
* parse io address.
|
||||
|
@ -170,8 +169,8 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @param nodedef_len: kernel node def length
|
||||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
uint32_t ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> *io_addrs, char **nodedef,
|
||||
uint32_t *nodedef_len) const;
|
||||
uint32_t ParseIoAddr(AicpuParamHead *param_head, std::vector<uint64_t> &io_addrs, char *&nodedef,
|
||||
uint32_t &nodedef_len) const;
|
||||
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
|
@ -180,7 +179,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
* @return uint32_t: 0 indicates success, while the others fail
|
||||
*/
|
||||
std::shared_ptr<CpuKernelContext> GetCpuKernelContext(bool has_sess_info, uint64_t kernel_id, const char *nodedef,
|
||||
uint32_t nodedef_len, std::shared_ptr<NodeDef> *nodedef_proto);
|
||||
uint32_t nodedef_len, std::shared_ptr<NodeDef> &nodedef_proto);
|
||||
|
||||
/*
|
||||
* get cpu kernel context from cache
|
||||
|
@ -191,7 +190,7 @@ class CpuKernelCache : public KernelCache<CpuCacheData> {
|
|||
*/
|
||||
std::shared_ptr<CpuKernelContext> GetCpuKernelContextWithBlock(std::shared_ptr<ExtInfoMsg> extInfoMsg,
|
||||
const char *nodedef, uint32_t nodedef_len,
|
||||
std::shared_ptr<NodeDef> *nodedef_proto,
|
||||
std::shared_ptr<NodeDef> &nodedef_proto,
|
||||
struct BlkDimInfo *blkdim_info);
|
||||
|
||||
/*
|
||||
|
|
|
@ -85,7 +85,7 @@ std::vector<std::string> CpuKernelRegister::GetAllRegisteredOpTypes() const {
|
|||
* param ctx: context of kernel
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t CpuKernelRegister::RunCpuKernel(const CpuKernelContext &ctx) {
|
||||
uint32_t CpuKernelRegister::RunCpuKernel(CpuKernelContext &ctx) {
|
||||
std::string type = ctx.GetOpType();
|
||||
KERNEL_LOG_INFO("RunCpuKernel[%s] begin.", type.c_str());
|
||||
auto kernel = GetCpuKernel(type);
|
||||
|
@ -114,8 +114,8 @@ uint32_t CpuKernelRegister::RunCpuKernel(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelRegister::RunCpuKernelAsync(const CpuKernelContext &ctx, const uint8_t wait_type,
|
||||
const uint32_t wait_id, std::function<uint32_t()> cb) {
|
||||
uint32_t CpuKernelRegister::RunCpuKernelAsync(CpuKernelContext &ctx, const uint8_t wait_type, const uint32_t wait_id,
|
||||
std::function<uint32_t()> cb) {
|
||||
std::string type = ctx.GetOpType();
|
||||
KERNEL_LOG_INFO("RunCpuKernelAsync[%s] begin.", type.c_str());
|
||||
auto kernel = GetCpuKernel(type);
|
||||
|
|
|
@ -51,7 +51,7 @@ class AICPU_VISIBILITY CpuKernelRegister {
|
|||
* param ctx: context of kernel
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t RunCpuKernel(const CpuKernelContext &ctx);
|
||||
uint32_t RunCpuKernel(CpuKernelContext &ctx);
|
||||
|
||||
/*
|
||||
* run async cpu kernel.
|
||||
|
@ -61,7 +61,7 @@ class AICPU_VISIBILITY CpuKernelRegister {
|
|||
* @param cb : callback function
|
||||
* @return uint32_t: 0->success other->failed
|
||||
*/
|
||||
uint32_t RunCpuKernelAsync(const CpuKernelContext &ctx, const uint8_t wait_type, const uint32_t wait_id,
|
||||
uint32_t RunCpuKernelAsync(CpuKernelContext &ctx, const uint8_t wait_type, const uint32_t wait_id,
|
||||
std::function<uint32_t()> cb);
|
||||
|
||||
// CpuKernel registration function to register different types of kernel to
|
||||
|
|
|
@ -74,9 +74,9 @@ std::string CpuKernelUtils::GetTensorName(const Tensor *tensor) {
|
|||
/*
|
||||
* set tensor name.
|
||||
*/
|
||||
void CpuKernelUtils::SetTensorName(const std::string &name, std::shared_ptr<Tensor> *tensor) {
|
||||
void CpuKernelUtils::SetTensorName(const std::string &name, std::shared_ptr<Tensor> &tensor) {
|
||||
KERNEL_LOG_INFO("Set tensor name[%s]", name.c_str());
|
||||
auto impl = GetImpl(tensor->get());
|
||||
auto impl = GetImpl(tensor.get());
|
||||
KERNEL_CHECK_NULLPTR_VOID(impl, "Get Tensor impl failed.")
|
||||
impl->SetName(name);
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ class AICPU_VISIBILITY CpuKernelUtils {
|
|||
/*
|
||||
* set tensor name.
|
||||
*/
|
||||
static void SetTensorName(const std::string &name, std::shared_ptr<Tensor> *tensor);
|
||||
static void SetTensorName(const std::string &name, std::shared_ptr<Tensor> &tensor);
|
||||
|
||||
/*
|
||||
* create Tensor shape.
|
||||
|
|
|
@ -41,7 +41,7 @@ class AICPU_VISIBILITY NodeDef {
|
|||
* serialize string to node def.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SerializeToString(std::string *str) const;
|
||||
bool SerializeToString(std::string &str) const;
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "cpu_kernel/common/host_sharder.h"
|
||||
|
||||
namespace aicpu {
|
||||
Device::Device(DeviceType device) : device_(device), sharder_(InitSharder(device)) {}
|
||||
Device::Device(DeviceType device) : device_(device), sharder_(InitSharder(device)){};
|
||||
|
||||
Device::~Device() {
|
||||
if (sharder_ != nullptr) {
|
||||
|
|
|
@ -25,49 +25,50 @@
|
|||
#include "cpu_kernel/common/session_cache.h"
|
||||
#include "cpu_kernel/common/status.h"
|
||||
|
||||
using namespace aicpu;
|
||||
namespace {
|
||||
// max param len limit 10k.
|
||||
constexpr uint32_t kMaxParamLen = 10240;
|
||||
// max extend info len limit 20k.
|
||||
constexpr uint32_t kMaxExtendLen = 20480;
|
||||
const char kContextKeyStreamId[] = "streamId";
|
||||
const std::string kContextKeyStreamId = "streamId";
|
||||
|
||||
uint32_t ParseExtSessionInfo(aicpu::AicpuParamHead *param_head, SessionInfo **session) {
|
||||
uint32_t ParseExtSessionInfo(AicpuParamHead *param_head, SessionInfo *&session) {
|
||||
KERNEL_LOG_INFO("Parse extend session info begin.");
|
||||
uint32_t offset = 0;
|
||||
aicpu::FWKAdapter::ExtInfo *ext_info = nullptr;
|
||||
FWKAdapter::ExtInfo *ext_info = nullptr;
|
||||
char *ext_info_buf = reinterpret_cast<char *>(static_cast<uintptr_t>(param_head->extInfoAddr));
|
||||
while (offset + sizeof(aicpu::FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
|
||||
ext_info = reinterpret_cast<aicpu::FWKAdapter::ExtInfo *>(ext_info_buf + offset);
|
||||
while (offset + sizeof(FWKAdapter::ExtInfo) <= param_head->extInfoLength) {
|
||||
ext_info = reinterpret_cast<FWKAdapter::ExtInfo *>(ext_info_buf + offset);
|
||||
if (ext_info == nullptr) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Extend info is nullptr, extend info length[%u], extend info "
|
||||
"offset[%u].",
|
||||
param_head->extInfoLength, offset);
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (ext_info->infoType == aicpu::FWKAdapter::FWK_ADPT_EXT_SESSION_INFO) {
|
||||
if (ext_info->infoType == FWKAdapter::FWK_ADPT_EXT_SESSION_INFO) {
|
||||
auto need_len = sizeof(SessionInfo);
|
||||
if (ext_info->infoLen != need_len) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Parse extend session info failed, as info length must be "
|
||||
"[%zu], but %u.",
|
||||
sizeof(SessionInfo), ext_info->infoLen);
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
*session = reinterpret_cast<SessionInfo *>(ext_info->infoMsg);
|
||||
session = reinterpret_cast<SessionInfo *>(ext_info->infoMsg);
|
||||
KERNEL_LOG_INFO("Parse extend session info success.");
|
||||
}
|
||||
|
||||
// not overflow
|
||||
offset += aicpu::FWKAdapter::kExtInfoHeadSize;
|
||||
offset += FWKAdapter::kExtInfoHeadSize;
|
||||
offset += ext_info->infoLen;
|
||||
}
|
||||
|
||||
KERNEL_LOG_INFO("Parse extend session info end.");
|
||||
return aicpu::KERNEL_STATUS_OK;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -76,94 +77,94 @@ __attribute__((visibility("default"))) uint32_t RunCpuKernel(void *param) {
|
|||
KERNEL_LOG_INFO("RunCpuKernel C begin");
|
||||
if (param == nullptr) {
|
||||
KERNEL_LOG_ERROR("Param is null.");
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// parse param_len
|
||||
aicpu::AicpuParamHead *param_head = static_cast<aicpu::AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(aicpu::AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
(param_head->extInfoLength > kMaxExtendLen)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Param length[%u] not in [%zu, %u] or extend info length[%u] is "
|
||||
"greater "
|
||||
"than the limit[%u].",
|
||||
param_head->length, sizeof(aicpu::AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
param_head->length, sizeof(AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
SessionInfo *session = nullptr;
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, &session);
|
||||
if (ret != aicpu::KERNEL_STATUS_OK) {
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, session);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (session == nullptr) {
|
||||
KERNEL_LOG_INFO("RunCpuKernel directly.");
|
||||
aicpu::CpuKernelCache cache;
|
||||
CpuKernelCache cache;
|
||||
cache.Init(false);
|
||||
return cache.RunKernel(param);
|
||||
}
|
||||
|
||||
std::string stream_id_value = "0";
|
||||
auto status = aicpu::GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != aicpu::AICPU_ERROR_NONE) {
|
||||
auto status = GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("GetThreadLocalCtx failed, ret[%d].", status);
|
||||
return aicpu::KERNEL_STATUS_INNER_ERROR;
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
uint64_t stream_id = atoi(stream_id_value.c_str());
|
||||
KERNEL_LOG_INFO(
|
||||
"RunCpuKernel from cache, stream id[%lu], session id[%lu], session "
|
||||
"flag[%d].",
|
||||
stream_id, session->sessionId, session->sessFlag);
|
||||
return aicpu::SessionCache<aicpu::CpuCacheData>::Instance().RunKernel<aicpu::CpuKernelCache>(
|
||||
param, session->sessionId, stream_id, session->sessFlag);
|
||||
return SessionCache<CpuCacheData>::Instance().RunKernel<CpuKernelCache>(param, session->sessionId, stream_id,
|
||||
session->sessFlag);
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) uint32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info) {
|
||||
if (param == nullptr || blkdim_info == nullptr) {
|
||||
KERNEL_LOG_ERROR("Param is null.");
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_LOG_INFO("RunCpuKernelWithBlock C begin. blockid[%u], blockdim[%u].", blkdim_info->blockId,
|
||||
blkdim_info->blockNum);
|
||||
if (param == nullptr || blkdim_info == nullptr) {
|
||||
KERNEL_LOG_ERROR("Param is null.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// parse param_len
|
||||
aicpu::AicpuParamHead *param_head = static_cast<aicpu::AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(aicpu::AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
AicpuParamHead *param_head = static_cast<AicpuParamHead *>(param);
|
||||
if ((param_head->length < sizeof(AicpuParamHead)) || (param_head->length > kMaxParamLen) ||
|
||||
(param_head->extInfoLength > kMaxExtendLen)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Param length[%u] not in [%zu, %u] or extend info length[%u] is "
|
||||
"greater "
|
||||
"than the limit[%u].",
|
||||
param_head->length, sizeof(aicpu::AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return aicpu::KERNEL_STATUS_PARAM_INVALID;
|
||||
param_head->length, sizeof(AicpuParamHead), kMaxParamLen, param_head->extInfoLength, kMaxExtendLen);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
SessionInfo *session = nullptr;
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, &session);
|
||||
if (ret != aicpu::KERNEL_STATUS_OK) {
|
||||
uint32_t ret = ParseExtSessionInfo(param_head, session);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (session == nullptr) {
|
||||
KERNEL_LOG_INFO("RunCpuKernelWithBlock directly.");
|
||||
aicpu::CpuKernelCache cache;
|
||||
CpuKernelCache cache;
|
||||
cache.Init(false);
|
||||
return cache.RunCpuKernelWithBlock(param, blkdim_info);
|
||||
}
|
||||
|
||||
std::string stream_id_value = "0";
|
||||
auto status = aicpu::GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != aicpu::AICPU_ERROR_NONE) {
|
||||
auto status = GetThreadLocalCtx(kContextKeyStreamId, &stream_id_value);
|
||||
if (status != AICPU_ERROR_NONE) {
|
||||
KERNEL_LOG_ERROR("GetThreadLocalCtx failed, ret[%d].", status);
|
||||
return aicpu::KERNEL_STATUS_INNER_ERROR;
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
uint64_t stream_id = atoi(stream_id_value.c_str());
|
||||
KERNEL_LOG_INFO(
|
||||
"RunCpuKernel from cache, stream id[%lu], session id[%lu], session "
|
||||
"flag[%d].",
|
||||
stream_id, session->sessionId, session->sessFlag);
|
||||
return aicpu::SessionCache<aicpu::CpuCacheData>::Instance().RunCpuKernelWithBlock<aicpu::CpuKernelCache>(
|
||||
return SessionCache<CpuCacheData>::Instance().RunCpuKernelWithBlock<CpuKernelCache>(
|
||||
param, session->sessionId, stream_id, session->sessFlag, blkdim_info);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,4 +26,4 @@ extern "C" {
|
|||
uint32_t RunCpuKernel(void *param);
|
||||
uint32_t RunCpuKernelWithBlock(void *param, struct BlkDimInfo *blkdim_info);
|
||||
}
|
||||
#endif // AICPU_CONTEXT_COMMON_DEVICE_CPU_KERNEL_H
|
||||
#endif // AICPU_CONTEXT_COMMON_DEVICE_CPU_KERNEL_H
|
|
@ -1,17 +1,18 @@
|
|||
|
||||
// Copyright 2022 Huawei Technologies Co., Ltd
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cpu_kernel/common/eigen_threadpool.h"
|
||||
|
||||
#include <sys/sysinfo.h>
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
|
||||
// Copyright 2022 Huawei Technologies Co., Ltd
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
||||
#define AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
||||
#define EIGEN_USE_THREADS
|
||||
|
@ -56,5 +57,5 @@ class EigenThreadPool {
|
|||
static std::unique_ptr<Eigen::ThreadPool> eigen_threadpool_;
|
||||
static std::unique_ptr<Eigen::ThreadPoolDevice> threadpool_device_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
}; // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_EIGEN_THREAD_POOL_H
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
namespace aicpu {
|
||||
class HostSharder : public Sharder {
|
||||
public:
|
||||
explicit HostSharder(DeviceType device) : Sharder(device) {}
|
||||
explicit HostSharder(DeviceType device) : Sharder(device){};
|
||||
|
||||
~HostSharder() = default;
|
||||
|
||||
|
|
|
@ -13,176 +13,179 @@
|
|||
#include "cpu_kernel/common/cpu_kernel_utils.h"
|
||||
|
||||
namespace aicpu {
|
||||
std::shared_ptr<NodeDef> NodeDefBuilder::CreateNodeDef() { return CpuKernelUtils::CpuKernelUtils::CreateNodeDef(); }
|
||||
|
||||
NodeDefBuilder::NodeDefBuilder(NodeDef *nodeDef, const std::string &name, std::string opName) : name_(name) {
|
||||
nodeDef_ = nodeDef;
|
||||
nodeDef_->SetOpType(opName);
|
||||
std::shared_ptr<NodeDef> NodeDefBuilder::CreateNodeDef() {
|
||||
return CpuKernelUtils::CpuKernelUtils::CreateNodeDef();
|
||||
}
|
||||
|
||||
void NodeDefBuilder::BuildNodeFromInputOutputNode(const InputOutputNode &node, bool isInput) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
if (isInput) {
|
||||
tensor = nodeDef_->AddInputs();
|
||||
} else {
|
||||
tensor = nodeDef_->AddOutputs();
|
||||
}
|
||||
aicpu::CpuKernelUtils::SetTensorName(node.node, &tensor);
|
||||
tensor->SetDataType(node.dType);
|
||||
auto shape = tensor->GetTensorShape();
|
||||
shape->SetDimSizes(node.dims);
|
||||
shape->SetFormat(node.format);
|
||||
int64_t dataSize = 1;
|
||||
for (size_t i = 0; i < node.dims.size(); i++) {
|
||||
dataSize = dataSize * node.dims[i];
|
||||
}
|
||||
dataSize = dataSize * GetSizeByDataType(node.dType);
|
||||
if (node.dims.empty()) {
|
||||
dataSize = GetSizeByDataType(node.dType);
|
||||
}
|
||||
if (node.data == nullptr) {
|
||||
dataSize = 0;
|
||||
}
|
||||
tensor->SetDataSize(dataSize);
|
||||
tensor->SetData(node.data);
|
||||
NodeDefBuilder::NodeDefBuilder(NodeDef *nodeDef, std::string name, std::string opName) {
|
||||
nodeDef_ = nodeDef;
|
||||
name_ = name;
|
||||
nodeDef_->SetOpType(opName);
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Input(const InputOutputNode &input) {
|
||||
BuildNodeFromInputOutputNode(input, true);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Output(const InputOutputNode &output) {
|
||||
BuildNodeFromInputOutputNode(output, false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, int32_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, int64_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, float value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, double value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, bool value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, aicpu::DataType value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<bool> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::string &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<std::string> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<float> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<aicpu::DataType> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &dims, const std::string &type) {
|
||||
if (type == "shape") {
|
||||
auto shape = CpuKernelUtils::CreateAttrValue();
|
||||
auto value = CpuKernelUtils::CreateTensorShape();
|
||||
value->SetDimSizes(dims);
|
||||
shape->SetTensorShape(value.get());
|
||||
nodeDef_->AddAttrs(name, shape.get());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists,
|
||||
const std::string &type) {
|
||||
if (type == "shape_list") {
|
||||
auto shapeItems = CpuKernelUtils::CreateAttrValue();
|
||||
for (size_t i = 0; i < shapeLists.size(); i++) {
|
||||
auto value = shapeItems->AddListTensorShape();
|
||||
value->SetDimSizes(shapeLists[i]);
|
||||
void NodeDefBuilder::BuildNodeFromInputOutputNode(const InputOutputNode& node, bool isInput) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
if (isInput) {
|
||||
tensor = nodeDef_->AddInputs();
|
||||
} else {
|
||||
tensor = nodeDef_->AddOutputs();
|
||||
}
|
||||
nodeDef_->AddAttrs(name, shapeItems.get());
|
||||
}
|
||||
return *this;
|
||||
aicpu::CpuKernelUtils::SetTensorName(node.node, tensor);
|
||||
tensor->SetDataType(node.dType);
|
||||
auto shape = tensor->GetTensorShape();
|
||||
shape->SetDimSizes(node.dims);
|
||||
shape->SetFormat(node.format);
|
||||
int64_t dataSize = 1;
|
||||
for (size_t i = 0; i < node.dims.size(); i++) {
|
||||
dataSize = dataSize * node.dims[i];
|
||||
}
|
||||
dataSize = dataSize * GetSizeByDataType(node.dType);
|
||||
if (node.dims.empty()) {
|
||||
dataSize = GetSizeByDataType(node.dType);
|
||||
}
|
||||
if (node.data == nullptr) {
|
||||
dataSize = 0;
|
||||
}
|
||||
tensor->SetDataSize(dataSize);
|
||||
tensor->SetData(node.data);
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, aicpu::Tensor *tensor) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetTensor(tensor);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
NodeDefBuilder& NodeDefBuilder::Input(const InputOutputNode& input) {
|
||||
BuildNodeFromInputOutputNode(input, true);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder &NodeDefBuilder::Attr(std::string name, const std::vector<aicpu::Tensor *> &tensors) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListTensor(tensors);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
NodeDefBuilder& NodeDefBuilder::Output(const InputOutputNode& output) {
|
||||
BuildNodeFromInputOutputNode(output, false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int32_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, int64_t value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, float value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, double value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, bool value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::DataType value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<bool> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListBool(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::string &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::string> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListString(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListListInt(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<float> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListFloat(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<aicpu::DataType> &value) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListDataType(value);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<int64_t> &dims, std::string type) {
|
||||
if (type == "shape") {
|
||||
auto shape = CpuKernelUtils::CreateAttrValue();
|
||||
auto value = CpuKernelUtils::CreateTensorShape();
|
||||
value->SetDimSizes(dims);
|
||||
shape->SetTensorShape(value.get());
|
||||
nodeDef_->AddAttrs(name, shape.get());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists,
|
||||
std::string type) {
|
||||
if (type == "shape_list") {
|
||||
auto shapeItems = CpuKernelUtils::CreateAttrValue();
|
||||
for (size_t i = 0; i < shapeLists.size(); i++) {
|
||||
auto value = shapeItems->AddListTensorShape();
|
||||
value->SetDimSizes(shapeLists[i]);
|
||||
}
|
||||
nodeDef_->AddAttrs(name, shapeItems.get());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, aicpu::Tensor *tensor) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetTensor(tensor);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(std::string name, std::vector<aicpu::Tensor *> &tensors) {
|
||||
auto attr = CpuKernelUtils::CreateAttrValue();
|
||||
attr->SetListTensor(tensors);
|
||||
nodeDef_->AddAttrs(name, attr.get());
|
||||
return *this;
|
||||
}
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -31,7 +31,7 @@ class NodeDefBuilder {
|
|||
|
||||
static std::shared_ptr<NodeDef> CreateNodeDef();
|
||||
|
||||
NodeDefBuilder(NodeDef *nodeDef, const std::string &name, std::string opName);
|
||||
NodeDefBuilder(NodeDef *nodeDef, std::string name, std::string opName);
|
||||
|
||||
NodeDefBuilder &Input(const InputOutputNode &input);
|
||||
|
||||
|
@ -63,13 +63,13 @@ class NodeDefBuilder {
|
|||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<aicpu::DataType> &value);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &dims, const std::string &type);
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<int64_t> &dims, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists, const std::string &type);
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<std::vector<int64_t>> &shapeLists, std::string type);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::Tensor *tensor);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, const std::vector<aicpu::Tensor *> &tensors);
|
||||
NodeDefBuilder &Attr(std::string name, std::vector<aicpu::Tensor *> &tensors);
|
||||
|
||||
private:
|
||||
void BuildNodeFromInputOutputNode(const InputOutputNode &node, bool isInput);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <mutex>
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
class Notification {
|
||||
public:
|
||||
Notification() : notified_(0) {}
|
||||
|
@ -47,9 +48,10 @@ class Notification {
|
|||
}
|
||||
|
||||
private:
|
||||
std::mutex mu_;
|
||||
std::condition_variable cv_;
|
||||
std::atomic<bool> notified_;
|
||||
std::mutex mu_; // protects mutations of notified_
|
||||
std::condition_variable cv_; // signaled when notified_ becomes non-zero
|
||||
std::atomic<bool> notified_; // mutations under mu_
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_CONTEXT_COMMON_NOTIFICATION_H
|
||||
|
|
|
@ -34,4 +34,4 @@ struct RuntimeTensorDesc {
|
|||
#pragma pack(pop)
|
||||
} // namespace ge
|
||||
|
||||
#endif // INC_GE_RUNTIME_TENSOR_DESC_H_
|
||||
#endif // INC_GE_RUNTIME_TENSOR_DESC_H_
|
|
@ -49,14 +49,14 @@ class SessionCache {
|
|||
if (sess_flag) {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from session, id[%llu].", session_id);
|
||||
std::unique_lock<std::mutex> lock(session_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(&session_kernel_cache_, session_id, sess_flag, &kernel);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(session_kernel_cache_, session_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from stream, id[%llu].", stream_id);
|
||||
std::unique_lock<std::mutex> lock(stream_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(&stream_kernel_cache_, stream_id, sess_flag, &kernel);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(stream_kernel_cache_, stream_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -79,14 +79,14 @@ class SessionCache {
|
|||
if (sess_flag) {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from session, id[%llu].", session_id);
|
||||
std::unique_lock<std::mutex> lock(session_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(&session_kernel_cache_, session_id, sess_flag, &kernel);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(session_kernel_cache_, session_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
KERNEL_LOG_DEBUG("SessionCache KernelCache from stream, id[%llu].", stream_id);
|
||||
std::unique_lock<std::mutex> lock(stream_mutex_);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(&stream_kernel_cache_, stream_id, sess_flag, &kernel);
|
||||
int32_t ret = GetOrCreateKernelCache<T>(stream_kernel_cache_, stream_id, sess_flag, kernel);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -103,24 +103,24 @@ class SessionCache {
|
|||
SessionCache &operator=(SessionCache &&) = delete;
|
||||
|
||||
template <class T>
|
||||
int32_t GetOrCreateKernelCache(std::map<uint64_t, std::shared_ptr<KernelCache<C>>> *kernel_map, uint64_t id,
|
||||
bool sess_flag, std::shared_ptr<KernelCache<C>> *kernel) {
|
||||
auto iter = kernel_map->find(id);
|
||||
if (iter != kernel_map->end()) {
|
||||
int32_t GetOrCreateKernelCache(std::map<uint64_t, std::shared_ptr<KernelCache<C>>> &kernel_map, uint64_t id,
|
||||
bool sess_flag, std::shared_ptr<KernelCache<C>> &kernel) {
|
||||
auto iter = kernel_map.find(id);
|
||||
if (iter != kernel_map.end()) {
|
||||
KERNEL_LOG_DEBUG("Get kernel from cache success, id[%llu].", id);
|
||||
*kernel = iter->second;
|
||||
kernel = iter->second;
|
||||
} else {
|
||||
KernelCache<C> *cache = new (std::nothrow) T();
|
||||
if (cache == nullptr) {
|
||||
KERNEL_LOG_DEBUG("Create kernel cache failed, id[%llu].", id);
|
||||
return -1;
|
||||
}
|
||||
*kernel = std::shared_ptr<KernelCache<C>>(cache);
|
||||
int32_t ret = (*kernel)->Init(sess_flag);
|
||||
kernel = std::shared_ptr<KernelCache<C>>(cache);
|
||||
int32_t ret = kernel->Init(sess_flag);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
kernel_map->insert(std::make_pair(id, *kernel));
|
||||
kernel_map.insert(std::make_pair(id, kernel));
|
||||
KERNEL_LOG_DEBUG("Create kernel cache, id[%llu].", id);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -30,7 +30,7 @@ class ThreadCtx {
|
|||
|
||||
virtual uint32_t SetThreadCtxInfo(CtxType type, const std::string &key, const std::string &value) const = 0;
|
||||
|
||||
virtual uint32_t GetThreadCtxInfo(CtxType type, const std::string &key, std::string *value) const = 0;
|
||||
virtual uint32_t GetThreadCtxInfo(CtxType type, const std::string &key, std::string &value) const = 0;
|
||||
|
||||
virtual uint32_t RemoveThreadCtxInfo(CtxType type, const std::string &key) const = 0;
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class AttrValueImpl {
|
|||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
explicit AttrValueImpl(
|
||||
AttrValueImpl(
|
||||
aicpuops::AttrValue *attr, std::function<void(aicpuops::AttrValue *)> del_func = [](aicpuops::AttrValue *p) {})
|
||||
: attr_value_(attr, del_func) {}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ bool NodeDef::ParseFromString(const std::string &str) { return impl_->ParseFromS
|
|||
/*
|
||||
* serialize string to node def.
|
||||
*/
|
||||
bool NodeDef::SerializeToString(std::string *str) const { return impl_->SerializeToString(str); }
|
||||
bool NodeDef::SerializeToString(std::string &str) const { return impl_->SerializeToString(str); }
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
|
|
|
@ -38,8 +38,8 @@ bool NodeDefImpl::ParseFromString(const std::string &str) {
|
|||
/*
|
||||
* serialize string to node def.
|
||||
*/
|
||||
bool NodeDefImpl::SerializeToString(std::string *str) const {
|
||||
if (!nodedef_->SerializeToString(str)) {
|
||||
bool NodeDefImpl::SerializeToString(std::string &str) const {
|
||||
if (!nodedef_->SerializeToString(&str)) {
|
||||
KERNEL_LOG_ERROR("SerializeToString failed");
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ class NodeDefImpl {
|
|||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
explicit NodeDefImpl(
|
||||
NodeDefImpl(
|
||||
aicpuops::NodeDef *nodedef, std::function<void(aicpuops::NodeDef *)> del_func = [](aicpuops::NodeDef *p) {})
|
||||
: nodedef_(nodedef, del_func) {}
|
||||
|
||||
|
@ -49,7 +49,7 @@ class NodeDefImpl {
|
|||
* serialize string to node def.
|
||||
* @return bool: true->success, false->failed
|
||||
*/
|
||||
bool SerializeToString(std::string *str) const;
|
||||
bool SerializeToString(std::string &str) const;
|
||||
|
||||
/*
|
||||
* set op type to node def.
|
||||
|
|
|
@ -68,4 +68,4 @@ int64_t Tensor::CalcDataSizeByShape() const { return impl_->CalcDataSizeByShape(
|
|||
* get data elements number.
|
||||
*/
|
||||
int64_t Tensor::NumElements() const { return impl_->NumElements(); }
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -134,4 +134,4 @@ int64_t TensorImpl::NumElements() const {
|
|||
}
|
||||
|
||||
aicpuops::Tensor *TensorImpl::GetProto() const { return tensor_.get(); }
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -27,7 +27,7 @@ class TensorImpl {
|
|||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
explicit TensorImpl(
|
||||
TensorImpl(
|
||||
aicpuops::Tensor *tensor, std::function<void(aicpuops::Tensor *)> delFunc = [](aicpuops::Tensor *p) {})
|
||||
: tensor_(tensor, delFunc) {}
|
||||
|
||||
|
|
|
@ -63,4 +63,4 @@ int64_t TensorShape::GetDimSize(int32_t index) const { return impl_->GetDimSize(
|
|||
* get data elements number.
|
||||
*/
|
||||
int64_t TensorShape::NumElements() const { return impl_->NumElements(); }
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -103,4 +103,4 @@ int64_t TensorShapeImpl::NumElements() const {
|
|||
*/
|
||||
|
||||
aicpuops::TensorShape *TensorShapeImpl::GetProto() const { return tensor_shape_.get(); }
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -27,7 +27,7 @@ class TensorShapeImpl {
|
|||
friend class CpuKernelUtils;
|
||||
|
||||
public:
|
||||
explicit TensorShapeImpl(
|
||||
TensorShapeImpl(
|
||||
aicpuops::TensorShape *shape,
|
||||
std::function<void(aicpuops::TensorShape *)> del_func = [](aicpuops::TensorShape *p) {})
|
||||
: tensor_shape_(shape, del_func) {}
|
||||
|
|
|
@ -68,23 +68,23 @@ bool CheckShape(Format format, const ShapeVector &shape) {
|
|||
* @dst_shape: N*W1*H1*H0*w0
|
||||
* @return
|
||||
*/
|
||||
uint32_t TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector *dst_shape,
|
||||
ShapeVector *hw_shape) {
|
||||
dst_shape->clear();
|
||||
hw_shape->clear();
|
||||
uint32_t TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape,
|
||||
ShapeVector &hw_shape) {
|
||||
dst_shape.clear();
|
||||
hw_shape.clear();
|
||||
auto w0 = GetCubeSizeByDataType(data_type);
|
||||
int64_t h0 = kCubeSize;
|
||||
switch (src_shape.size()) {
|
||||
case kSingleDim:
|
||||
dst_shape->push_back(Ceil(src_shape[kNdDimIndexN], w0));
|
||||
dst_shape->push_back(kDimDefaultValue);
|
||||
dst_shape->push_back(h0);
|
||||
dst_shape->push_back(w0);
|
||||
hw_shape->push_back(kDimDefaultValue);
|
||||
hw_shape->push_back(kDimDefaultValue);
|
||||
hw_shape->push_back(src_shape[kNdDimIndexN]);
|
||||
if (!IsShapeValid(*dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(*dst_shape).c_str());
|
||||
dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0));
|
||||
dst_shape.push_back(kDimDefaultValue);
|
||||
dst_shape.push_back(h0);
|
||||
dst_shape.push_back(w0);
|
||||
hw_shape.push_back(kDimDefaultValue);
|
||||
hw_shape.push_back(kDimDefaultValue);
|
||||
hw_shape.push_back(src_shape[kNdDimIndexN]);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
|
@ -92,27 +92,27 @@ uint32_t TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Sh
|
|||
auto size = src_shape.size();
|
||||
int64_t times = 1;
|
||||
for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) {
|
||||
dst_shape->push_back(src_shape[i]);
|
||||
dst_shape.push_back(src_shape[i]);
|
||||
times *= src_shape[i];
|
||||
}
|
||||
dst_shape->push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
|
||||
dst_shape->push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
|
||||
dst_shape->push_back(h0);
|
||||
dst_shape->push_back(w0);
|
||||
hw_shape->push_back(times);
|
||||
hw_shape->push_back(src_shape[size - kNdDimCountBackwardsWH]);
|
||||
hw_shape->push_back(src_shape[size - kNdDimCountBackwardsW]);
|
||||
if (!IsShapeValid(*dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(*dst_shape).c_str());
|
||||
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0));
|
||||
dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0));
|
||||
dst_shape.push_back(h0);
|
||||
dst_shape.push_back(w0);
|
||||
hw_shape.push_back(times);
|
||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]);
|
||||
hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Failed to check dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t CheckShapeRelation(const TransArgs &args, ShapeVector *hw_shape) {
|
||||
uint32_t CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) {
|
||||
ShapeVector expect_src_shape;
|
||||
auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, &expect_src_shape, hw_shape);
|
||||
auto ret = TransShapeToFracNz(args.dst_shape, args.src_data_type, expect_src_shape, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans shape from [%s] to [%s], shape [%s] to [%s], data type [%s] "
|
||||
|
@ -128,12 +128,12 @@ uint32_t CheckShapeRelation(const TransArgs &args, ShapeVector *hw_shape) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransFormatFromNdToFracNz(const TransArgs &args, TransResult *result, const ShapeVector &hw_shape) {
|
||||
uint32_t TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) {
|
||||
int size = GetSizeByDataType(args.src_data_type);
|
||||
// data size will not be greater than INT_MAX
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
|
@ -206,16 +206,16 @@ uint32_t TransFormatFromNdToFracNz(const TransArgs &args, TransResult *result, c
|
|||
}
|
||||
}
|
||||
}
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransFormatFromFracNzToNd(const TransArgs &args, TransResult *result, const ShapeVector &dst_hw_shape) {
|
||||
uint32_t TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) {
|
||||
int size = GetSizeByDataType(args.src_data_type);
|
||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
|
@ -286,13 +286,13 @@ uint32_t TransFormatFromFracNzToNd(const TransArgs &args, TransResult *result, c
|
|||
}
|
||||
}
|
||||
}
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (!IsDataTypeSupport(args.src_data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
|
@ -319,7 +319,7 @@ uint32_t FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult
|
|||
DTypeStr(args.src_data_type).c_str());
|
||||
ShapeVector expect_shape;
|
||||
ShapeVector hw_shape;
|
||||
auto ret = TransShapeToFracNz(args.src_shape, args.src_data_type, &expect_shape, &hw_shape);
|
||||
auto ret = TransShapeToFracNz(args.src_shape, args.src_data_type, expect_shape, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ uint32_t FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult
|
|||
}
|
||||
|
||||
uint32_t FormatTransferFractalNz::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type,
|
||||
Format dst_format, ShapeVector *dst_shape, int64_t groups) {
|
||||
Format dst_format, ShapeVector &dst_shape, int64_t groups) {
|
||||
if (!IsDataTypeSupport(data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], data type [%s] is not "
|
||||
|
@ -348,10 +348,10 @@ uint32_t FormatTransferFractalNz::TransShape(Format src_format, const ShapeVecto
|
|||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
ShapeVector hw_shape;
|
||||
return TransShapeToFracNz(src_shape, data_type, dst_shape, &hw_shape);
|
||||
return TransShapeToFracNz(src_shape, data_type, dst_shape, hw_shape);
|
||||
}
|
||||
|
||||
uint32_t FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (!IsDataTypeSupport(args.src_data_type)) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Trans format from [%s] to [%s], src shape [%s], dst shape [%s], data "
|
||||
|
@ -379,7 +379,7 @@ uint32_t FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResu
|
|||
DTypeStr(args.src_data_type).c_str());
|
||||
|
||||
ShapeVector hw_shape;
|
||||
auto ret = CheckShapeRelation(args, &hw_shape);
|
||||
auto ret = CheckShapeRelation(args, hw_shape);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -387,7 +387,7 @@ uint32_t FormatTransferFractalNzND::TransFormat(const TransArgs &args, TransResu
|
|||
}
|
||||
|
||||
uint32_t FormatTransferFractalNzND::TransShape(Format src_format, const ShapeVector &src_shape, DataType data_type,
|
||||
Format dst_format, ShapeVector *dst_shape, int64_t groups) {
|
||||
Format dst_format, ShapeVector &dst_shape, int64_t groups) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"The shape derivation from [%s] to [%s] is not unique. Trans shape is "
|
||||
"not supported",
|
||||
|
|
|
@ -25,17 +25,17 @@ namespace formats {
|
|||
// transfer from nd to nz
|
||||
class FormatTransferFractalNz : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
|
||||
// transfer nz to nd
|
||||
class FormatTransferFractalNzND : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -45,7 +45,7 @@ KernelStatus CheckDataTypeSupport(DataType data_type) {
|
|||
*/
|
||||
|
||||
uint32_t TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
auto c0 = GetCubeSizeByDataType(data_type);
|
||||
if (c0 < 0) {
|
||||
KERNEL_LOG_ERROR("Cube size must greater than or equal to 0");
|
||||
|
@ -70,20 +70,20 @@ uint32_t TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Da
|
|||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t g_dim = Ceil(groups, e_mult);
|
||||
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
|
||||
dst_shape->clear();
|
||||
dst_shape->push_back(g_dim * c1_dim * h * w);
|
||||
dst_shape->push_back(n1);
|
||||
dst_shape->push_back(kNiSize);
|
||||
dst_shape->push_back(cube_k);
|
||||
if (!IsShapeValid(*dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(*dst_shape).c_str());
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(g_dim * c1_dim * h * w);
|
||||
dst_shape.push_back(n1);
|
||||
dst_shape.push_back(kNiSize);
|
||||
dst_shape.push_back(cube_k);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeNchwToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ uint32_t TransShapeNchwToFzWithGroups(const std::vector<int64_t> &src_shape, Dat
|
|||
}
|
||||
|
||||
uint32_t TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ uint32_t TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, Dat
|
|||
}
|
||||
|
||||
uint32_t TransShapeNhwcToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -130,13 +130,13 @@ uint32_t TransShapeNhwcToFzWithGroups(const std::vector<int64_t> &src_shape, Dat
|
|||
// old filter to the new filter, and finally added 0 to the position where there
|
||||
// is no data.
|
||||
uint32_t TransFormatWithGroups(const Format &format_4d, const std::vector<int64_t> &shape_4d, const TransArgs &args,
|
||||
TransResult *result, bool reverse) {
|
||||
TransResult &result, bool reverse) {
|
||||
int64_t h_dim = 0;
|
||||
int64_t w_dim = 0;
|
||||
int64_t c_dim = 0;
|
||||
int64_t n_dim = 0;
|
||||
int64_t d_dim = 1;
|
||||
if (GetFormatDim(&d_dim, &h_dim, &w_dim, &c_dim, &n_dim, format_4d, shape_4d) != KERNEL_STATUS_OK) {
|
||||
if (GetFormatDim(d_dim, h_dim, w_dim, c_dim, n_dim, format_4d, shape_4d) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c_dim;
|
||||
|
@ -156,7 +156,7 @@ uint32_t TransFormatWithGroups(const Format &format_4d, const std::vector<int64_
|
|||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly.
|
||||
if (dst_size == 0) {
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
|
@ -203,13 +203,13 @@ uint32_t TransFormatWithGroups(const Format &format_4d, const std::vector<int64_
|
|||
}
|
||||
}
|
||||
}
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
if (args.groups == 0) {
|
||||
KERNEL_LOG_ERROR("Attr[groups] must not be equal to 0");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
|
@ -226,7 +226,7 @@ uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult
|
|||
args.dst_format == FORMAT_FRACTAL_Z) {
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, &expect_shape, args.groups);
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -239,8 +239,8 @@ uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult
|
|||
(args.dst_format == FORMAT_NCHW)) &&
|
||||
args.src_format == FORMAT_FRACTAL_Z) {
|
||||
std::vector<int64_t> expect_input_shape;
|
||||
auto ret = TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, &expect_input_shape,
|
||||
args.groups);
|
||||
auto ret =
|
||||
TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, expect_input_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return ret;
|
||||
|
@ -256,7 +256,7 @@ uint32_t FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult
|
|||
}
|
||||
|
||||
uint32_t FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> *dst_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
if (CheckDataTypeSupport(data_type) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
|
@ -282,4 +282,4 @@ REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_NCHW)
|
|||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_HWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_FRACTAL_Z, FORMAT_NHWC)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -24,9 +24,9 @@ namespace aicpu {
|
|||
namespace formats {
|
||||
class FormatTransferFractalZ : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -45,7 +45,7 @@ KernelStatus CheckDataTypeSupport(DataType data_type) {
|
|||
*/
|
||||
|
||||
uint32_t TransShapeToFz3DWithGroups(int64_t n, int64_t c, int64_t d, int64_t h, int64_t w, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
auto c0 = GetCubeSizeByDataType(data_type);
|
||||
if (c0 < 0) {
|
||||
KERNEL_LOG_ERROR("Cube size must greater than or equal to 0");
|
||||
|
@ -68,20 +68,20 @@ uint32_t TransShapeToFz3DWithGroups(int64_t n, int64_t c, int64_t d, int64_t h,
|
|||
int64_t c1_dim = cin_opt / cube_k;
|
||||
int64_t dim_g = Ceil(groups, e_mult);
|
||||
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
|
||||
dst_shape->clear();
|
||||
dst_shape->push_back(dim_g * c1_dim * d * h * w);
|
||||
dst_shape->push_back(n1);
|
||||
dst_shape->push_back(kNiSize);
|
||||
dst_shape->push_back(cube_k);
|
||||
if (!IsShapeValid(*dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(*dst_shape).c_str());
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(dim_g * c1_dim * d * h * w);
|
||||
dst_shape.push_back(n1);
|
||||
dst_shape.push_back(kNiSize);
|
||||
dst_shape.push_back(cube_k);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeNcdhwToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(kNcdhwDimsNum))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ uint32_t TransShapeNcdhwToFzWithGroups(const std::vector<int64_t> &src_shape, Da
|
|||
}
|
||||
|
||||
uint32_t TransShapeDhwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(kDhwcnDimsNum))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ uint32_t TransShapeDhwcnToFzWithGroups(const std::vector<int64_t> &src_shape, Da
|
|||
}
|
||||
|
||||
uint32_t TransShapeNdhwcToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) {
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) {
|
||||
if (!CheckShapeValid(src_shape, kNdhwcDimsNum)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
@ -128,13 +128,13 @@ uint32_t TransShapeNdhwcToFzWithGroups(const std::vector<int64_t> &src_shape, Da
|
|||
// index between NCDHW and FORMAT_FRACTAL_Z_3D , then Convert the old filter to the new
|
||||
// filter, and finally added 0 to the position where there is no data.
|
||||
uint32_t TransFormatWithGroups(const Format &format_5d, const std::vector<int64_t> &shape_5d, const TransArgs &args,
|
||||
TransResult *result, bool reverse) {
|
||||
TransResult &result, bool reverse) {
|
||||
int64_t h_dim = 0;
|
||||
int64_t w_dim = 0;
|
||||
int64_t c_dim = 0;
|
||||
int64_t n_dim = 0;
|
||||
int64_t d_dim = 0;
|
||||
if (GetFormatDim(&d_dim, &h_dim, &w_dim, &c_dim, &n_dim, format_5d, shape_5d) != KERNEL_STATUS_OK) {
|
||||
if (GetFormatDim(d_dim, h_dim, w_dim, c_dim, n_dim, format_5d, shape_5d) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
int64_t cin_ori = c_dim;
|
||||
|
@ -153,7 +153,7 @@ uint32_t TransFormatWithGroups(const Format &format_5d, const std::vector<int64_
|
|||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly.
|
||||
if (dst_size == 0) {
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
|
@ -200,13 +200,14 @@ uint32_t TransFormatWithGroups(const Format &format_5d, const std::vector<int64_
|
|||
}
|
||||
}
|
||||
}
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
KERNEL_LOG_DEBUG(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], data type "
|
||||
"[%s], dst "
|
||||
|
@ -224,7 +225,7 @@ uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResul
|
|||
args.dst_format == FORMAT_FRACTAL_Z_3D) {
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, &expect_shape, args.groups);
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -236,8 +237,8 @@ uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResul
|
|||
(args.dst_format == FORMAT_NCDHW)) &&
|
||||
args.src_format == FORMAT_FRACTAL_Z_3D) {
|
||||
std::vector<int64_t> expect_input_shape;
|
||||
auto ret = TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, &expect_input_shape,
|
||||
args.groups);
|
||||
auto ret =
|
||||
TransShape(args.dst_format, args.dst_shape, args.src_data_type, args.src_format, expect_input_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
KERNEL_LOG_ERROR("Check dst shape failed, dst shape [%s]", VectorToString(args.dst_shape).c_str());
|
||||
return ret;
|
||||
|
@ -254,7 +255,7 @@ uint32_t FormatTransferFractalz3D::TransFormat(const TransArgs &args, TransResul
|
|||
}
|
||||
|
||||
uint32_t FormatTransferFractalz3D::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> *dst_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
if (CheckDataTypeSupport(data_type) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
|
@ -282,4 +283,4 @@ REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_N
|
|||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_DHWCN)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferFractalz3D, FORMAT_FRACTAL_Z_3D, FORMAT_NDHWC)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -23,9 +23,9 @@ namespace aicpu {
|
|||
namespace formats {
|
||||
class FormatTransferFractalz3D : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -40,7 +40,7 @@ KernelStatus CheckDataTypeSupport(DataType data_type) {
|
|||
}
|
||||
|
||||
void TransSrcDataToDstData(const TransArgs &args, const std::vector<int64_t> &shape_ndhwc,
|
||||
std::shared_ptr<uint8_t> *dst, int64_t c0, int32_t data_size) {
|
||||
std::shared_ptr<uint8_t> &dst, int64_t c0, int32_t data_size) {
|
||||
const int64_t n = shape_ndhwc[0];
|
||||
const int64_t d = shape_ndhwc[1];
|
||||
const int64_t h = shape_ndhwc[2];
|
||||
|
@ -55,6 +55,7 @@ void TransSrcDataToDstData(const TransArgs &args, const std::vector<int64_t> &sh
|
|||
const int64_t c1hwc0 = c1 * hwc0;
|
||||
const int64_t dc1hwc0 = d * c1hwc0;
|
||||
const int64_t ndhwc = n * dhwc;
|
||||
int64_t src_index = 0;
|
||||
|
||||
for (int64_t ndhwc_idx = 0; ndhwc_idx < ndhwc; ++ndhwc_idx) {
|
||||
const int64_t n_idx = ndhwc_idx / dhwc;
|
||||
|
@ -62,11 +63,11 @@ void TransSrcDataToDstData(const TransArgs &args, const std::vector<int64_t> &sh
|
|||
const int64_t c_idx = ndhwc_idx % c;
|
||||
const int64_t dst_index =
|
||||
n_idx * dc1hwc0 + (dhw_idx / hw) * c1hwc0 + (c_idx / c0) * hwc0 + (dhw_idx % hw) * c0 + c_idx % c0;
|
||||
int64_t src_index = n_idx * dhwc + c_idx * dhw + dhw_idx;
|
||||
src_index = n_idx * dhwc + c_idx * dhw + dhw_idx;
|
||||
if (args.src_format == FORMAT_NDHWC) {
|
||||
src_index = n_idx * dhwc + dhw_idx * c + c_idx;
|
||||
}
|
||||
uint8_t *dst_data = dst->get() + dst_index * data_size;
|
||||
uint8_t *dst_data = dst.get() + dst_index * data_size;
|
||||
const uint8_t *src_data = args.data + src_index * data_size;
|
||||
for (int64_t index = 0; index < data_size; ++index) {
|
||||
*dst_data++ = *src_data++;
|
||||
|
@ -74,12 +75,12 @@ void TransSrcDataToDstData(const TransArgs &args, const std::vector<int64_t> &sh
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t TransDstDataToNdc1hwc0(const TransArgs &args, TransResult *result) {
|
||||
uint32_t TransDstDataToNdc1hwc0(const TransArgs &args, TransResult &result) {
|
||||
const int32_t data_size = GetSizeByDataType(args.src_data_type);
|
||||
const auto dst_size = GetItemNumByShape(args.dst_shape) * data_size;
|
||||
// The input is empty tensor, we should return success directly
|
||||
if (dst_size == 0) {
|
||||
result->length = 0;
|
||||
result.length = 0;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
|
||||
|
@ -117,16 +118,16 @@ uint32_t TransDstDataToNdc1hwc0(const TransArgs &args, TransResult *result) {
|
|||
KERNEL_LOG_ERROR("Failed to get c0, c0 is [%ld]", c0);
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
TransSrcDataToDstData(args, shape_ndhwc, &dst, c0, data_size);
|
||||
TransSrcDataToDstData(args, shape_ndhwc, dst, c0, data_size);
|
||||
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransShapeToNdc1hwc0(const std::vector<int64_t> &src_shape, const Format &src_format,
|
||||
const DataType &data_type, std::vector<int64_t> *dst_shape) {
|
||||
const DataType &data_type, std::vector<int64_t> &dst_shape) {
|
||||
auto iter = kFormatTable.find(src_format);
|
||||
if (iter == kFormatTable.end()) {
|
||||
KERNEL_LOG_ERROR("src_format is wrong, now format is [%d]", static_cast<int32_t>(src_format));
|
||||
|
@ -148,15 +149,15 @@ uint32_t TransShapeToNdc1hwc0(const std::vector<int64_t> &src_shape, const Forma
|
|||
if (!CheckShapeValid(src_shape, static_cast<int64_t>(cur_format.length()))) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
dst_shape->clear();
|
||||
dst_shape->push_back(src_shape.at(n_index));
|
||||
dst_shape->push_back(src_shape.at(d_index));
|
||||
dst_shape->push_back(Ceil(src_shape.at(c_index), c0));
|
||||
dst_shape->push_back(src_shape.at(h_index));
|
||||
dst_shape->push_back(src_shape.at(w_index));
|
||||
dst_shape->push_back(c0);
|
||||
if (!IsShapeValid(*dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(*dst_shape).c_str());
|
||||
dst_shape.clear();
|
||||
dst_shape.push_back(src_shape.at(n_index));
|
||||
dst_shape.push_back(src_shape.at(d_index));
|
||||
dst_shape.push_back(Ceil(src_shape.at(c_index), c0));
|
||||
dst_shape.push_back(src_shape.at(h_index));
|
||||
dst_shape.push_back(src_shape.at(w_index));
|
||||
dst_shape.push_back(c0);
|
||||
if (!IsShapeValid(dst_shape)) {
|
||||
KERNEL_LOG_ERROR("Check shape failed, dst shape [%s]", VectorToString(dst_shape).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
@ -164,7 +165,7 @@ uint32_t TransShapeToNdc1hwc0(const std::vector<int64_t> &src_shape, const Forma
|
|||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t FormatTransferNdc1hwc0::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferNdc1hwc0::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
KERNEL_LOG_INFO(
|
||||
"Begin to trans format from [%s] to [%s], src shape [%s], data type [%s], dst "
|
||||
"shape [%s]",
|
||||
|
@ -174,7 +175,7 @@ uint32_t FormatTransferNdc1hwc0::TransFormat(const TransArgs &args, TransResult
|
|||
|
||||
std::vector<int64_t> expect_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, &expect_shape, args.groups);
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -186,7 +187,7 @@ uint32_t FormatTransferNdc1hwc0::TransFormat(const TransArgs &args, TransResult
|
|||
}
|
||||
|
||||
uint32_t FormatTransferNdc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> *dst_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
(void)dst_format;
|
||||
(void)groups;
|
||||
|
@ -205,4 +206,4 @@ uint32_t FormatTransferNdc1hwc0::TransShape(Format src_format, const std::vector
|
|||
REGISTER_FORMAT_TRANSFER(FormatTransferNdc1hwc0, FORMAT_NCDHW, FORMAT_NDC1HWC0)
|
||||
REGISTER_FORMAT_TRANSFER(FormatTransferNdc1hwc0, FORMAT_NDHWC, FORMAT_NDC1HWC0)
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -24,9 +24,9 @@ namespace aicpu {
|
|||
namespace formats {
|
||||
class FormatTransferNdc1hwc0 : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -46,21 +46,18 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{
|
|||
{FORMAT_HWCN, std::vector<int64_t>({kChwnH, kChwnW, kChwnC, kChwnN})}}},
|
||||
};
|
||||
|
||||
bool LessThanZero(int64_t dim) { return dim < 0; }
|
||||
|
||||
bool ShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) {
|
||||
if (src_shape.empty()) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, src shape is empty");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::const_iterator it = find_if(src_shape.begin(), src_shape.end(), LessThanZero);
|
||||
if (it == src_shape.end()) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, negative dim [%d] in src shape [%s]", *it,
|
||||
FmtToStr(VectorToString(src_shape)).c_str());
|
||||
return false;
|
||||
for (auto dim : src_shape) {
|
||||
if (dim < 0) {
|
||||
KERNEL_LOG_ERROR("Failed to transpose, negative dim [%d] in src shape [%s]", dim,
|
||||
FmtToStr(VectorToString(src_shape)).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (perm_arg.size() != src_shape.size()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to transpose, the size of src shape [%s] and perm arg [%s] are "
|
||||
|
@ -93,11 +90,11 @@ bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_sha
|
|||
return ShapeArgValid(src_shape, perm_arg);
|
||||
}
|
||||
|
||||
void GenHeads(const std::vector<int64_t> &shape, std::vector<int64_t> *heads) {
|
||||
heads->resize(shape.size());
|
||||
(*heads)[shape.size() - 1] = 1;
|
||||
void GenHeads(const std::vector<int64_t> &shape, std::vector<int64_t> &heads) {
|
||||
heads.resize(shape.size());
|
||||
heads[shape.size() - 1] = 1;
|
||||
for (auto i = static_cast<int64_t>(shape.size() - 2); i >= 0; --i) {
|
||||
(*heads)[i] = shape[i + 1] * (*heads)[i + 1];
|
||||
heads[i] = shape[i + 1] * heads[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,13 +106,13 @@ int64_t GenOffset(const std::vector<int64_t> &offsets, const std::vector<int64_t
|
|||
return offset;
|
||||
}
|
||||
|
||||
void AddOne(const std::vector<int64_t> &shape, std::vector<int64_t> *indexes) {
|
||||
size_t i = indexes->size() - 1;
|
||||
(*indexes)[i]++;
|
||||
void AddOne(const std::vector<int64_t> &shape, std::vector<int64_t> &indexes) {
|
||||
size_t i = indexes.size() - 1;
|
||||
indexes[i]++;
|
||||
while (i > 0) {
|
||||
if ((*indexes)[i] >= shape[i]) {
|
||||
(*indexes)[i] = 0;
|
||||
(*indexes)[i - 1]++;
|
||||
if (indexes[i] >= shape[i]) {
|
||||
indexes[i] = 0;
|
||||
indexes[i - 1]++;
|
||||
--i;
|
||||
} else {
|
||||
break;
|
||||
|
@ -124,25 +121,25 @@ void AddOne(const std::vector<int64_t> &shape, std::vector<int64_t> *indexes) {
|
|||
}
|
||||
|
||||
void TransShapeByPerm(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg,
|
||||
std::vector<int64_t> *dst_shape) {
|
||||
dst_shape->resize(src_shape.size());
|
||||
std::vector<int64_t> &dst_shape) {
|
||||
dst_shape.resize(src_shape.size());
|
||||
for (size_t i = 0; i < perm_arg.size(); ++i) {
|
||||
(*dst_shape)[i] = src_shape[perm_arg[i]];
|
||||
dst_shape[i] = src_shape[perm_arg[i]];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult *result) {
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result) {
|
||||
if (!IsTransposeArgValid(src, src_shape, src_data_type, perm_arg)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int64_t> dst_shape;
|
||||
TransShapeByPerm(src_shape, perm_arg, &dst_shape);
|
||||
TransShapeByPerm(src_shape, perm_arg, dst_shape);
|
||||
std::vector<int64_t> src_origin_ordered_heads;
|
||||
GenHeads(src_shape, &src_origin_ordered_heads);
|
||||
GenHeads(src_shape, src_origin_ordered_heads);
|
||||
std::vector<int64_t> src_heads;
|
||||
TransShapeByPerm(src_origin_ordered_heads, perm_arg, &src_heads);
|
||||
TransShapeByPerm(src_origin_ordered_heads, perm_arg, src_heads);
|
||||
|
||||
int64_t dst_ele_num = GetItemNumByShape(dst_shape);
|
||||
int64_t data_size = GetSizeByDataType(src_data_type);
|
||||
|
@ -154,7 +151,7 @@ uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Da
|
|||
VectorToString(src_shape).c_str(), VectorToString(perm_arg).c_str(), VectorToString(dst_shape).c_str(),
|
||||
DTypeStr(src_data_type).c_str());
|
||||
if (dst_ele_num == 0) {
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
|
@ -184,23 +181,23 @@ uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Da
|
|||
dst_offset_bytes, VectorToString(dst_indexes).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
AddOne(dst_shape, &dst_indexes);
|
||||
AddOne(dst_shape, dst_indexes);
|
||||
++dst_index;
|
||||
}
|
||||
|
||||
result->data = dst;
|
||||
result->length = static_cast<size_t>(dst_size);
|
||||
result.data = dst;
|
||||
result.length = static_cast<size_t>(dst_size);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> &src_shape,
|
||||
const std::vector<int64_t> &dst_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult *result) {
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result) {
|
||||
if (!IsTransposeArgValid(data, src_shape, src_data_type, perm_arg)) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int64_t> expected_shape;
|
||||
TransShapeByPerm(src_shape, perm_arg, &expected_shape);
|
||||
TransShapeByPerm(src_shape, perm_arg, expected_shape);
|
||||
if (dst_shape != expected_shape) {
|
||||
KERNEL_LOG_ERROR(
|
||||
"Failed to trans axis for perm_arg [%s], invalid dst shape [%s], "
|
||||
|
@ -212,7 +209,7 @@ uint32_t TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t>
|
|||
return Transpose(data, src_shape, src_data_type, perm_arg, result);
|
||||
}
|
||||
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> *perm) {
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) {
|
||||
auto dst_iter = perm_args.find(src_format);
|
||||
if (dst_iter == perm_args.end()) {
|
||||
KERNEL_LOG_ERROR(
|
||||
|
@ -229,14 +226,14 @@ uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64
|
|||
FormatToSerialString(src_format).c_str(), FormatToSerialString(dst_format).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
*perm = iter->second;
|
||||
perm = iter->second;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) {
|
||||
std::vector<int64_t> expected_shape;
|
||||
auto ret =
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, &expected_shape, args.groups);
|
||||
TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape, args.groups);
|
||||
if (ret != KERNEL_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
@ -248,10 +245,10 @@ uint32_t FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult
|
|||
}
|
||||
|
||||
uint32_t FormatTransferTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> *dst_shape,
|
||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape,
|
||||
int64_t groups) {
|
||||
std::vector<int64_t> perm_arg;
|
||||
if (GetPermByForamt(src_format, dst_format, &perm_arg) != KERNEL_STATUS_OK) {
|
||||
if (GetPermByForamt(src_format, dst_format, perm_arg) != KERNEL_STATUS_OK) {
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
if (!ShapeArgValid(src_shape, perm_arg)) {
|
||||
|
|
|
@ -24,17 +24,17 @@
|
|||
namespace aicpu {
|
||||
namespace formats {
|
||||
uint32_t Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult *result);
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result);
|
||||
|
||||
uint32_t TransposeWithShapeCheck(const uint8_t *src, const std::vector<int64_t> &src_shape,
|
||||
const std::vector<int64_t> &dst_shape, DataType src_data_type,
|
||||
const std::vector<int64_t> &perm_arg, TransResult *result);
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> *perm);
|
||||
const std::vector<int64_t> &perm_arg, TransResult &result);
|
||||
uint32_t GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm);
|
||||
class FormatTransferTranspose : public FormatTransfer {
|
||||
public:
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) override;
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) override;
|
||||
uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
|
||||
std::vector<int64_t> *dst_shape, int64_t groups) override;
|
||||
std::vector<int64_t> &dst_shape, int64_t groups) override;
|
||||
};
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -70,7 +70,7 @@ int64_t GetCubeSizeByDataType(DataType data_type) {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, const std::vector<int64_t> &expect_shape) {
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) {
|
||||
if (args.src_shape != expect_shape) {
|
||||
std::string error = "Failed to trans format from" + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
FmtToStr(FormatToSerialString(args.dst_format)) + ", invalid relationship between src shape " +
|
||||
|
@ -82,7 +82,7 @@ bool IsTransShapeSrcCorrect(const TransArgs &args, const std::vector<int64_t> &e
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, const std::vector<int64_t> &expect_shape) {
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) {
|
||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) {
|
||||
std::string error = "Failed to trans format from " + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
FmtToStr(FormatToSerialString(args.dst_format)) + ", the dst shape" +
|
||||
|
@ -97,11 +97,13 @@ bool IsTransShapeDstCorrect(const TransArgs &args, const std::vector<int64_t> &e
|
|||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape) {
|
||||
// shape will not be greater than INT_MAX
|
||||
int64_t num = 1;
|
||||
num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (auto dim : shape) {
|
||||
num *= dim;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result) {
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result) {
|
||||
auto transfer = BuildFormatTransfer(args);
|
||||
if (transfer == nullptr) {
|
||||
std::string error = "Failed to trans data from format " + FmtToStr(FormatToSerialString(args.src_format)) + " to " +
|
||||
|
@ -157,44 +159,44 @@ KernelStatus CheckDimOri(int64_t cin_ori, int64_t cout_ori) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
KernelStatus GetFormatDim(int64_t *d_dim, int64_t *h_dim, int64_t *w_dim, int64_t *c_dim, int64_t *n_dim,
|
||||
KernelStatus GetFormatDim(int64_t &d_dim, int64_t &h_dim, int64_t &w_dim, int64_t &c_dim, int64_t &n_dim,
|
||||
const Format &input_format, const std::vector<int64_t> &dims) {
|
||||
if (input_format == FORMAT_NCDHW) {
|
||||
*n_dim = dims[kNcdhwN];
|
||||
*c_dim = dims[kNcdhwC];
|
||||
*d_dim = dims[kNcdhwD];
|
||||
*h_dim = dims[kNcdhwH];
|
||||
*w_dim = dims[kNcdhwW];
|
||||
n_dim = dims[kNcdhwN];
|
||||
c_dim = dims[kNcdhwC];
|
||||
d_dim = dims[kNcdhwD];
|
||||
h_dim = dims[kNcdhwH];
|
||||
w_dim = dims[kNcdhwW];
|
||||
} else if (input_format == FORMAT_DHWCN) {
|
||||
*d_dim = dims[kDhwcnD];
|
||||
*h_dim = dims[kDhwcnH];
|
||||
*w_dim = dims[kDhwcnW];
|
||||
*c_dim = dims[kDhwcnC];
|
||||
*n_dim = dims[kDhwcnN];
|
||||
d_dim = dims[kDhwcnD];
|
||||
h_dim = dims[kDhwcnH];
|
||||
w_dim = dims[kDhwcnW];
|
||||
c_dim = dims[kDhwcnC];
|
||||
n_dim = dims[kDhwcnN];
|
||||
} else if (input_format == FORMAT_NDHWC) {
|
||||
*n_dim = dims[kNdhwcN];
|
||||
*d_dim = dims[kNdhwcD];
|
||||
*h_dim = dims[kNdhwcH];
|
||||
*w_dim = dims[kNdhwcW];
|
||||
*c_dim = dims[kNdhwcC];
|
||||
n_dim = dims[kNdhwcN];
|
||||
d_dim = dims[kNdhwcD];
|
||||
h_dim = dims[kNdhwcH];
|
||||
w_dim = dims[kNdhwcW];
|
||||
c_dim = dims[kNdhwcC];
|
||||
} else if (input_format == FORMAT_NHWC) {
|
||||
*n_dim = dims[kNhwcN];
|
||||
*h_dim = dims[kNhwcH];
|
||||
*d_dim = 1;
|
||||
*w_dim = dims[kNhwcW];
|
||||
*c_dim = dims[kNhwcC];
|
||||
n_dim = dims[kNhwcN];
|
||||
h_dim = dims[kNhwcH];
|
||||
d_dim = 1;
|
||||
w_dim = dims[kNhwcW];
|
||||
c_dim = dims[kNhwcC];
|
||||
} else if (input_format == FORMAT_NCHW) {
|
||||
*n_dim = dims[kNchwN];
|
||||
*c_dim = dims[kNchwC];
|
||||
*h_dim = dims[kNchwH];
|
||||
*w_dim = dims[kNchwW];
|
||||
*d_dim = 1;
|
||||
n_dim = dims[kNchwN];
|
||||
c_dim = dims[kNchwC];
|
||||
h_dim = dims[kNchwH];
|
||||
w_dim = dims[kNchwW];
|
||||
d_dim = 1;
|
||||
} else if (input_format == FORMAT_HWCN) {
|
||||
*h_dim = dims[kHwcnH];
|
||||
*w_dim = dims[kHwcnW];
|
||||
*c_dim = dims[kHwcnC];
|
||||
*n_dim = dims[kHwcnN];
|
||||
*d_dim = 1;
|
||||
h_dim = dims[kHwcnH];
|
||||
w_dim = dims[kHwcnW];
|
||||
c_dim = dims[kHwcnC];
|
||||
n_dim = dims[kHwcnN];
|
||||
d_dim = 1;
|
||||
} else {
|
||||
KERNEL_LOG_WARN(
|
||||
"Format is not FORMAT_DHWCN or FORMAT_NDHWC or FORMAT_NCDHW or "
|
||||
|
|
|
@ -25,9 +25,9 @@
|
|||
|
||||
namespace aicpu {
|
||||
namespace formats {
|
||||
const int kCubeSize = 16;
|
||||
const int kNiSize = 16;
|
||||
const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;
|
||||
static const int kCubeSize = 16;
|
||||
static const int kNiSize = 16;
|
||||
static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL;
|
||||
int64_t Lcm(int64_t a, int64_t b);
|
||||
bool IsShapeValid(const std::vector<int64_t> &shape);
|
||||
|
||||
|
@ -35,16 +35,16 @@ bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dim
|
|||
|
||||
int64_t GetCubeSizeByDataType(DataType data_type);
|
||||
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, const std::vector<int64_t> &expect_shape);
|
||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape);
|
||||
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, const std::vector<int64_t> &expect_shape);
|
||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape);
|
||||
|
||||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape);
|
||||
|
||||
void copy_data(const uint8_t *input_data, std::shared_ptr<uint8_t> dst, int64_t src_index, int64_t dst_index,
|
||||
int64_t data_size);
|
||||
|
||||
KernelStatus GetFormatDim(int64_t *d_dim, int64_t *h_dim, int64_t *w_dim, int64_t *c_dim, int64_t *n_dim,
|
||||
KernelStatus GetFormatDim(int64_t &d_dim, int64_t &h_dim, int64_t &w_dim, int64_t &c_dim, int64_t &n_dim,
|
||||
const Format &input_format, const std::vector<int64_t> &dims);
|
||||
KernelStatus CheckDimOri(int64_t cin_ori, int64_t cout_ori);
|
||||
|
||||
|
@ -63,7 +63,7 @@ T Ceil(T n1, T n2) {
|
|||
* @param result
|
||||
* @return
|
||||
*/
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult *result);
|
||||
uint32_t TransFormat(const TransArgs &args, TransResult &result);
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_KERNELS_HOST_FORMAT_TRANSFER_FORMAT_TRANSFER_UTILS_H_
|
||||
|
|
|
@ -60,4 +60,4 @@ bool FormatTransferExists(const TransArgs &args) {
|
|||
return dst_builder->second.count(args.dst_format) > 0;
|
||||
}
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
|
@ -48,9 +48,9 @@ struct TransResult {
|
|||
class FormatTransfer {
|
||||
public:
|
||||
virtual ~FormatTransfer() = default;
|
||||
virtual uint32_t TransFormat(const TransArgs &args, TransResult *result) = 0;
|
||||
virtual uint32_t TransFormat(const TransArgs &args, TransResult &result) = 0;
|
||||
virtual uint32_t TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
|
||||
Format dst_format, std::vector<int64_t> *dst_shape, int64_t groups) = 0;
|
||||
Format dst_format, std::vector<int64_t> &dst_shape, int64_t groups) = 0;
|
||||
};
|
||||
|
||||
using FormatTransferBuilder = std::function<std::shared_ptr<FormatTransfer>()>;
|
||||
|
@ -78,4 +78,4 @@ std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args);
|
|||
bool FormatTransferExists(const TransArgs &args);
|
||||
} // namespace formats
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
#endif
|
|
@ -60,7 +60,7 @@ class AICPU_VISIBILITY NodeDefBuilder {
|
|||
|
||||
NodeDefBuilder &Attr(std::string name, aicpu::Tensor *tensor);
|
||||
|
||||
NodeDefBuilder &Attr(std::string name, std::vector<aicpu::Tensor *> *tensors);
|
||||
NodeDefBuilder &Attr(std::string name, std::vector<aicpu::Tensor *> &tensors);
|
||||
|
||||
private:
|
||||
void BuildNodeFromInputOutputNode(const InputOutputNode &node, bool isInput);
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernel {
|
||||
public:
|
||||
virtual uint32_t Compute(const CpuKernelContext &ctx) = 0;
|
||||
virtual uint32_t Compute(CpuKernelContext &ctx) = 0;
|
||||
|
||||
virtual ~CpuKernel() {}
|
||||
};
|
||||
|
|
|
@ -104,14 +104,14 @@ void ComputeSingleThread(int64_t start, int64_t end, AdaptiveCalcArgs<SCALAR_T>
|
|||
}
|
||||
|
||||
template <typename SCALAR_T>
|
||||
uint32_t AdaptiveAvgPool2dOutFrame(const CpuKernelContext &ctx, AdaptiveCalcArgs<SCALAR_T> args, int64_t num) {
|
||||
uint32_t AdaptiveAvgPool2dOutFrame(CpuKernelContext &ctx, AdaptiveCalcArgs<SCALAR_T> args, int64_t num) {
|
||||
auto shard_frame = [&](int64_t start, int64_t end) { ComputeSingleThread(start, end, args); };
|
||||
SWITCH_PARALLEL(shard_frame, args.size_d, num);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
template <typename SCALAR_T>
|
||||
uint32_t AdaptiveAvgPool2dOutTemplate(const CpuKernelContext &ctx) {
|
||||
uint32_t AdaptiveAvgPool2dOutTemplate(CpuKernelContext &ctx) {
|
||||
Tensor &input = *(ctx.Input(kFirstInputIndex));
|
||||
auto input_shape_ptr = input.GetTensorShape();
|
||||
int32_t input_dims = input_shape_ptr->GetDims();
|
||||
|
@ -176,7 +176,7 @@ uint32_t AdaptiveAvgPool2dOutTemplate(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t AdaptiveAvgPool2d::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t AdaptiveAvgPool2d::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output number failed.",
|
||||
kAdaptiveAvgPool2d);
|
||||
|
|
|
@ -26,7 +26,7 @@ class AdaptiveAvgPool2d : public CpuKernel {
|
|||
public:
|
||||
AdaptiveAvgPool2d() = default;
|
||||
~AdaptiveAvgPool2d() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
|
||||
inline int StartIndex(int offset, int out_size, int in_size) {
|
||||
|
|
|
@ -52,7 +52,7 @@ inline int EndIndex(int offset, int out_size, int in_size) {
|
|||
|
||||
namespace aicpu {
|
||||
template <typename SCALAR_T>
|
||||
uint32_t AdaptiveAvgPool2dGradOutFrame(const CpuKernelContext &ctx, AdaptiveCalcArgs<SCALAR_T> args) {
|
||||
uint32_t AdaptiveAvgPool2dGradOutFrame(CpuKernelContext &ctx, AdaptiveCalcArgs<SCALAR_T> args) {
|
||||
uint32_t min_core_num = 1;
|
||||
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2);
|
||||
|
||||
|
@ -111,7 +111,7 @@ uint32_t AdaptiveAvgPool2dGradOutFrame(const CpuKernelContext &ctx, AdaptiveCalc
|
|||
}
|
||||
|
||||
template <typename SCALAR_T>
|
||||
uint32_t AdaptiveAvgPool2dGradOutCpuTemplate(const CpuKernelContext &ctx) {
|
||||
uint32_t AdaptiveAvgPool2dGradOutCpuTemplate(CpuKernelContext &ctx) {
|
||||
Tensor &input = *(ctx.Input(kFirstInputIndex));
|
||||
|
||||
auto input_shape_ptr = input.GetTensorShape();
|
||||
|
@ -178,7 +178,7 @@ uint32_t AdaptiveAvgPool2dGradOutCpuTemplate(const CpuKernelContext &ctx) {
|
|||
return AdaptiveAvgPool2dGradOutFrame<SCALAR_T>(ctx, args);
|
||||
}
|
||||
|
||||
uint32_t AdaptiveAvgPool2dGrad::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t AdaptiveAvgPool2dGrad::Compute(CpuKernelContext &ctx) {
|
||||
Tensor *input_0 = ctx.Input(kFirstInputIndex);
|
||||
KERNEL_CHECK_NULLPTR(input_0, KERNEL_STATUS_PARAM_INVALID, "Get input tensor failed.");
|
||||
KERNEL_CHECK_NULLPTR(input_0->GetData(), KERNEL_STATUS_PARAM_INVALID, "Get input data failed.");
|
||||
|
|
|
@ -26,7 +26,7 @@ class AdaptiveAvgPool2dGrad : public CpuKernel {
|
|||
~AdaptiveAvgPool2dGrad() = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_KERNELS_NORMALIZED_ADAPTIVE_AVG_POOL_2DGRAD_H_
|
||||
|
|
|
@ -102,7 +102,7 @@ uint32_t CacheSwapTableMsCpuKernel::DoCompute() {
|
|||
return calls[indices_type_](inputs_, outputs_, batch_size_, output_size_, one_line_col_, type_size);
|
||||
}
|
||||
|
||||
uint32_t CacheSwapTableMsCpuKernel::GetInputAndCheck(const CpuKernelContext &ctx) {
|
||||
uint32_t CacheSwapTableMsCpuKernel::GetInputAndCheck(CpuKernelContext &ctx) {
|
||||
KERNEL_LOG_INFO("GetInputAndCheck start!");
|
||||
// get input Tensors
|
||||
const uint32_t kNumInput = 3;
|
||||
|
@ -137,7 +137,7 @@ uint32_t CacheSwapTableMsCpuKernel::GetInputAndCheck(const CpuKernelContext &ctx
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CacheSwapTableMsCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t CacheSwapTableMsCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
uint32_t res = GetInputAndCheck(ctx);
|
||||
if (res != KERNEL_STATUS_OK) {
|
||||
return res;
|
||||
|
|
|
@ -24,12 +24,12 @@ namespace aicpu {
|
|||
class CacheSwapTableMsCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~CacheSwapTableMsCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t DoCompute();
|
||||
|
||||
uint32_t GetInputAndCheck(const CpuKernelContext &ctx);
|
||||
uint32_t GetInputAndCheck(CpuKernelContext &ctx);
|
||||
|
||||
int64_t batch_size_ = 1;
|
||||
int64_t one_line_col_ = 1;
|
||||
|
|
|
@ -37,7 +37,7 @@ void FillGenerateCase(Tensor *&value_tensor, Tensor *&output) {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t FillCpuKernel::GetDimsByType(const CpuKernelContext &ctx) {
|
||||
uint32_t FillCpuKernel::GetDimsByType(CpuKernelContext &ctx) {
|
||||
dims.clear();
|
||||
Tensor *dims_tensor = ctx.Input(0);
|
||||
KERNEL_CHECK_NULLPTR(dims_tensor, KERNEL_STATUS_PARAM_INVALID, "Get dims input failed")
|
||||
|
@ -63,7 +63,7 @@ uint32_t FillCpuKernel::GetDimsByType(const CpuKernelContext &ctx) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
uint32_t FillCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t FillCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
uint32_t check = GetDimsByType(ctx);
|
||||
if (check != KERNEL_STATUS_OK) {
|
||||
return check;
|
||||
|
|
|
@ -24,10 +24,10 @@ class FillCpuKernel : public CpuKernel {
|
|||
public:
|
||||
FillCpuKernel() = default;
|
||||
~FillCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t GetDimsByType(const CpuKernelContext &ctx);
|
||||
uint32_t GetDimsByType(CpuKernelContext &ctx);
|
||||
/**
|
||||
* @brief calc dims from input dims tensor
|
||||
* @param dims_tensor input dims tensor
|
||||
|
|
|
@ -38,7 +38,7 @@ constexpr int64_t kParallelDataNums = 8 * 1024;
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t LogMatrixDeterminantCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t LogMatrixDeterminantCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "[%s] check input and output failed.",
|
||||
kLogMatrixDeterminant);
|
||||
KERNEL_HANDLE_ERROR(LogMatrixDeterminantCheck(ctx), "[%s] check params failed.", kLogMatrixDeterminant);
|
||||
|
@ -55,7 +55,7 @@ uint32_t LogMatrixDeterminantCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t LogMatrixDeterminantCpuKernel::LogMatrixDeterminantCheck(const CpuKernelContext &ctx) {
|
||||
uint32_t LogMatrixDeterminantCpuKernel::LogMatrixDeterminantCheck(CpuKernelContext &ctx) {
|
||||
auto input_0 = ctx.Input(0);
|
||||
auto output_0 = ctx.Output(0);
|
||||
auto output_1 = ctx.Output(1);
|
||||
|
@ -96,7 +96,7 @@ uint32_t LogMatrixDeterminantCpuKernel::LogMatrixDeterminantCheck(const CpuKerne
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t LogMatrixDeterminantCpuKernel::LogMatrixDeterminantCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t LogMatrixDeterminantCpuKernel::LogMatrixDeterminantCompute(CpuKernelContext &ctx) {
|
||||
auto input_x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_sign = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto output_y = reinterpret_cast<T *>(ctx.Output(1)->GetData());
|
||||
|
|
|
@ -23,13 +23,13 @@ class LogMatrixDeterminantCpuKernel : public CpuKernel {
|
|||
public:
|
||||
LogMatrixDeterminantCpuKernel() = default;
|
||||
~LogMatrixDeterminantCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t LogMatrixDeterminantCheck(const CpuKernelContext &ctx);
|
||||
uint32_t LogMatrixDeterminantCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t LogMatrixDeterminantCompute(const CpuKernelContext &ctx);
|
||||
uint32_t LogMatrixDeterminantCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -113,7 +113,7 @@ void UpdateIndexByCarry(std::vector<int64_t> &preIndex, const std::vector<int64_
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MaskedSelectCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t MaskedSelectCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kMaskedSelectInputNum, kMaskedSelectOutputNum), "[%s] check params failed.",
|
||||
kMaskedSelect);
|
||||
|
@ -165,7 +165,7 @@ uint32_t MaskedSelectCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCpuKernel::ParallelCompute(const CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
uint32_t MaskedSelectCpuKernel::ParallelCompute(CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
const std::vector<int64_t> &inputShapeMask,
|
||||
const std::vector<int64_t> &outputShape, int64_t dataNum) {
|
||||
T *x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
|
@ -241,7 +241,7 @@ uint32_t MaskedSelectCpuKernel::ParallelCompute(const CpuKernelContext &ctx, con
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCpuKernel::MaskedSelectCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MaskedSelectCpuKernel::MaskedSelectCompute(CpuKernelContext &ctx) {
|
||||
T *x = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(x, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[0] failed.",
|
||||
kMaskedSelect);
|
||||
|
@ -264,7 +264,7 @@ uint32_t MaskedSelectCpuKernel::MaskedSelectCompute(const CpuKernelContext &ctx)
|
|||
return static_cast<uint32_t>(KERNEL_STATUS_OK);
|
||||
}
|
||||
std::vector<int64_t> output_shape;
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, &output_shape);
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, output_shape);
|
||||
KERNEL_CHECK_FALSE(ret == KERNEL_STATUS_OK, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID),
|
||||
"Shape of x and mask can't be broadcast.");
|
||||
int64_t tensor_size = 1;
|
||||
|
@ -278,7 +278,7 @@ uint32_t MaskedSelectCpuKernel::MaskedSelectCompute(const CpuKernelContext &ctx)
|
|||
}
|
||||
|
||||
int64_t j = 0;
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, &output_shape);
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, output_shape);
|
||||
iter.SetPos(0);
|
||||
for (int64_t i = 0; i < tensor_size; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace aicpu {
|
|||
class MaskedSelectCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~MaskedSelectCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
|
@ -31,9 +31,9 @@ class MaskedSelectCpuKernel : public CpuKernel {
|
|||
* @return status if success
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectCompute(const CpuKernelContext &ctx);
|
||||
uint32_t MaskedSelectCompute(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t ParallelCompute(const CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
uint32_t ParallelCompute(CpuKernelContext &ctx, const std::vector<int64_t> &inputShapeX,
|
||||
const std::vector<int64_t> &inputShapeMask, const std::vector<int64_t> &outputShape,
|
||||
int64_t dataNum);
|
||||
};
|
||||
|
|
|
@ -30,7 +30,7 @@ const char *const kMaskedSelectGrad = "MaskedSelectGrad";
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MaskedSelectGradCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t MaskedSelectGradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kMaskedSelectGradInputNum, kMaskedSelectGradOutputNum),
|
||||
"[%s] check params failed.", kMaskedSelectGrad);
|
||||
|
@ -82,7 +82,7 @@ uint32_t MaskedSelectGradCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectGradCpuKernel::MaskedSelectGradCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MaskedSelectGradCpuKernel::MaskedSelectGradCompute(CpuKernelContext &ctx) {
|
||||
bool *mask = reinterpret_cast<bool *>(ctx.Input(1)->GetData());
|
||||
KERNEL_CHECK_NULLPTR(mask, static_cast<uint32_t>(KERNEL_STATUS_PARAM_INVALID), "[%s] get input_data[1] failed.",
|
||||
kMaskedSelectGrad);
|
||||
|
@ -96,7 +96,7 @@ uint32_t MaskedSelectGradCpuKernel::MaskedSelectGradCompute(const CpuKernelConte
|
|||
auto input_shape_a = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_shape_b = ctx.Input(1)->GetTensorShape()->GetDimSizes();
|
||||
std::vector<int64_t> output_shape;
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, &output_shape);
|
||||
auto ret = GetBroadcastShape(input_shape_a, input_shape_b, output_shape);
|
||||
KERNEL_CHECK_FALSE(ret == KERNEL_STATUS_OK, KERNEL_STATUS_PARAM_INVALID, "Shape of x and mask can't be broadcast.");
|
||||
int64_t tensor_size = 1;
|
||||
for (const int64_t &d : output_shape) {
|
||||
|
@ -107,7 +107,7 @@ uint32_t MaskedSelectGradCpuKernel::MaskedSelectGradCompute(const CpuKernelConte
|
|||
dx[k] = NUM_ZERO;
|
||||
}
|
||||
int64_t j = 0;
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, &output_shape);
|
||||
BroadcastIterator iter(input_shape_a, input_shape_b, output_shape);
|
||||
iter.SetPos(0);
|
||||
for (int64_t i = 0; i < tensor_size; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace aicpu {
|
|||
class MaskedSelectGradCpuKernel : public CpuKernel {
|
||||
public:
|
||||
~MaskedSelectGradCpuKernel() = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
|
@ -31,7 +31,7 @@ class MaskedSelectGradCpuKernel : public CpuKernel {
|
|||
* @return status if success
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t MaskedSelectGradCompute(const CpuKernelContext &ctx);
|
||||
uint32_t MaskedSelectGradCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -49,7 +49,7 @@ const char *kMedian = "Median";
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MedianCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(MedianCheck(ctx), "Median check params failed.");
|
||||
auto data_type = ctx.Input(0)->GetDataType();
|
||||
AttrValue *global_ptr = ctx.GetAttr("global_median");
|
||||
|
@ -80,7 +80,7 @@ uint32_t MedianCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t MedianCpuKernel::MedianCheck(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianCpuKernel::MedianCheck(CpuKernelContext &ctx) {
|
||||
auto global_median = ctx.GetAttr("global_median");
|
||||
KERNEL_CHECK_NULLPTR(global_median, KERNEL_STATUS_PARAM_INVALID, "Get attr global_median failed.");
|
||||
bool global_median_value = global_median->GetBool();
|
||||
|
@ -127,7 +127,7 @@ uint32_t MedianCpuKernel::MedianCheck(const CpuKernelContext &ctx) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MedianCpuKernel::GlobalMedianCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianCpuKernel::GlobalMedianCompute(CpuKernelContext &ctx) {
|
||||
auto input_x0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_y0 = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
size_t data_num = ctx.Input(0)->GetTensorShape()->NumElements();
|
||||
|
@ -138,7 +138,7 @@ uint32_t MedianCpuKernel::GlobalMedianCompute(const CpuKernelContext &ctx) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t MedianCpuKernel::MedianCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianCpuKernel::MedianCompute(CpuKernelContext &ctx) {
|
||||
auto input_x0 = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_y0 = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
auto output_y1 = reinterpret_cast<int64_t *>(ctx.Output(1)->GetData());
|
||||
|
|
|
@ -25,14 +25,14 @@ class MedianCpuKernel : public CpuKernel {
|
|||
public:
|
||||
MedianCpuKernel() = default;
|
||||
~MedianCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t MedianCheck(const CpuKernelContext &ctx);
|
||||
uint32_t MedianCheck(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t GlobalMedianCompute(const CpuKernelContext &ctx);
|
||||
uint32_t GlobalMedianCompute(CpuKernelContext &ctx);
|
||||
template <typename T>
|
||||
uint32_t MedianCompute(const CpuKernelContext &ctx);
|
||||
uint32_t MedianCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -52,7 +52,7 @@ const int64_t kParallelDataNumMid = 16 * 1024;
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t MedianGradCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianGradCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check params
|
||||
KERNEL_HANDLE_ERROR(MedianGradParamCheck(ctx), "MedianGrad check params failed.");
|
||||
auto data_type_x = ctx.Input(1)->GetDataType();
|
||||
|
@ -85,7 +85,7 @@ uint32_t MedianGradCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t MedianGradCpuKernel::MedianGradParamCheck(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianGradCpuKernel::MedianGradParamCheck(CpuKernelContext &ctx) {
|
||||
auto global_median_ptr = ctx.GetAttr("global_median");
|
||||
KERNEL_CHECK_NULLPTR(global_median_ptr, KERNEL_STATUS_PARAM_INVALID, "Get attr global_median failed.");
|
||||
bool global_median = global_median_ptr->GetBool();
|
||||
|
@ -133,7 +133,7 @@ uint32_t MedianGradCpuKernel::MedianGradParamCheck(const CpuKernelContext &ctx)
|
|||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
uint32_t MedianGradCpuKernel::GlobalMedianGradCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianGradCpuKernel::GlobalMedianGradCompute(CpuKernelContext &ctx) {
|
||||
auto y_grad = reinterpret_cast<T1 *>(ctx.Input(0)->GetData());
|
||||
auto x = reinterpret_cast<T1 *>(ctx.Input(1)->GetData());
|
||||
auto y = reinterpret_cast<T1 *>(ctx.Input(2)->GetData());
|
||||
|
@ -175,7 +175,7 @@ uint32_t MedianGradCpuKernel::GlobalMedianGradCompute(const CpuKernelContext &ct
|
|||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
uint32_t MedianGradCpuKernel::MedianGradCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t MedianGradCpuKernel::MedianGradCompute(CpuKernelContext &ctx) {
|
||||
auto y_grad = reinterpret_cast<T1 *>(ctx.Input(0)->GetData());
|
||||
auto indices = reinterpret_cast<int64_t *>(ctx.Input(3)->GetData());
|
||||
auto x_grad = reinterpret_cast<T2 *>(ctx.Output(0)->GetData());
|
||||
|
|
|
@ -27,16 +27,16 @@ class MedianGradCpuKernel : public CpuKernel {
|
|||
~MedianGradCpuKernel() override = default;
|
||||
|
||||
protected:
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t MedianGradParamCheck(const CpuKernelContext &ctx);
|
||||
uint32_t MedianGradParamCheck(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
uint32_t MedianGradCompute(const CpuKernelContext &ctx);
|
||||
uint32_t MedianGradCompute(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
uint32_t GlobalMedianGradCompute(const CpuKernelContext &ctx);
|
||||
uint32_t GlobalMedianGradCompute(CpuKernelContext &ctx);
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -27,7 +27,7 @@ const char *kNMSWithMask = "NMSWithMask";
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t NMSWithMaskCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t NMSWithMaskCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
// check param
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "NMSWithMask check input or output is failed");
|
||||
AttrValue *iou_threshold = ctx.GetAttr("iou_threshold");
|
||||
|
@ -64,7 +64,7 @@ uint32_t NMSWithMaskCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
uint32_t NMSWithMaskCpuKernel::DoCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t NMSWithMaskCpuKernel::DoCompute(CpuKernelContext &ctx) {
|
||||
auto input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output = reinterpret_cast<T *>(ctx.Output(OUTPUT)->GetData());
|
||||
auto sel_idx = reinterpret_cast<int *>(ctx.Output(SEL_IDX)->GetData());
|
||||
|
|
|
@ -32,11 +32,11 @@ class NMSWithMaskCpuKernel : public CpuKernel {
|
|||
public:
|
||||
NMSWithMaskCpuKernel() = default;
|
||||
~NMSWithMaskCpuKernel() override = default;
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
uint32_t DoCompute(const CpuKernelContext &ctx);
|
||||
uint32_t DoCompute(CpuKernelContext &ctx);
|
||||
|
||||
int num_input_{0};
|
||||
float iou_value_{0.0};
|
||||
|
|
|
@ -44,7 +44,7 @@ const char *const kReduceSum = "ReduceSum";
|
|||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
uint32_t ReduceSumCpuKernel::Compute(const CpuKernelContext &ctx) {
|
||||
uint32_t ReduceSumCpuKernel::Compute(CpuKernelContext &ctx) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kReduceSumInputNum, kReduceSumOutputNum), "[%s] check input and output failed.",
|
||||
kReduceSum);
|
||||
KERNEL_HANDLE_ERROR(ReduceSumCheck(ctx), "[%s] check params failed.", kReduceSum);
|
||||
|
@ -69,7 +69,7 @@ uint32_t ReduceSumCpuKernel::Compute(const CpuKernelContext &ctx) {
|
|||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCheck(const CpuKernelContext &ctx) const {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCheck(CpuKernelContext &ctx) const {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get output failed.");
|
||||
|
@ -81,7 +81,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumCheck(const CpuKernelContext &ctx) const {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute(const CpuKernelContext &ctx) {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
@ -141,7 +141,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes(const T *input_data, std::vector<i
|
|||
return result;
|
||||
}
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(const CpuKernelContext &ctx) {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
@ -218,7 +218,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes2(const T *input_data, int64_t inpu
|
|||
}
|
||||
return result;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(const CpuKernelContext &ctx, std::vector<int64_t> &axes) {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes) {
|
||||
int32_t rank = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
int64_t axes_num = ctx.Input(1)->NumElements();
|
||||
|
|
|
@ -25,26 +25,26 @@ class ReduceSumCpuKernel : public CpuKernel {
|
|||
ReduceSumCpuKernel() = default;
|
||||
~ReduceSumCpuKernel() override = default;
|
||||
|
||||
uint32_t Compute(const CpuKernelContext &ctx) override;
|
||||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ReduceSumCheck(const CpuKernelContext &ctx) const;
|
||||
uint32_t ReduceSumCheck(CpuKernelContext &ctx) const;
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCompute(const CpuKernelContext &ctx);
|
||||
uint32_t ReduceSumCompute(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumOneAxes(const T *input_data, std::vector<int64_t> &input_shape, T *output_data, int64_t output_num,
|
||||
std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCompute2(const CpuKernelContext &ctx);
|
||||
uint32_t ReduceSumCompute2(CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumOneAxes2(const T *input_data, int64_t input_num, std::vector<int64_t> input_shape, T *output_data,
|
||||
int64_t output_num, std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
uint32_t ReduceSumDedupAxes(const CpuKernelContext &ctx, std::vector<int64_t> &axes);
|
||||
uint32_t ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes);
|
||||
|
||||
uint32_t ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes, uint32_t &axes_idx,
|
||||
int64_t &inner, int64_t &outer, int64_t &depth) const;
|
||||
|
|
|
@ -28,19 +28,19 @@ std::unordered_set<uint64_t> g_allocated_ptr;
|
|||
|
||||
namespace aicpu {
|
||||
uint32_t CpuKernelAllocatorUtils::ParamCheck(const std::vector<int64_t> &dims, const void *data_ptr,
|
||||
Tensor **outputResultTensor) {
|
||||
Tensor *&outputResultTensor) {
|
||||
if (dims.empty()) {
|
||||
KERNEL_LOG_ERROR("UpdateOutputDataTensor dims size == 0.");
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
KERNEL_CHECK_NULLPTR(*outputResultTensor, KERNEL_STATUS_PARAM_INVALID, "outputResultTensor nullptr");
|
||||
KERNEL_CHECK_NULLPTR(outputResultTensor, KERNEL_STATUS_PARAM_INVALID, "outputResultTensor nullptr");
|
||||
KERNEL_CHECK_NULLPTR(data_ptr, KERNEL_STATUS_PARAM_INVALID, "data_ptr nullptr");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t CpuKernelAllocatorUtils::UpdateOutputDataTensor(const std::vector<int64_t> &dims, DataType type,
|
||||
const void *data_ptr, int64_t input_data_size,
|
||||
Tensor **outputResultTensor) {
|
||||
Tensor *&outputResultTensor) {
|
||||
uint32_t check_ret = ParamCheck(dims, &data_ptr, outputResultTensor);
|
||||
if (check_ret != KERNEL_STATUS_OK) {
|
||||
return check_ret;
|
||||
|
@ -72,7 +72,7 @@ uint32_t CpuKernelAllocatorUtils::UpdateOutputDataTensor(const std::vector<int64
|
|||
}
|
||||
|
||||
aicpu::FWKAdapter::ResultSummary *result_summary =
|
||||
reinterpret_cast<aicpu::FWKAdapter::ResultSummary *>((*outputResultTensor)->GetData());
|
||||
reinterpret_cast<aicpu::FWKAdapter::ResultSummary *>(outputResultTensor->GetData());
|
||||
result_summary->raw_data_size = data_size;
|
||||
result_summary->shape_data_size = shape_buff_size;
|
||||
|
||||
|
@ -113,8 +113,9 @@ uint32_t CpuKernelAllocatorUtils::UpdateOutputDataTensor(const std::vector<int64
|
|||
|
||||
int64_t CpuKernelAllocatorUtils::GetInputDataSize(const std::vector<int64_t> &dims, DataType type) {
|
||||
int64_t num_elements = 1;
|
||||
int64_t dim_size = 0;
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
int64_t dim_size = dims[i];
|
||||
dim_size = dims[i];
|
||||
KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, dim_size, num_elements, KERNEL_STATUS_PARAM_INVALID);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,9 +27,9 @@
|
|||
namespace aicpu {
|
||||
class AICPU_VISIBILITY CpuKernelAllocatorUtils {
|
||||
public:
|
||||
static uint32_t ParamCheck(const std::vector<int64_t> &dims, const void *data_ptr, Tensor **outputResultTensor);
|
||||
static uint32_t ParamCheck(const std::vector<int64_t> &dims, const void *data_ptr, Tensor *&outputResultTensor);
|
||||
static uint32_t UpdateOutputDataTensor(const std::vector<int64_t> &dims, DataType type, const void *data_ptr,
|
||||
int64_t input_data_size, Tensor **outputResultTensor);
|
||||
int64_t input_data_size, Tensor *&outputResultTensor);
|
||||
static uint32_t CheckOutputDataPtr(const uint64_t data_ptr);
|
||||
static uint32_t DeleteOutputDataPtr(const uint64_t data_ptr);
|
||||
static int64_t GetInputDataSize(const std::vector<int64_t> &dims, DataType type);
|
||||
|
|
|
@ -83,7 +83,7 @@ uint32_t Bcast::Init(const std::vector<int64_t> &x, const std::vector<int64_t> &
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
Bcast::Bcast(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape) : valid_(true) {
|
||||
Bcast::Bcast(std::vector<int64_t> &x_shape, std::vector<int64_t> &y_shape) : valid_(true) {
|
||||
if (x_shape == y_shape) {
|
||||
int64_t elements_num = 1;
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
|
@ -252,15 +252,15 @@ uint32_t Bcast::GenerateBcastInfo(const BCalcInfo &calcInfo) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
void Bcast::GetBcastVec(BCalcInfo *calcInfo) {
|
||||
calcInfo->reshape_0 = std::move(x_reshape_);
|
||||
calcInfo->reshape_1 = std::move(y_reshape_);
|
||||
calcInfo->shape_out = std::move(shape_out_);
|
||||
calcInfo->bcast_0 = std::move(x_bcast_);
|
||||
calcInfo->bcast_1 = std::move(y_bcast_);
|
||||
void Bcast::GetBcastVec(BCalcInfo &calcInfo) {
|
||||
calcInfo.reshape_0 = std::move(x_reshape_);
|
||||
calcInfo.reshape_1 = std::move(y_reshape_);
|
||||
calcInfo.shape_out = std::move(shape_out_);
|
||||
calcInfo.bcast_0 = std::move(x_bcast_);
|
||||
calcInfo.bcast_1 = std::move(y_bcast_);
|
||||
}
|
||||
|
||||
void Bcast::BCastIndexes(std::vector<int64_t> *x_indexes, std::vector<int64_t> *y_indexes) {
|
||||
void Bcast::BCastIndexes(std::vector<int64_t> &x_indexes, std::vector<int64_t> &y_indexes) {
|
||||
std::reverse(x_reshape_.begin(), x_reshape_.end());
|
||||
std::reverse(y_reshape_.begin(), y_reshape_.end());
|
||||
std::reverse(shape_out_.begin(), shape_out_.end());
|
||||
|
@ -281,8 +281,8 @@ void Bcast::BCastIndexes(std::vector<int64_t> *x_indexes, std::vector<int64_t> *
|
|||
int64_t y_bias = y_dim;
|
||||
|
||||
for (int64_t i = 0; i < out_dim; i++) {
|
||||
x_indexes->push_back(x_dim == 1 ? 0 : i);
|
||||
y_indexes->push_back(y_dim == 1 ? 0 : i);
|
||||
x_indexes.push_back(x_dim == 1 ? 0 : i);
|
||||
y_indexes.push_back(y_dim == 1 ? 0 : i);
|
||||
}
|
||||
|
||||
// Process the remaining dimensions
|
||||
|
@ -291,11 +291,11 @@ void Bcast::BCastIndexes(std::vector<int64_t> *x_indexes, std::vector<int64_t> *
|
|||
y_dim = y_reshape_.at(i); // i-th dimension of y.
|
||||
out_dim = shape_out_.at(i); // i-th dimension of shape_out_.
|
||||
|
||||
std::vector<int64_t>::size_type stride = x_indexes->size();
|
||||
std::vector<int64_t>::size_type stride = x_indexes.size();
|
||||
for (int64_t j = 1; j < out_dim; j++) {
|
||||
for (std::vector<int64_t>::size_type k = 0; k < stride; k++) {
|
||||
x_indexes->push_back(x_indexes->at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
|
||||
y_indexes->push_back(y_indexes->at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
|
||||
x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
|
||||
y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
|
||||
}
|
||||
}
|
||||
x_bias *= x_dim;
|
||||
|
|
|
@ -49,13 +49,13 @@ struct BCalcInfo {
|
|||
|
||||
class Bcast {
|
||||
public:
|
||||
Bcast() : valid_(true) {}
|
||||
Bcast(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape);
|
||||
Bcast() : valid_(true){};
|
||||
Bcast(std::vector<int64_t> &x_shape, std::vector<int64_t> &y_shape);
|
||||
~Bcast() = default;
|
||||
|
||||
uint32_t GenerateBcastInfo(const BCalcInfo &calcInfo);
|
||||
void GetBcastVec(BCalcInfo *calcInfo);
|
||||
void BCastIndexes(std::vector<int64_t> *x_indexes, std::vector<int64_t> *y_indexes);
|
||||
void GetBcastVec(BCalcInfo &calcInfo);
|
||||
void BCastIndexes(std::vector<int64_t> &x_indexes, std::vector<int64_t> &y_indexes);
|
||||
int64_t GetBroadcastXIndex(int64_t index) const;
|
||||
int64_t GetBroadcastYIndex(int64_t index) const;
|
||||
bool IsValid() const { return valid_; }
|
||||
|
|
|
@ -19,11 +19,11 @@
|
|||
#include <utility>
|
||||
|
||||
namespace aicpu {
|
||||
BroadcastIterator::BroadcastIterator(const std::vector<int64_t> &input_shape_a,
|
||||
const std::vector<int64_t> &input_shape_b, std::vector<int64_t> *output_shape)
|
||||
BroadcastIterator::BroadcastIterator(std::vector<int64_t> &input_shape_a, std::vector<int64_t> &input_shape_b,
|
||||
std::vector<int64_t> &output_shape)
|
||||
: input_shape_a_(std::move(input_shape_a)),
|
||||
input_shape_b_(std::move(input_shape_b)),
|
||||
output_shape_(std::move(*output_shape)) {
|
||||
output_shape_(std::move(output_shape)) {
|
||||
output_dimension_ = output_shape_.size(); // Assign dimension to int for iterator
|
||||
BroadcastShape();
|
||||
// Allocate strides memory
|
||||
|
@ -91,7 +91,7 @@ void BroadcastIterator::InitStrides() {
|
|||
}
|
||||
|
||||
uint32_t GetBroadcastShape(const std::vector<int64_t> &x, const std::vector<int64_t> &y,
|
||||
std::vector<int64_t> *broadcast_shape) {
|
||||
std::vector<int64_t> &broadcast_shape) {
|
||||
int64_t x_len = x.size();
|
||||
int64_t y_len = y.size();
|
||||
int64_t length = x_len < y_len ? x_len : y_len;
|
||||
|
@ -109,15 +109,15 @@ uint32_t GetBroadcastShape(const std::vector<int64_t> &x, const std::vector<int6
|
|||
}
|
||||
if (length == x_len) {
|
||||
for (int64_t i = 0; i < y_len - length; ++i) {
|
||||
broadcast_shape->push_back(y[i]);
|
||||
broadcast_shape.push_back(y[i]);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < x_len - length; ++i) {
|
||||
broadcast_shape->push_back(x[i]);
|
||||
broadcast_shape.push_back(x[i]);
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
broadcast_shape->push_back(broadcast_shape_back[i]);
|
||||
broadcast_shape.push_back(broadcast_shape_back[i]);
|
||||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@
|
|||
namespace aicpu {
|
||||
class BroadcastIterator {
|
||||
public:
|
||||
BroadcastIterator(const std::vector<int64_t> &input_shape_a, const std::vector<int64_t> &input_shape_b,
|
||||
std::vector<int64_t> *output_shape);
|
||||
BroadcastIterator(std::vector<int64_t> &input_shape_a, std::vector<int64_t> &input_shape_b,
|
||||
std::vector<int64_t> &output_shape);
|
||||
virtual ~BroadcastIterator() = default;
|
||||
inline int64_t GetInputPosA() const { return input_pos_[0]; }
|
||||
inline int64_t GetInputPosB() const { return input_pos_[1]; }
|
||||
|
@ -52,8 +52,7 @@ class BroadcastIterator {
|
|||
std::vector<int64_t> input_strides_b_;
|
||||
std::vector<int64_t> input_back_strides_a_;
|
||||
std::vector<int64_t> input_back_strides_b_;
|
||||
static const size_t size_two = 2;
|
||||
std::array<int64_t, size_two> input_pos_ = {{0, 0}};
|
||||
std::array<int64_t, 2> input_pos_ = {{0, 0}};
|
||||
size_t output_dimension_{0};
|
||||
};
|
||||
|
||||
|
@ -63,6 +62,6 @@ class BroadcastIterator {
|
|||
* @return status
|
||||
*/
|
||||
uint32_t GetBroadcastShape(const std::vector<int64_t> &x, const std::vector<int64_t> &y,
|
||||
std::vector<int64_t> *broadcast_shape);
|
||||
std::vector<int64_t> &broadcast_shape);
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -41,8 +41,8 @@ class DistinctUniformIntDistribution {
|
|||
}
|
||||
|
||||
template <typename Generator>
|
||||
ResultType exec(const Generator &engine) {
|
||||
if (!(uset_.size() < range_)) {
|
||||
ResultType exec(Generator &engine) {
|
||||
if (not(uset_.size() < range_)) {
|
||||
std::terminate();
|
||||
}
|
||||
ResultType res;
|
||||
|
|
|
@ -18,4 +18,4 @@
|
|||
|
||||
namespace aicpu {
|
||||
const Tensor *EigenTensor::GetTensor() const { return tensor_; }
|
||||
} // namespace aicpu
|
||||
} // namespace aicpu
|
||||
|
|
|
@ -26,18 +26,18 @@ namespace aicpu {
|
|||
* @return status code
|
||||
*/
|
||||
template <typename T>
|
||||
uint32_t EqualCalculate(const CpuKernelContext &ctx, BCalcInfo *calcInfo, bool flag) {
|
||||
auto input_x1 = reinterpret_cast<T *>(calcInfo->input_0->GetData());
|
||||
auto input_x2 = reinterpret_cast<T *>(calcInfo->input_1->GetData());
|
||||
auto output_y = reinterpret_cast<bool *>(calcInfo->output->GetData());
|
||||
uint32_t EqualCalculate(const CpuKernelContext &ctx, BCalcInfo &calcInfo, bool flag) {
|
||||
auto input_x1 = reinterpret_cast<T *>(calcInfo.input_0->GetData());
|
||||
auto input_x2 = reinterpret_cast<T *>(calcInfo.input_1->GetData());
|
||||
auto output_y = reinterpret_cast<bool *>(calcInfo.output->GetData());
|
||||
KERNEL_CHECK_NULLPTR(input_x1, KERNEL_STATUS_PARAM_INVALID, "Get input x1 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(input_x2, KERNEL_STATUS_PARAM_INVALID, "Get input x2 data failed.")
|
||||
KERNEL_CHECK_NULLPTR(output_y, KERNEL_STATUS_PARAM_INVALID, "Get output data failed.")
|
||||
size_t data_num = calcInfo->x_indexes.size();
|
||||
size_t data_num = calcInfo.x_indexes.size();
|
||||
auto shard_equal = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto x_index = input_x1 + calcInfo->x_indexes[i];
|
||||
auto y_index = input_x2 + calcInfo->y_indexes[i];
|
||||
auto x_index = input_x1 + calcInfo.x_indexes[i];
|
||||
auto y_index = input_x2 + calcInfo.y_indexes[i];
|
||||
output_y[i] = (flag == true) ? (*x_index == *y_index) : (*x_index != *y_index);
|
||||
}
|
||||
};
|
||||
|
@ -73,7 +73,7 @@ uint32_t EqualCompute(const CpuKernelContext &ctx, bool flag) {
|
|||
bcast.BCastIndexes(calcInfo.x_indexes, calcInfo.y_indexes);
|
||||
bcast.GetBcastVec(calcInfo);
|
||||
|
||||
return EqualCalculate<T>(ctx, &calcInfo, flag);
|
||||
return EqualCalculate<T>(ctx, calcInfo, flag);
|
||||
}
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -66,7 +66,7 @@ const std::map<Format, std::string> kFormatToStringMap = {
|
|||
{FORMAT_RESERVED, "FORMAT_RESERVED"},
|
||||
{FORMAT_ALL, "ALL"},
|
||||
{FORMAT_NULL, "NULL"}};
|
||||
} // namespace
|
||||
}
|
||||
|
||||
std::string FormatToSerialString(Format format) {
|
||||
auto it = kFormatToStringMap.find(static_cast<Format>(GetPrimaryFormat(static_cast<int32_t>(format))));
|
||||
|
@ -120,7 +120,7 @@ bool IsEmptyTensor(Tensor *tensor) {
|
|||
return false;
|
||||
}
|
||||
|
||||
uint32_t NormalMathCheck(const CpuKernelContext &ctx) {
|
||||
uint32_t NormalMathCheck(CpuKernelContext &ctx) {
|
||||
const uint32_t kInputNum = 2;
|
||||
const uint32_t kOutputNum = 1;
|
||||
|
||||
|
@ -151,7 +151,7 @@ uint32_t NormalMathCheck(const CpuKernelContext &ctx) {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num) {
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num) {
|
||||
if (inputs_num != kDynamicInput) {
|
||||
KERNEL_CHECK_FALSE((ctx.GetInputsSize() >= inputs_num), KERNEL_STATUS_PARAM_INVALID,
|
||||
"[%s] need [%u] inputs, but got [%u].", ctx.GetOpType().c_str(), inputs_num,
|
||||
|
@ -191,7 +191,7 @@ uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, con
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
const std::vector<std::string> &attr_names) {
|
||||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, inputs_num, outputs_num), "Check Greater params failed.");
|
||||
for (auto const &attr_name : attr_names) {
|
||||
|
@ -211,9 +211,6 @@ bool IsMatrix(const std::vector<int64_t> &shape) { return (shape.size() == 2); }
|
|||
bool IsSquareMatrix(const std::vector<int64_t> &shape) { return ((shape.size() == 2) && (shape[0] == shape[1])); }
|
||||
|
||||
bool AddrAlignedCheck(const void *addr, uint64_t alignment) {
|
||||
if (alignment == 0) {
|
||||
return false;
|
||||
}
|
||||
return reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(addr)) % alignment == 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -130,8 +130,7 @@ std::string FormatToSerialString(Format format);
|
|||
inline int32_t GetPrimaryFormat(int32_t format) { return static_cast<int32_t>(static_cast<uint32_t>(format) & 0xff); }
|
||||
|
||||
inline int32_t GetSubFormat(int32_t format) {
|
||||
constexpr size_t OffsetEight = 8;
|
||||
return static_cast<int32_t>((static_cast<uint32_t>(format) & 0xffff00) >> OffsetEight);
|
||||
return static_cast<int32_t>((static_cast<uint32_t>(format) & 0xffff00) >> 8);
|
||||
}
|
||||
|
||||
inline bool HasSubFormat(int32_t format) { return GetSubFormat(format) > 0; }
|
||||
|
@ -150,16 +149,16 @@ bool IsEmptyTensor(Tensor *tensor);
|
|||
* @param xy product of x and y
|
||||
* @return true: normal, false: overflow
|
||||
*/
|
||||
inline bool MulWithoutOverflow(const int64_t x, const int64_t y, int64_t *xy) {
|
||||
inline bool MulWithoutOverflow(const int64_t x, const int64_t y, int64_t &xy) {
|
||||
// Multiply in uint64 rather than int64 since signed overflow is undefined.
|
||||
// Negative values will wrap around to large unsigned values in the casts
|
||||
// (see section 4.7 [conv.integral] of the C++14 standard).
|
||||
const uint64_t ux = static_cast<uint64_t>(x);
|
||||
const uint64_t uy = static_cast<uint64_t>(y);
|
||||
const uint64_t uxy = ux * uy;
|
||||
constexpr size_t OffsetThirtyTwo = 32;
|
||||
|
||||
// Check if we overflow uint64, using a cheap check if both inputs are small
|
||||
if (((ux | uy) >> OffsetThirtyTwo) != 0) {
|
||||
if ((ux | uy) >> 32 != 0) {
|
||||
// Ensure nonnegativity. Note that negative numbers will appear "large"
|
||||
// to the unsigned comparisons above.
|
||||
if (x < 0 || y < 0) {
|
||||
|
@ -174,7 +173,7 @@ inline bool MulWithoutOverflow(const int64_t x, const int64_t y, int64_t *xy) {
|
|||
}
|
||||
|
||||
// Cast back to signed. Any negative value will signal an error.
|
||||
*xy = static_cast<int64_t>(uxy);
|
||||
xy = static_cast<int64_t>(uxy);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -185,13 +184,13 @@ inline bool MulWithoutOverflow(const int64_t x, const int64_t y, int64_t *xy) {
|
|||
* @param sum sum of x and y
|
||||
* @return true: normal, false: overflow
|
||||
*/
|
||||
inline bool AddWithoutOverflow(const int64_t x, const int64_t y, int64_t *sum) {
|
||||
inline bool AddWithoutOverflow(const int64_t x, const int64_t y, int64_t &sum) {
|
||||
const uint64_t ux = static_cast<uint64_t>(x);
|
||||
const uint64_t uy = static_cast<uint64_t>(y);
|
||||
const uint64_t usum = ux + uy;
|
||||
*sum = static_cast<int64_t>(usum);
|
||||
sum = static_cast<int64_t>(usum);
|
||||
|
||||
return !(((x >= 0) == (y >= 0)) && (((*sum) >= 0) != (x >= 0)));
|
||||
return !(((x >= 0) == (y >= 0)) && ((sum >= 0) != (x >= 0)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -199,7 +198,7 @@ inline bool AddWithoutOverflow(const int64_t x, const int64_t y, int64_t *sum) {
|
|||
* @param ctx context
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalMathCheck(const CpuKernelContext &ctx);
|
||||
uint32_t NormalMathCheck(CpuKernelContext &ctx);
|
||||
|
||||
/**
|
||||
* @brief normal check for kernel
|
||||
|
@ -208,7 +207,7 @@ uint32_t NormalMathCheck(const CpuKernelContext &ctx);
|
|||
* @param outputs_num num of outputs
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num);
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num);
|
||||
|
||||
/**
|
||||
* @brief normal check for kernel
|
||||
|
@ -218,7 +217,7 @@ uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, con
|
|||
* @param attr_names names of attrs
|
||||
* @return status code
|
||||
*/
|
||||
uint32_t NormalCheck(const CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
uint32_t NormalCheck(CpuKernelContext &ctx, const uint32_t inputs_num, const uint32_t outputs_num,
|
||||
const std::vector<std::string> &attr_names);
|
||||
|
||||
bool IsScalar(const std::vector<int64_t> &shape);
|
||||
|
@ -250,5 +249,6 @@ DataType DType(std::string dtype_str);
|
|||
* @return string of data type
|
||||
*/
|
||||
std::string DTypeStr(DataType dtype);
|
||||
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
|
@ -48,8 +48,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
kSliceGradOpName,
|
||||
kRandomShuffleOpName,
|
||||
kRangeOpName};
|
||||
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kACosOpName,
|
||||
mindspore::kAdaptiveAvgPool2dOpName,
|
||||
static const std::set<std::string> kMigrateAicpuKernelOps = {mindspore::kAdaptiveAvgPool2dOpName,
|
||||
mindspore::kAdaptiveAvgPool2dGradOpName,
|
||||
mindspore::kCacheSwapTableOpName,
|
||||
mindspore::kFillOpName,
|
||||
|
|
|
@ -166,3 +166,8 @@ from .parallel_concat import _parallel_concat_aicpu
|
|||
from .concat_offset import _concat_offset_aicpu
|
||||
from .range import _range_aicpu
|
||||
from .slice_grad import _slice_grad_aicpu
|
||||
from .median import _median_aicpu
|
||||
from .median_grad import _median_grad_aicpu
|
||||
from .reduce_sum import _reduce_sum_aicpu
|
||||
from .adaptive_avg_pool_2d_v1 import _adaptive_avg_pool_2d_v1_aicpu
|
||||
from .fill_v2 import _fill_v2_aicpu
|
||||
|
|
Loading…
Reference in New Issue