!3866 add register func for mean op

Merge pull request !3866 from fuzhiye/mindspore
This commit is contained in:
mindspore-ci-bot 2020-08-03 17:15:53 +08:00 committed by Gitee
commit 201bcdd9af
18 changed files with 236 additions and 40 deletions

View File

@ -0,0 +1,79 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
} // namespace
int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
return RET_ERROR;
}
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
return RET_NULL_PTR;
}
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
auto mean_prim = this->primitive->value_as_Mean();
bool keep_dims = static_cast<bool>(mean_prim->keepDims());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
const auto &axes = mean_prim->axis();
auto num_axes = axes->size();
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
for (auto i = 0; i < in_shape.size(); i++) {
out_shape.push_back(1);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
// reduce on selected axes
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i) {
reduce_axis = true;
break;
}
}
if (reduce_axis) {
if (keep_dims) {
out_shape.push_back(1);
}
} else {
out_shape.push_back(in_shape[i]);
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -384,6 +384,13 @@ class Fill : public Primitive {
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override; int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
}; };
class Mean : public Primitive {
public:
explicit Mean(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::Mean *GetAttribute() const { return this->primitive->value_as_Mean(); }
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
};
class ArgMax : public Primitive { class ArgMax : public Primitive {
public: public:
explicit ArgMax(schema::Primitive *primitive) : Primitive(primitive) {} explicit ArgMax(schema::Primitive *primitive) : Primitive(primitive) {}
@ -601,10 +608,11 @@ class SpaceToBatch : public Primitive {
explicit SpaceToBatch(schema::Primitive *primitive) : Primitive(primitive) {} explicit SpaceToBatch(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::SpaceToBatch *GetAttribute() const { return this->primitive->value_as_SpaceToBatch(); } const schema::SpaceToBatch *GetAttribute() const { return this->primitive->value_as_SpaceToBatch(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
std::vector<int> BlockSizes() {return block_sizes_;} std::vector<int> BlockSizes() { return block_sizes_; }
std::vector<int> Paddings() {return block_sizes_;} std::vector<int> Paddings() { return block_sizes_; }
std::vector<int> InShape() {return block_sizes_;} std::vector<int> InShape() { return block_sizes_; }
std::vector<int> PaddedInShape() {return block_sizes_;} std::vector<int> PaddedInShape() { return block_sizes_; }
private: private:
std::vector<int> block_sizes_; std::vector<int> block_sizes_;
std::vector<int> paddings_; std::vector<int> paddings_;

View File

@ -18,6 +18,7 @@
#include <float.h> #include <float.h>
#include "src/ops/ops.h" #include "src/ops/ops.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "schema/ops_generated.h"
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h" #include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h"
#include "src/runtime/kernel/arm/opclib/fp32/cast.h" #include "src/runtime/kernel/arm/opclib/fp32/cast.h"
@ -391,6 +392,30 @@ OpParameter *PopulateReduceParameter(const lite::Primitive *primitive) {
return reinterpret_cast<OpParameter *>(reduce_param); return reinterpret_cast<OpParameter *>(reduce_param);
} }
OpParameter *PopulateMeanParameter(const lite::Primitive *primitive) {
ReduceParameter *mean_param = new (std::nothrow) ReduceParameter();
if (mean_param == nullptr) {
MS_LOG(ERROR) << "new ReduceParameter failed.";
return nullptr;
}
mean_param->op_parameter_.type_ = primitive->Type();
auto mean = primitive->Value()->value_as_Mean();
mean_param->keep_dims_ = mean->keepDims();
auto axisVector = mean->axis();
if (axisVector->size() > REDUCE_MAX_AXES_NUM) {
MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM;
delete (mean_param);
return nullptr;
}
mean_param->num_axes_ = static_cast<int>(axisVector->size());
int i = 0;
for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) {
mean_param->axes_[i++] = *iter;
}
mean_param->mode_ = static_cast<int>(schema::ReduceMode_ReduceMean);
return reinterpret_cast<OpParameter *>(mean_param);
}
OpParameter *PopulatePadParameter(const lite::Primitive *primitive) { OpParameter *PopulatePadParameter(const lite::Primitive *primitive) {
PadParameter *pad_param = new (std::nothrow) PadParameter(); PadParameter *pad_param = new (std::nothrow) PadParameter();
if (pad_param == nullptr) { if (pad_param == nullptr) {
@ -1131,6 +1156,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter;
populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter; populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter; populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter;
populate_parameter_funcs_[schema::PrimitiveType_Mean] = PopulateMeanParameter;
populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter; populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter;
populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter; populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter;
populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter; populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter;

View File

@ -92,7 +92,7 @@ kernel::LiteKernel *CpuReshapeFp32KernelCreator(const std::vector<lite::tensor::
MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); MS_ASSERT(desc.type == schema::PrimitiveType_Reshape);
auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; MS_LOG(ERROR) << "new ReshapeCPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();

View File

@ -99,8 +99,10 @@ kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector<lite::tenso
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
} }
return kernel; return kernel;
} }

View File

@ -61,6 +61,7 @@ kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tenso
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr; return nullptr;
@ -70,4 +71,3 @@ kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tenso
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -46,12 +46,13 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::tensor::T
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
} }
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -74,8 +74,10 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
} }
return kernel; return kernel;
} }

View File

@ -26,6 +26,7 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR; using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Mean;
using mindspore::schema::PrimitiveType_Reduce; using mindspore::schema::PrimitiveType_Reduce;
using mindspore::schema::ReduceMode; using mindspore::schema::ReduceMode;
using mindspore::schema::ReduceMode_ReduceMax; using mindspore::schema::ReduceMode_ReduceMax;
@ -195,6 +196,27 @@ int ReduceCPUKernel::Run() {
return RET_OK; return RET_OK;
} }
int ReduceCPUKernel::MallocTmpBuffer() {
auto input_shape = inputs_.at(0)->shape();
for (auto i = 0; i < num_axes_ - 1; i++) {
int axis = axes_[i];
size_t size = 1;
for (auto j = 0; j < input_shape.size(); j++) {
if (static_cast<size_t>(axis) != j) {
size *= input_shape[j];
}
}
float *buffer = reinterpret_cast<float *>(malloc(size * sizeof(float)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
}
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
}
return RET_OK;
}
kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx, OpParameter *opParameter, const lite::Context *ctx,
@ -219,30 +241,42 @@ kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector<lite::tensor::T
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
} }
return kernel; return kernel;
} }
int ReduceCPUKernel::MallocTmpBuffer() { kernel::LiteKernel *CpuMeanFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
auto input_shape = inputs_.at(0)->shape(); const std::vector<lite::tensor::Tensor *> &outputs,
for (auto i = 0; i < num_axes_ - 1; i++) { OpParameter *opParameter, const lite::Context *ctx,
int axis = axes_[i]; const kernel::KernelKey &desc) {
size_t size = 1; MS_ASSERT(opParameter != nullptr);
for (auto j = 0; j < input_shape.size(); j++) { MS_ASSERT(desc.type == schema::PrimitiveType_Mean);
if (static_cast<size_t>(axis) != j) { if (opParameter == nullptr) {
size *= input_shape[j]; MS_LOG(ERROR) << "Reduce opParameter nullptr";
} return nullptr;
}
float *buffer = reinterpret_cast<float *>(malloc(size * sizeof(float)));
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
return RET_ERROR;
}
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
} }
return RET_OK; if (desc.type != schema::PrimitiveType_Mean) {
MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Mean, got " << desc.type;
return nullptr;
}
auto *kernel =
new (std::nothrow) ReduceCPUKernel(reinterpret_cast<ReduceParameter *>(opParameter), inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reduce, CpuReduceFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reduce, CpuReduceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mean, CpuMeanFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -69,7 +69,7 @@ kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector<lite::tensor::Te
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != 0) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel; delete kernel;

View File

@ -73,7 +73,7 @@ kernel::LiteKernel *CpuStridedSliceFp32KernelCreator(const std::vector<lite::ten
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != 0) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel; delete kernel;

View File

@ -82,7 +82,7 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor
MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze);
auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx); auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new AddNCPUKernel fail!"; MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();

View File

@ -30,9 +30,9 @@ using mindspore::schema::PrimitiveType_Activation;
namespace mindspore::kernel { namespace mindspore::kernel {
kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *parameter, const lite::Context *ctx, OpParameter *parameter, const lite::Context *ctx,
const KernelKey &desc) { const KernelKey &desc) {
if (parameter == nullptr) { if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr"; MS_LOG(ERROR) << "parameter is nullptr";
return nullptr; return nullptr;
@ -56,8 +56,10 @@ kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector<lite::tenso
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
return nullptr;
} }
return kernel; return kernel;
} }

View File

@ -23,6 +23,7 @@
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Add; using mindspore::schema::PrimitiveType_Add;
namespace mindspore::kernel { namespace mindspore::kernel {
@ -135,7 +136,7 @@ kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::tensor::Tens
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (0 != ret) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel; delete kernel;
@ -146,4 +147,3 @@ kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::tensor::Tens
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Add, CpuAddInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Add, CpuAddInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -18,8 +18,10 @@
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
#include "src/runtime/kernel/arm/opclib/errorcode.h" #include "src/runtime/kernel/arm/opclib/errorcode.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasAdd; using mindspore::schema::PrimitiveType_BiasAdd;
namespace mindspore::kernel { namespace mindspore::kernel {
@ -71,7 +73,7 @@ kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector<lite::tensor::
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (0 != ret) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel; delete kernel;
@ -82,4 +84,3 @@ kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector<lite::tensor::
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, CpuBiasAddInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, CpuBiasAddInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -105,7 +105,46 @@ void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight
#ifndef ENABLE_ARM32 #ifndef ENABLE_ARM32
void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t relu6) {} size_t relu6) {
for (int i = 0; i < TILE_NUM; i++) {
int input_tile_offset = i * C4NUM;
int output_tile_offset = i * output_channel;
for (int j = 0; j < output_channel; j++) {
int oc4_block = j / C4NUM;
int oc4_res = j % C4NUM;
int weight_oc_offset = oc4_block * step * ic4 * C4NUM * C4NUM + oc4_res;
int out_oc_offset = output_tile_offset + j;
float acc = 0;
for (int n = 0; n < step; n++) {
int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM;
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C4NUM;
for (int k = 0; k < ic4; k++) {
int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM;
int weight_ic4_offset = weight_kw_offset + k * C4NUM * C4NUM;
for (int m = 0; m < C4NUM; m++) {
int input_ic_offset = input_ic4_offset + m;
int weight_ic_offset = weight_ic4_offset + m * C4NUM;
acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0];
}
}
}
acc += bias[j];
if (relu) {
acc = acc > 0 ? acc : 0;
} else if (relu6) {
if (acc < 0) {
acc = 0;
} else if (acc > 6) {
acc = 6;
} else {
}
}
(output + out_oc_offset)[0] = acc;
}
}
}
#endif #endif
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }

View File

@ -175,9 +175,11 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, key); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, key);
if (kernel != nullptr) { if (kernel != nullptr) {
MS_LOG(INFO) << "Get fp16 op success.";
kernel->set_desc(desc); kernel->set_desc(desc);
return kernel; return kernel;
} }
MS_LOG(INFO) << "Get fp16 op failed, back to fp32 op.";
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc);
} else { } else {
kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc); kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context_, desc);

View File

@ -73,7 +73,7 @@ TEST_F(InferTest, TestConvNode) {
auto buf = new char *[1]; auto buf = new char *[1];
//================================================================ //================================================================
size_t weight_size; size_t weight_size;
std::string weight_path = "./convfp32_weight_32_3_3_3.bin"; std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin";
ReadFile(weight_path.c_str(), &weight_size, buf); ReadFile(weight_path.c_str(), &weight_size, buf);
ASSERT_NE(nullptr, buf[0]); ASSERT_NE(nullptr, buf[0]);
auto weight_data_temp = reinterpret_cast<float *>(buf[0]); auto weight_data_temp = reinterpret_cast<float *>(buf[0]);
@ -118,7 +118,7 @@ TEST_F(InferTest, TestConvNode) {
auto data = inTensor->MutableData(); auto data = inTensor->MutableData();
//=================================================== //===================================================
size_t input_size; size_t input_size;
std::string input_path = "./convfp32_input_1_28_28_3.bin"; std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin";
ReadFile(input_path.c_str(), &input_size, buf); ReadFile(input_path.c_str(), &input_size, buf);
ASSERT_NE(nullptr, buf[0]); ASSERT_NE(nullptr, buf[0]);
auto input_data = reinterpret_cast<float *>(buf[0]); auto input_data = reinterpret_cast<float *>(buf[0]);
@ -138,7 +138,7 @@ TEST_F(InferTest, TestConvNode) {
ASSERT_NE(nullptr, outData); ASSERT_NE(nullptr, outData);
//=================================================== //===================================================
size_t output_size; size_t output_size;
std::string output_path = "./convfp32_out_1_28_28_32.bin"; std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin";
ReadFile(output_path.c_str(), &output_size, buf); ReadFile(output_path.c_str(), &output_size, buf);
ASSERT_NE(nullptr, buf[0]); ASSERT_NE(nullptr, buf[0]);
auto output_data = reinterpret_cast<float *>(buf[0]); auto output_data = reinterpret_cast<float *>(buf[0]);
@ -146,7 +146,7 @@ TEST_F(InferTest, TestConvNode) {
//=================================================== //===================================================
ASSERT_EQ(output_size, outTensor->Size()); ASSERT_EQ(output_size, outTensor->Size());
for (size_t i = 0; i < outTensor->ElementsNum(); i++) { for (size_t i = 0; i < outTensor->ElementsNum(); i++) {
ASSERT_EQ(output_data[i], outData[i]); ASSERT_LE((output_data[i]- outData[i]), 0.001);
} }
MS_LOG(INFO) << "Passed"; MS_LOG(INFO) << "Passed";
} }