batch_to_space,depth_to_space,argmin,argmax support int8

This commit is contained in:
chenjianping 2020-08-03 17:20:29 +08:00
parent 5338128283
commit 131cad16e8
35 changed files with 2049 additions and 270 deletions

View File

@ -38,7 +38,7 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size;
return RET_PARAM_INVALID;
}
if (argmax_prim->topK() == -1) {
if (argmax_prim->topK() == 1) {
output_shape.erase(output_shape.begin() + axis);
} else if (argmax_prim->axisType() == 1) {
output_shape[axis] = argmax_prim->topK();

View File

@ -37,7 +37,7 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
return RET_PARAM_INVALID;
}
std::vector<int> output_shape(input->shape());
if (argmin_prim->topK() == -1) {
if (argmin_prim->topK() == 1) {
output_shape.erase(output_shape.begin() + axis);
} else if (argmin_prim->axisType() == 1) {
output_shape[axis] = argmin_prim->topK();

View File

@ -27,7 +27,7 @@
#include "src/runtime/kernel/arm/opclib/reshape_parameter.h"
#include "src/runtime/kernel/arm/opclib/fp32/stack.h"
#include "src/runtime/kernel/arm/opclib/unstack.h"
#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h"
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/kernel/arm/opclib/fp32/pooling.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
@ -56,7 +56,7 @@
#include "src/runtime/kernel/arm/opclib/fp32/gatherNd.h"
#include "src/runtime/kernel/arm/opclib/resize.h"
#include "src/runtime/kernel/arm/opclib/scatter_nd.h"
#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/opclib/batch_to_space.h"
#include "src/runtime/kernel/arm/opclib/fp32/crop.h"
#include "src/runtime/kernel/arm/fp32/flatten.h"
#include "src/runtime/kernel/arm/opclib/fp32/unsqueeze.h"

View File

@ -0,0 +1,149 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/arg_min_max_base.h"
#include "src/runtime/kernel/arm/opclib/arg_min_max.h"
#include "src/runtime/kernel/arm/fp32/argminmax.h"
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
#include "include/context.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_PARAM_INVALID;
using mindspore::lite::RET_FORMAT_ERR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ArgMax;
using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel {
int ArgMinMaxBaseCPUKernel::Init() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
switch (opParameter->type_) {
case PrimitiveType_ArgMax:
param->get_max_ = true;
break;
case PrimitiveType_ArgMin:
param->get_max_ = false;
break;
default:
MS_LOG(ERROR) << "Unexpected type " << opParameter->type_;
return RET_ERROR;
}
auto in_shape = inputs_.at(0)->shape();
auto dims_size = in_shape.size();
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
param->axis_ = axis;
param->dims_size_ = dims_size;
if (param->topk_ <= 0) {
MS_LOG(ERROR) << "Invalid topk " << param->topk_;
return RET_PARAM_INVALID;
}
param->topk_ = MSMIN(param->topk_, in_shape[axis]);
if (param->topk_ > 1) {
if (context_ != nullptr && context_->allocator != nullptr) {
param->arg_elements_
= reinterpret_cast<ArgElement *>(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis]));
data_from_allocator_ = true;
} else {
param->arg_elements_ = reinterpret_cast<ArgElement *>(malloc(sizeof(ArgElement) * in_shape[axis]));
}
if (param->arg_elements_ == nullptr) {
MS_LOG(ERROR) << "malloc memroy fail!";
return RET_ERROR;
}
}
return RET_OK;
}
int ArgMinMaxBaseCPUKernel::Run() {
auto input = inputs_.at(0);
auto input_data = reinterpret_cast<const void *>(inputs_.at(0)->Data());
auto output_data = outputs_.at(0)->Data();
auto shape = input->shape().data();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
ArgMinMax(input_data, output_data, reinterpret_cast<const int *>(shape), param);
return RET_OK;
}
void ArgMinMaxBaseCPUKernel::FreeTmpMemory() {
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
if (param->arg_elements_ == nullptr) {
return;
}
if (data_from_allocator_) {
context_->allocator->Free(param->arg_elements_);
} else {
free(param->arg_elements_);
}
param->arg_elements_ = nullptr;
}
kernel::LiteKernel *CpuArgMinMaxInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto kernel = new (std::nothrow) ArgMinMaxInt8CPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ArgMinMaxInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuArgMinMaxFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto kernel = new (std::nothrow) ArgMinMaxCPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ArgMinMaxCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ArgMax, CpuArgMinMaxInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_ArgMin, CpuArgMinMaxInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class ArgMinMaxBaseCPUKernel : public LiteKernel {
public:
ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), context_(ctx), data_from_allocator_(false) {
opParameter->thread_num_ = ctx->threadNum;
}
virtual ~ArgMinMaxBaseCPUKernel() {
FreeTmpMemory();
}
int Init() override;
int ReSize() override { return 0; }
int Run() override;
void FreeTmpMemory();
private:
const lite::Context *context_;
bool data_from_allocator_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARG_MIN_MAX_BASE_H_

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/batch_to_space_base.h"
#include "src/runtime/kernel/arm/opclib/batch_to_space.h"
#include "src/runtime/kernel/arm/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
#include "include/context.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_FORMAT_ERR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BatchToSpace;
namespace mindspore::kernel {
int BatchToSpaceBaseCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) {
no_crop_ = false;
}
}
return RET_OK;
}
kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace);
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) BatchToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace);
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BatchToSpace, CpuBatchToSpaceInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/concat_parameter.h"
namespace mindspore::kernel {
class BatchToSpaceBaseCPUKernel : public LiteKernel {
public:
BatchToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs) {
opParameter->thread_num_ = ctx->threadNum;
}
virtual ~BatchToSpaceBaseCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override { return 0; }
bool IsNoCrop() const {
return no_crop_;
}
private:
bool no_crop_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_BATCH_TO_SPACE_BASE_H_

View File

@ -34,7 +34,7 @@ class ConcatBaseCPUKernel : public LiteKernel {
concat_param_ = reinterpret_cast<ConcatParameter *>(opParameter);
}
~ConcatBaseCPUKernel() = default;
virtual ~ConcatBaseCPUKernel() = default;
int Init() override;

View File

@ -0,0 +1,114 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
#include "src/runtime/kernel/arm/fp32/depth_to_space.h"
#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h"
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
#include "include/context.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_PARAM_INVALID;
using mindspore::lite::RET_FORMAT_ERR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthToSpace;
namespace mindspore::kernel {
int DepthToSpaceBaseCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
if (param->block_size_ <= 0) {
MS_LOG(ERROR) << "Input block_size should > 0!";
return RET_PARAM_INVALID;
}
auto shape_size = inputs_[0]->shape().size();
if (shape_size != DIMENSION_4D) {
MS_LOG(ERROR) << "Input shape size should be " << DIMENSION_4D;
return RET_PARAM_INVALID;
}
int32_t in_strides[DIMENSION_4D];
ComputeStrides(const_cast<int *>(inputs_[0]->shape().data()), in_strides, shape_size);
param->in_stride_dim0_ = in_strides[0];
param->in_stride_dim1_ = in_strides[1];
param->in_stride_dim2_ = in_strides[2];
int32_t out_strides[DIMENSION_4D];
ComputeStrides(const_cast<int *>(outputs_[0]->shape().data()), out_strides, shape_size);
param->out_stride_dim0_ = out_strides[0];
param->out_stride_dim1_ = out_strides[1];
param->out_stride_dim2_ = out_strides[2];
return RET_OK;
}
kernel::LiteKernel *CpuDepthToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace);
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) DepthToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace);
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(op_parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthToSpace, CpuDepthToSpaceInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
namespace mindspore::kernel {
class DepthToSpaceBaseCPUKernel : public LiteKernel {
public:
DepthToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs) {
opParameter->thread_num_ = ctx->threadNum;
}
virtual ~DepthToSpaceBaseCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override { return 0; }
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_

View File

@ -18,7 +18,7 @@
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h"
#include "src/runtime/kernel/arm/opclib/arg_min_max.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -29,69 +29,19 @@ using mindspore::schema::PrimitiveType_ArgMax;
using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel {
namespace {
constexpr int kInputNum = 1;
constexpr int kOutputNum = 1;
} // namespace
int ArgMinMaxCPUKernel::Init() {
switch (opParameter->type_) {
case PrimitiveType_ArgMax:
get_max_ = true;
break;
case PrimitiveType_ArgMin:
get_max_ = false;
break;
default:
MS_LOG(ERROR) << "Unexpected type " << opParameter->type_;
return RET_ERROR;
auto ret = ArgMinMaxBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
auto dims_size = inputs_.at(0)->shape().size();
axis_ = reinterpret_cast<ArgMinMaxParameter *>(opParameter)->axis_;
axis_ = axis_ < 0 ? axis_ + dims_size : axis_;
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
param->data_type_ = kNumberTypeFloat32;
return RET_OK;
}
int ArgMinMaxCPUKernel::Run() {
auto input = inputs_.at(0);
auto input_data = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto output_data = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto shape = input->shape().data();
int dims_number = input->shape().size();
bool out_value = reinterpret_cast<ArgMinMaxParameter *>(opParameter)->out_value_;
if (get_max_) {
ArgMax(input_data, shape, dims_number, axis_, out_value, output_data);
} else {
ArgMin(input_data, shape, dims_number, axis_, out_value, output_data);
}
return RET_OK;
auto ret = ArgMinMaxBaseCPUKernel::Run();
ArgMinMaxBaseCPUKernel::FreeTmpMemory();
return ret;
}
kernel::LiteKernel *CpuArgMinMaxFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) ArgMinMaxCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ArgMinMaxCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -17,23 +17,20 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/arg_min_max_base.h"
namespace mindspore::kernel {
class ArgMinMaxCPUKernel : public LiteKernel {
class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel {
public:
ArgMinMaxCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs) : LiteKernel(parameter, inputs, outputs) {}
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~ArgMinMaxCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override;
private:
int axis_;
bool get_max_;
};
} // namespace mindspore::kernel

View File

@ -17,28 +17,14 @@
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/opclib/batch_to_space.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_FORMAT_ERR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BatchToSpace;
namespace mindspore::kernel {
int BatchToSpaceCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) {
no_crop_ = false;
}
}
return RET_OK;
return BatchToSpaceBaseCPUKernel::Init();
}
int BatchToSpaceCPUKernel::Run() {
@ -50,7 +36,7 @@ int BatchToSpaceCPUKernel::Run() {
auto out_shape = output->shape();
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
if (no_crop_) {
if (IsNoCrop()) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(float));
} else {
@ -60,31 +46,4 @@ int BatchToSpaceCPUKernel::Run() {
return RET_OK;
}
kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace);
auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,26 +15,21 @@
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/batch_to_space_base.h"
namespace mindspore::kernel {
class BatchToSpaceCPUKernel : public LiteKernel {
class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel {
public:
BatchToSpaceCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs), no_crop_(true) {}
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~BatchToSpaceCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override;
private:
bool no_crop_;
};
} // namespace mindspore::kernel

View File

@ -17,7 +17,8 @@
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h"
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
@ -30,15 +31,12 @@ using mindspore::schema::PrimitiveType_DepthToSpace;
namespace mindspore::kernel {
int DepthToSpaceCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;
auto ret = DepthToSpaceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
if (param->block_size_ <= 0) {
MS_LOG(ERROR) << "Input block_size should > 0!";
return RET_PARAM_INVALID;
}
param->data_type_size_ = sizeof(float);
return RET_OK;
}
@ -48,42 +46,13 @@ int DepthToSpaceCPUKernel::Run() {
const float *input_data = reinterpret_cast<const float *>(input->Data());
float *output_data = reinterpret_cast<float *>(output->Data());
auto in_shape = input->shape();
auto out_shape = output->shape();
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
if (input->GetFormat() == schema::Format_NHWC) {
DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape.data(), in_shape.size(),
param->block_size_);
DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param);
return RET_OK;
} else {
MS_LOG(ERROR) << "Only support NHWC now!";
MS_LOG(ERROR) << "Depth_to_space only support NHWC now!";
return RET_ERROR;
}
}
kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace);
auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -17,14 +17,14 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
namespace mindspore::kernel {
class DepthToSpaceCPUKernel : public LiteKernel {
class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
public:
DepthToSpaceCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs) : LiteKernel(parameter, inputs, outputs) {}
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~DepthToSpaceCPUKernel() = default;
int Init() override;

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/arg_min_max.h"
#include "include/errorcode.h"
using mindspore::lite::RET_OK;
using mindspore::lite::RET_ERROR;
namespace mindspore::kernel {
int ArgMinMaxInt8CPUKernel::Init() {
auto ret = ArgMinMaxBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
param->data_type_ = kNumberTypeInt8;
return RET_OK;
}
int ArgMinMaxInt8CPUKernel::Run() {
auto ret = ArgMinMaxBaseCPUKernel::Run();
FreeTmpMemory();
return ret;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/arg_min_max_base.h"
namespace mindspore::kernel {
class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel {
public:
ArgMinMaxInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~ArgMinMaxInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/batch_to_space.h"
#include "include/errorcode.h"
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int BatchToSpaceInt8CPUKernel::Init() {
return BatchToSpaceBaseCPUKernel::Init();
}
int BatchToSpaceInt8CPUKernel::Run() {
auto input = inputs_[0];
auto output = outputs_[0];
const int8_t *input_data = reinterpret_cast<const int8_t *>(input->Data());
int8_t *output_data = reinterpret_cast<int8_t *>(output->Data());
auto in_shape = input->shape();
auto out_shape = output->shape();
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
if (IsNoCrop()) {
BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
sizeof(int8_t));
} else {
BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
sizeof(int8_t));
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/batch_to_space_base.h"
namespace mindspore::kernel {
class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel {
public:
BatchToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~BatchToSpaceInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BATCH_TO_SPACE_INT8_H_

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
#include "include/errorcode.h"
using mindspore::lite::RET_OK;
using mindspore::lite::RET_ERROR;
namespace mindspore::kernel {
int DepthToSpaceInt8CPUKernel::Init() {
auto ret = DepthToSpaceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
param->data_type_size_ = sizeof(int8_t);
return RET_OK;
}
int DepthToSpaceInt8CPUKernel::Run() {
auto input = inputs_[0];
auto output = outputs_[0];
const int8_t *input_data = reinterpret_cast<const int8_t *>(input->Data());
int8_t *output_data = reinterpret_cast<int8_t *>(output->Data());
auto in_shape = input->shape();
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
if (input->GetFormat() == schema::Format_NHWC) {
DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), param);
return RET_OK;
} else {
MS_LOG(ERROR) << "Depth_to_space only support NHWC now!";
return RET_ERROR;
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
namespace mindspore::kernel {
class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
public:
DepthToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~DepthToSpaceInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEPTH_TO_SPACE_INT8_H_

View File

@ -0,0 +1,158 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/arg_min_max.h"
#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h"
#include "src/runtime/kernel/arm/opclib/int8/arg_min_max.h"
#define FLOAT_DATA_TYPE 43
#define INT8_DATA_TYPE 32
void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count,
int *after_axis_count) {
*pre_axis_count = 1;
for (int i = 0; i < axis; ++i) {
*pre_axis_count = (*pre_axis_count) * shape[i];
}
*axis_count = shape[axis];
*after_axis_count = 1;
for (int i = axis + 1; i < dims_number; ++i) {
*after_axis_count = (*after_axis_count) * shape[i];
}
}
void ArgMinMaxTopk1(const void *input, void *output, const int *shape, ArgMinMaxParameter *param) {
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
GetCalcParameter(shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
switch (param->data_type_) {
case FLOAT_DATA_TYPE: {
if (param->get_max_) {
ArgMax(reinterpret_cast<const float *>(input), reinterpret_cast<float *>(output), param, pre_axis_count,
axis_count, after_axis_count);
} else {
ArgMin(reinterpret_cast<const float *>(input), reinterpret_cast<float *>(output), param, pre_axis_count,
axis_count, after_axis_count);
}
break;
}
case INT8_DATA_TYPE: {
if (param->get_max_) {
ArgMax(reinterpret_cast<const int8_t *>(input), reinterpret_cast<int8_t *>(output), param, pre_axis_count,
axis_count, after_axis_count);
} else {
ArgMin(reinterpret_cast<const int8_t *>(input), reinterpret_cast<int8_t *>(output), param, pre_axis_count,
axis_count, after_axis_count);
}
break;
}
default:
break;
}
}
void ArgMinMaxTopknFp32(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->get_max_) {
switch (param->axis_) {
case 0:
ArgMaxDim0(input, output, in_shape, param);
break;
case 1:
ArgMaxDim1(input, output, in_shape, param);
break;
case 2:
ArgMaxDim2(input, output, in_shape, param);
break;
case 3:
ArgMaxDim3(input, output, in_shape, param);
break;
}
} else {
switch (param->axis_) {
case 0:
ArgMinDim0(input, output, in_shape, param);
break;
case 1:
ArgMinDim1(input, output, in_shape, param);
break;
case 2:
ArgMinDim2(input, output, in_shape, param);
break;
case 3:
ArgMinDim3(input, output, in_shape, param);
break;
}
}
}
void ArgMinMaxTopknInt8(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->get_max_) {
switch (param->axis_) {
case 0:
ArgMaxDim0(input, output, in_shape, param);
break;
case 1:
ArgMaxDim1(input, output, in_shape, param);
break;
case 2:
ArgMaxDim2(input, output, in_shape, param);
break;
case 3:
ArgMaxDim3(input, output, in_shape, param);
break;
}
} else {
switch (param->axis_) {
case 0:
ArgMinDim0(input, output, in_shape, param);
break;
case 1:
ArgMinDim1(input, output, in_shape, param);
break;
case 2:
ArgMinDim2(input, output, in_shape, param);
break;
case 3:
ArgMinDim3(input, output, in_shape, param);
break;
}
}
}
void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->topk_ == 1) {
ArgMinMaxTopk1(input, output, in_shape, param);
return;
}
switch (param->data_type_) {
case FLOAT_DATA_TYPE: {
ArgMinMaxTopknFp32(reinterpret_cast<const float *>(input), reinterpret_cast<float *>(output), in_shape, param);
return;
}
case INT8_DATA_TYPE: {
ArgMinMaxTopknInt8(reinterpret_cast<const int8_t *>(input), reinterpret_cast<int8_t *>(output), in_shape, param);
return;
}
default:
break;
}
}
#undef FLOAT_DATA_TYPE
#undef INT8_DATA_TYPE

View File

@ -13,17 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/arg_min_max_parameter.h"
struct DepthToSpaceParameter {
OpParameter op_parameter_;
int32_t block_size_;
};
void DepthToSpaceForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size,
int block_size);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_
void ArgMinMax(const void *input, void *output, const int *in_shape, ArgMinMaxParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_PARAMETER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_PARAMETER_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct ArgElement {
uint32_t index_;
union ArgData {
int8_t i8_data_;
int32_t i_data_;
float f_data_;
} data_;
};
struct ArgMinMaxParameter {
OpParameter op_parameter_;
bool out_value_;
bool keep_dims_;
bool get_max_;
int32_t axis_;
int32_t topk_;
int32_t axis_type_;
int32_t dims_size_;
int32_t data_type_; // equals to type_id
int32_t in_strides_[DIMENSION_4D];
int32_t out_strides_[DIMENSION_4D];
ArgElement *arg_elements_;
};
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARG_MIN_MAX_PARAMETER_H_

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/opclib/batch_to_space.h"
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block,

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCH_TO_SPACE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCH_TO_SPACE_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_BATCH_TO_SPACE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_BATCH_TO_SPACE_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/depth_to_space.h"
#include <string.h>
void DepthToSpaceForNHWC(const void *input, void *output, int *in_shape, DepthToSpaceParameter *param) {
int32_t block_size = param->block_size_;
int32_t in_shape_dim2 = in_shape[2];
int32_t in_shape_dim1 = in_shape[1];
size_t copy_size = block_size * param->out_stride_dim2_ * param->data_type_size_;
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_offset_n = i * param->in_stride_dim0_;
size_t out_offset_n = i * param->out_stride_dim0_;
for (int j = 0; j < in_shape_dim1; ++j) {
size_t in_offset_h = in_offset_n + j * param->in_stride_dim1_;
size_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_;
for (int k = 0; k < in_shape_dim2; ++k) {
size_t in_offset_w = in_offset_h + k * param->in_stride_dim2_;
size_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_;
for (int l = 0; l < block_size; ++l) {
size_t out_offset = (out_offset_w + l * param->out_stride_dim1_) * param->data_type_size_;
size_t in_offset = (in_offset_w + l * block_size * param->out_stride_dim2_) * param->data_type_size_;
memcpy(reinterpret_cast<int8_t *>(output) + out_offset, reinterpret_cast<const int8_t *>(input) + in_offset,
copy_size);
}
}
}
}
}

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_DEPTH_TO_SPACE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_DEPTH_TO_SPACE_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct DepthToSpaceParameter {
OpParameter op_parameter_;
int32_t block_size_;
int32_t in_stride_dim0_;
int32_t in_stride_dim1_;
int32_t in_stride_dim2_;
int32_t out_stride_dim0_;
int32_t out_stride_dim1_;
int32_t out_stride_dim2_;
uint8_t data_type_size_;
};
void DepthToSpaceForNHWC(const void *input, void *output, int *in_shape, DepthToSpaceParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_DEPTH_TO_SPACE_H_

View File

@ -14,33 +14,441 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h"
#include <stdlib.h>
#include <float.h>
void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count,
int *after_axis_count) {
*pre_axis_count = 1;
for (int i = 0; i < axis; ++i) {
*pre_axis_count = (*pre_axis_count) * shape[i];
}
int ArgCompareAscFp32(const void *a, const void *b) {
return reinterpret_cast<const ArgElement *>(a)->data_.f_data_
- reinterpret_cast<const ArgElement *>(b)->data_.f_data_;
}
*axis_count = shape[axis];
int ArgCompareDescFp32(const void *a, const void *b) {
return reinterpret_cast<const ArgElement *>(b)->data_.f_data_
- reinterpret_cast<const ArgElement *>(a)->data_.f_data_;
}
*after_axis_count = 1;
for (int i = axis + 1; i < dims_number; ++i) {
*after_axis_count = (*after_axis_count) * shape[i];
void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescFp32);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].data_.f_data_;
}
}
}
void ArgMax(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output) {
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
GetCalcParameter(shape, dims_number, axis, &pre_axis_count, &axis_count, &after_axis_count);
void ArgMaxDim0OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescFp32);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].index_;
}
}
}
void ArgMinDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscFp32);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].data_.f_data_;
}
}
}
void ArgMinDim0OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscFp32);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].index_;
}
}
}
void ArgMaxDim1OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescFp32);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].data_.f_data_;
}
}
}
}
void ArgMaxDim1OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescFp32);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].index_;
}
}
}
}
void ArgMinDim1OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscFp32);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].data_.f_data_;
}
}
}
}
void ArgMinDim1OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscFp32);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].index_;
}
}
}
}
void ArgMaxDim2OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMaxDim2OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMinDim2OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMinDim2OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMaxDim3OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMaxDim3OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMinDim3OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMinDim3OutIndex(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscFp32);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMaxDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim0OutValue(input, output, in_shape, param);
} else {
ArgMaxDim0OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim0OutValue(input, output, in_shape, param);
} else {
ArgMinDim0OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim1OutValue(input, output, in_shape, param);
} else {
ArgMaxDim1OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim1OutValue(input, output, in_shape, param);
} else {
ArgMinDim1OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim2OutValue(input, output, in_shape, param);
} else {
ArgMaxDim2OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim2OutValue(input, output, in_shape, param);
} else {
ArgMinDim2OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim3OutValue(input, output, in_shape, param);
} else {
ArgMaxDim3OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim3OutValue(input, output, in_shape, param);
} else {
ArgMinDim3OutIndex(input, output, in_shape, param);
}
}
void ArgMax(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
int64_t output_offset = i * after_axis_count;
int64_t input_offset = output_offset * axis_count;
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
float value = -FLT_MAX;
float index = 0.0f;
@ -56,15 +464,12 @@ void ArgMax(const float *input, const int *shape, int dims_number, int axis, boo
}
}
void ArgMin(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output) {
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
GetCalcParameter(shape, dims_number, axis, &pre_axis_count, &axis_count, &after_axis_count);
void ArgMin(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
int64_t output_offset = i * after_axis_count;
int64_t input_offset = output_offset * axis_count;
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
float value = FLT_MAX;
float index = 0.0f;
@ -79,4 +484,3 @@ void ArgMin(const float *input, const int *shape, int dims_number, int axis, boo
}
}
}

View File

@ -16,22 +16,18 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/arg_min_max_parameter.h"
// For arg min, arg max.
struct ArgMinMaxParameter {
OpParameter op_parameter_;
int axis_;
int topk_;
int axis_type_;
bool out_value_;
bool keep_dims_;
};
void ArgMax(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output);
void ArgMin(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output);
void ArgMax(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count);
void ArgMin(const float *input, float *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count);
void ArgMaxDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim0(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim1(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim2(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim3(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_

View File

@ -1,43 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h"
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
void DepthToSpaceForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size,
int block_size) {
int *in_strides = (int *)(malloc(sizeof(int) * shape_size));
ComputeStrides(in_shape, in_strides, shape_size);
int *out_strides = (int *)(malloc(sizeof(int) * shape_size));
ComputeStrides(out_shape, out_strides, shape_size);
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_offset_n = i * in_strides[0];
size_t out_offset_n = i * out_strides[0];
for (int j = 0; j < in_shape[1]; ++j) {
size_t in_offset_h = in_offset_n + j * in_strides[1];
size_t out_offset_h = out_offset_n + j * block_size * out_strides[1];
for (int k = 0; k < in_shape[2]; ++k) {
size_t in_offset_w = in_offset_h + k * in_strides[2];
size_t out_offset_w = out_offset_h + k * block_size * out_strides[2];
for (int l = 0; l < block_size; ++l) {
memcpy(output + out_offset_w + l * out_strides[1], input + in_offset_w + l * block_size * out_strides[2],
block_size * out_strides[2] * 4);
}
}
}
}
free(out_strides);
free(in_strides);
}

View File

@ -0,0 +1,488 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/int8/arg_min_max.h"
#define INT8_MAX_VALUE 127
int ArgCompareAscInt8(const void *a, const void *b) {
return reinterpret_cast<const ArgElement *>(a)->data_.i8_data_
- reinterpret_cast<const ArgElement *>(b)->data_.i8_data_;
}
int ArgCompareDescInt8(const void *a, const void *b) {
return reinterpret_cast<const ArgElement *>(b)->data_.i8_data_
- reinterpret_cast<const ArgElement *>(a)->data_.i8_data_;
}
void ArgMaxDim0OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].data_.f_data_;
}
}
}
void ArgMaxDim0OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].index_;
}
}
}
void ArgMinDim0OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].data_.f_data_;
}
}
}
void ArgMinDim0OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->arg_elements_[j].index_;
}
}
}
void ArgMaxDim1OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].data_.f_data_;
}
}
}
}
void ArgMaxDim1OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].index_;
}
}
}
}
void ArgMinDim1OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].data_.f_data_;
}
}
}
}
void ArgMinDim1OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->arg_elements_[k].index_;
}
}
}
}
void ArgMaxDim2OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMaxDim2OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMinDim2OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMinDim2OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMaxDim3OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMaxDim3OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMinDim3OutValue(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
void ArgMinDim3OutIndex(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] = param->arg_elements_[l].index_;
}
}
}
}
}
void ArgMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim0OutValue(input, output, in_shape, param);
} else {
ArgMaxDim0OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim0OutValue(input, output, in_shape, param);
} else {
ArgMinDim0OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim1OutValue(input, output, in_shape, param);
} else {
ArgMaxDim1OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim1OutValue(input, output, in_shape, param);
} else {
ArgMinDim1OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim2OutValue(input, output, in_shape, param);
} else {
ArgMaxDim2OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim2OutValue(input, output, in_shape, param);
} else {
ArgMinDim2OutIndex(input, output, in_shape, param);
}
}
void ArgMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMaxDim3OutValue(input, output, in_shape, param);
} else {
ArgMaxDim3OutIndex(input, output, in_shape, param);
}
}
void ArgMinDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param) {
if (param->out_value_) {
ArgMinDim3OutValue(input, output, in_shape, param);
} else {
ArgMinDim3OutIndex(input, output, in_shape, param);
}
}
void ArgMax(const int8_t *input, int8_t *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
int8_t value = -INT8_MAX_VALUE;
int8_t index = 0;
for (int k = 0; k < axis_count; ++k) {
int8_t value_tmp = input[input_offset + k * after_axis_count + j];
if (value_tmp > value) {
value = value_tmp;
index = k;
}
}
output[output_offset + j] = out_value ? value : index;
}
}
}
void ArgMin(const int8_t *input, int8_t *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
int8_t value = INT8_MAX_VALUE;
int8_t index = 0;
for (int k = 0; k < axis_count; ++k) {
int8_t value_tmp = input[input_offset + k * after_axis_count + j];
if (value_tmp < value) {
value = value_tmp;
index = k;
}
}
output[output_offset + j] = out_value ? value : index;
}
}
}
#undef INT8_MAX_VALUE

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_ARG_MIN_MAX_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_ARG_MIN_MAX_H_
#include "src/runtime/kernel/arm/opclib/arg_min_max_parameter.h"
void ArgMax(const int8_t *input, int8_t *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count);
void ArgMin(const int8_t *input, int8_t *output, ArgMinMaxParameter *param, int pre_axis_count, int axis_count,
int after_axis_count);
void ArgMaxDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim0(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim1(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim2(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMaxDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
void ArgMinDim3(const int8_t *input, int8_t *output, const int *in_shape, ArgMinMaxParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_ARG_MIN_MAX_H_

View File

@ -32,7 +32,7 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
// These are caffe attributes, set to default value.
attr->axisType = 1;
attr->outMaxValue = false;
attr->topK = -1;
attr->topK = 1;
attr->keepDims = false;
if (op != nullptr) {