!9347 [lite] add reciprocal op and adjust tile、split

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-08 21:02:28 +08:00 committed by Gitee
commit c04304337a
35 changed files with 336 additions and 82 deletions

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <math.h>
#include "nnacl/fp16/arithmetic_self_fp16.h"
@ -108,3 +109,11 @@ int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) {
}
return NNACL_OK;
}
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) {
assert(input[i] != 0.0f);
output[i] = 1.f / input[i];
}
return NNACL_OK;
}

View File

@ -48,6 +48,8 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size);
int ElementCeilFp16(float16_t *input, float16_t *output, int number);
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size);
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size);
#ifdef __cplusplus
}
#endif

View File

@ -16,6 +16,7 @@
#include <string.h>
#include <math.h>
#include <assert.h>
#include "nnacl/fp32/arithmetic_self_fp32.h"
// abs:
@ -128,3 +129,11 @@ int ElementNegative(const float *input, float *output, const int element_size) {
}
return NNACL_OK;
}
int ElementReciprocal(const float *input, float *output, const int element_size) {
for (int i = 0; i < element_size; ++i) {
assert(input[i] != 0.0f);
output[i] = 1.f / input[i];
}
return NNACL_OK;
}

View File

@ -51,6 +51,8 @@ int ElementFloor(const float *input, float *output, const int element_size);
int ElementCeil(const float *input, float *output, const int number);
int ElementNegative(const float *input, float *output, const int element_size);
int ElementReciprocal(const float *input, float *output, const int element_size);
#ifdef __cplusplus
}
#endif

View File

@ -15,6 +15,7 @@
*/
#include <math.h>
#include <assert.h>
#include "nnacl/int8/arithmetic_self_int8.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
@ -278,3 +279,24 @@ int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, Arith
}
return NNACL_OK;
}
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) {
float in_scale = para.in_args_.scale_;
int32_t in_zp = para.in_args_.zp_;
float out_scale = para.out_args_.scale_;
int32_t out_zp = para.out_args_.zp_;
float bias = in_zp * in_scale;
for (int i = 0; i < element_size; i++) {
float input_f32 = input[i] * in_scale + bias;
assert(input_f32 != 0.0f);
int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp;
if (output_tmp > para.output_activation_max_) {
output[i] = para.output_activation_max_;
} else if (output_tmp < para.output_activation_min_) {
output[i] = para.output_activation_min_;
} else {
output[i] = (int8_t)output_tmp;
}
}
return NNACL_OK;
}

View File

@ -50,6 +50,8 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf
int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
#ifdef __cplusplus
}
#endif

View File

@ -253,7 +253,8 @@ union PrimitiveType {
All,
Assert,
Adder,
SparseSoftmaxCrossEntropy
SparseSoftmaxCrossEntropy,
Reciprocal,
}
enum QuantType: int {

View File

@ -1203,3 +1203,6 @@ table All {
table Assert {
summarize : int;
}
table Reciprocal {
}

View File

@ -375,11 +375,11 @@ void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output
int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (inputs_.size() != 2 && inputs_.size() != 3) {
MS_LOG(ERROR) << "Add should has two or three inputs";
MS_LOG(ERROR) << "Conv2d should has two or three inputs";
return RET_ERROR;
}
if (outputs_.size() != 1) {
MS_LOG(ERROR) << "Add should has one outputs";
MS_LOG(ERROR) << "Conv2d should has one outputs";
return RET_ERROR;
}
auto *input_tensor = inputs_.front();

View File

@ -47,6 +47,7 @@ Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateA
Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf);
Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf);
Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf);
Registry ReciprocalParameterRegistry(schema::PrimitiveType_Reciprocal, PopulateArithmeticSelf);
} // namespace lite
} // namespace mindspore

View File

@ -31,7 +31,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
memset(split_param, 0, sizeof(SplitParameter));
auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
split_param->op_parameter_.type_ = primitive->Type();
split_param->num_split_ = param->GetNumberSplit();
split_param->num_split_ = param->num_split();
if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "The value of split_param->num_split_ is too big";
return nullptr;
@ -44,7 +44,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
}
memset(split_sizes, 0, split_param->num_split_ * sizeof(int));
split_param->split_sizes_ = split_sizes;
auto split_sizes_vector_ = param->GetSizeSplits();
auto split_sizes_vector_ = param->size_splits();
int i = 0;
for (int &iter : split_sizes_vector_) {
split_param->split_sizes_[i++] = iter;

View File

@ -43,8 +43,10 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive)
for (size_t i = 0; i < kDimension_4d; ++i) {
tile_param->multiples_[i] = 1;
}
for (size_t i = 0; i < dims.size(); ++i) {
tile_param->multiples_[dims.at(i)] = multiples.at(i);
if (!dims.empty() && !multiples.empty()) {
for (size_t i = 0; i < dims.size(); ++i) {
tile_param->multiples_[dims[i]] = multiples[i];
}
}
#endif
return reinterpret_cast<OpParameter *>(tile_param);

View File

@ -148,6 +148,7 @@
#include "src/ops/while.h"
#include "src/ops/oneslike.h"
#include "src/ops/unsorted_segment_sum.h"
#include "src/ops/reciprocal.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -888,6 +889,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Quant(primitive);
case schema::PrimitiveType_OnnxInt8Dequantize:
return new (std::nothrow) Dequant(primitive);
case schema::PrimitiveType_Reciprocal:
return new (std::nothrow) Reciprocal(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

View File

@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/reciprocal.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<Reciprocal>(primitive);
}
Registry ReciprocalRegistry(schema::PrimitiveType_Reciprocal, ReciprocalCreator);
#endif
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
#include "src/ops/arithmetic_self.h"
namespace mindspore {
namespace lite {
class Reciprocal : public ArithmeticSelf {
public:
Reciprocal() = default;
~Reciprocal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf);
explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateReciprocal(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reciprocal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; }
std::vector<int> Split::GetSizeSplits() const { return this->primitive_->value.AsSplit()->sizeSplits; }
std::vector<int> Split::GetSizeSplit() const { return this->primitive_->value.AsSplit()->sizeSplits; }
int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; }
void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; }
@ -67,7 +67,7 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
#else
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
std::vector<int> Split::GetSizeSplits() const {
std::vector<int> Split::GetSizeSplit() const {
auto fb_vector = this->primitive_->value_as_Split()->sizeSplits();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
@ -108,42 +108,50 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
return RET_ERROR;
}
auto output = outputs_.front();
if (output == nullptr) {
MS_LOG(ERROR) << "output null pointer dereferencing.";
if (outputs_.empty()) {
MS_LOG(ERROR) << "split has no output.";
return RET_ERROR;
}
int number_split = GetNumberSplit();
if (static_cast<int>(outputs_.size()) != number_split) {
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
return RET_ERROR;
}
for (int i = 0; i < number_split; ++i) {
outputs_.at(i)->set_data_type(input->data_type());
outputs_.at(i)->set_format(input->format());
for (auto &output : outputs_) {
output->set_data_type(input->data_type());
output->set_format(input->format());
}
size_splits_ = GetSizeSplit();
num_split_ = GetNumberSplit() == 0 ? static_cast<int>(outputs_.size()) : GetNumberSplit();
if (!infer_flag()) {
return RET_INFER_INVALID;
}
size_t split_dim = GetSplitDim() == -1 ? input->shape().size() - 1 : GetSplitDim();
size_t split_dim = GetSplitDim() < 0 ? input->shape().size() + GetSplitDim() : GetSplitDim();
std::vector<int> input_shape = input->shape();
std::vector<int> size_split;
for (size_t i = 0; i < GetSizeSplits().size(); ++i) {
size_split.push_back(GetSizeSplits().at(i));
if (split_dim > input_shape.size()) {
MS_LOG(ERROR) << "split dim is out of range, which is " << input_shape.size();
return RET_INPUT_PARAM_INVALID;
}
for (int i = 0; i < number_split; ++i) {
if (static_cast<int>(outputs_.size()) != num_split_) {
MS_LOG(ERROR) << "outputs number is not equal to " << num_split_;
return RET_ERROR;
}
if (size_splits_.empty()) {
if (input_shape[split_dim] % num_split_ != 0) {
MS_LOG(ERROR) << "cannot split to equal size, which dim is " << input_shape[split_dim] << ", num split is "
<< num_split_;
return RET_INPUT_PARAM_INVALID;
}
for (int i = 0; i < num_split_; ++i) {
size_splits_.push_back(input_shape[split_dim] / num_split_);
}
}
for (int i = 0; i < num_split_; ++i) {
std::vector<int> output_shape;
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
int split_dim_i = input_shape.at(split_dim);
// support split size is -1 in the end.
if (size_split.empty()) {
split_dim_i = input_shape.at(split_dim) / number_split;
} else if (i == number_split - 1 && size_split.at(i) == -1) {
for (size_t j = 0; j < size_split.size() - 1; ++j) {
split_dim_i -= size_split.at(j);
if (i == num_split_ - 1 && size_splits_[i] == -1) {
for (size_t j = 0; j < size_splits_.size() - 1; ++j) {
split_dim_i -= size_splits_[j];
}
} else {
split_dim_i = size_split.at(i);
split_dim_i = size_splits_[i];
}
output_shape.at(split_dim) = split_dim_i;
outputs_.at(i)->set_shape(output_shape);

View File

@ -42,8 +42,14 @@ class Split : public PrimitiveC {
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetNumberSplit() const;
std::vector<int> GetSizeSplits() const;
std::vector<int> GetSizeSplit() const;
int GetSplitDim() const;
int num_split() const { return num_split_; }
std::vector<int> size_splits() const { return size_splits_; }
protected:
int num_split_ = 0;
std::vector<int> size_splits_;
};
} // namespace lite
} // namespace mindspore

View File

@ -139,8 +139,22 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
std::vector<int> out_shape;
std::vector<int> multiples = GetMultiples();
std::vector<int> multiples;
if (inputs_.size() == 2) {
if (inputs_[1]->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
int data_num = inputs_[1]->ElementsNum();
if (data_num > static_cast<int>(input->shape().size())) {
MS_LOG(ERROR) << "multiples data num cannot be larger than input shape size.";
return RET_INPUT_TENSOR_ERROR;
}
multiples.resize(data_num);
memcpy(multiples.data(), inputs_[1]->data_c(), inputs_[1]->Size());
} else {
multiples = GetMultiples();
}
#ifdef SUPPORT_TRAIN
const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size();
@ -156,6 +170,11 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
#else
std::vector<int> dims = GetDims();
if (inputs_.size() == 2 && dims.empty()) {
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
dims.push_back(dim);
}
}
const size_t in_dims = input->shape().size();
MS_ASSERT(multiples.size() == dims.size());

View File

@ -38,7 +38,7 @@ int SplitBaseCPUKernel::ReSize() {
auto input_shape = in_tensor->shape();
MS_ASSERT(param);
MS_ASSERT(input_shape.size() >= 2 && input_shape.size() <= SPLIT_STRIDES_SIZE);
MS_ASSERT(input_shape.size() >= 1 && input_shape.size() <= SPLIT_STRIDES_SIZE);
param->strides_[input_shape.size() - 1] = 1;
for (int i = input_shape.size() - 2; i >= 0; i--) {
param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1);
@ -50,8 +50,8 @@ int SplitBaseCPUKernel::ReSize() {
param->n_dims_ = input_shape.size();
if (param->split_sizes_[0] == 0) {
MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) < input_shape.size());
if (input_shape.at(param->split_dim_) % param->num_split_ != 0) {
MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) <= input_shape[param->split_dim_]);
if (input_shape[param->split_dim_] % param->num_split_ != 0) {
MS_LOG(ERROR) << "Default split size is not usable.";
return RET_ERROR;
}

View File

@ -43,7 +43,8 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int
{mindspore::schema::PrimitiveType_Floor, ElementFloorFp16},
{mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16},
{mindspore::schema::PrimitiveType_Round, ElementRoundFp16},
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}};
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16},
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16}};
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
@ -139,4 +140,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Floor, CpuArithmeticSelfFp16K
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reciprocal, CpuArithmeticSelfFp16KernelCreator)
} // namespace mindspore::kernel

View File

@ -41,7 +41,8 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
{mindspore::schema::PrimitiveType_Floor, ElementFloor},
{mindspore::schema::PrimitiveType_Ceil, ElementCeil},
{mindspore::schema::PrimitiveType_Round, ElementRound},
{mindspore::schema::PrimitiveType_Neg, ElementNegative}};
{mindspore::schema::PrimitiveType_Neg, ElementNegative},
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocal}};
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
@ -152,4 +153,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32K
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Neg, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reciprocal, CpuArithmeticSelfFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -26,6 +26,7 @@ using mindspore::schema::PrimitiveType_Floor;
using mindspore::schema::PrimitiveType_Log;
using mindspore::schema::PrimitiveType_LogicalNot;
using mindspore::schema::PrimitiveType_Neg;
using mindspore::schema::PrimitiveType_Reciprocal;
using mindspore::schema::PrimitiveType_Round;
using mindspore::schema::PrimitiveType_Rsqrt;
using mindspore::schema::PrimitiveType_Sin;

View File

@ -24,6 +24,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Tile;
namespace mindspore::kernel {
namespace {
constexpr size_t kDoubleInputsSize = 2;
}
int TileCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
@ -42,6 +45,17 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
int TileCPUKernel::ReSize() {
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
MS_ASSERT(tile_parameter_);
if (in_tensors_.size() == kDoubleInputsSize) {
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size.";
return false;
}
auto input1_addr = reinterpret_cast<int *>(in_tensors_[1]->data_c());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
tile_parameter_->dims_[i] = i;
tile_parameter_->multiples_[i] = input1_addr[i];
}
}
tile_parameter_->in_dim_ = in_tensors_.at(0)->shape().size();
for (int i = 0; i < tile_parameter_->in_dim_; ++i) {
tile_parameter_->in_shape_[i] = in_tensors_.at(0)->shape().at(i);
@ -93,4 +107,5 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector<lite::Tensor *> &
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reciprocal, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -31,6 +31,7 @@ using mindspore::schema::PrimitiveType_Cos;
using mindspore::schema::PrimitiveType_Floor;
using mindspore::schema::PrimitiveType_Log;
using mindspore::schema::PrimitiveType_LogicalNot;
using mindspore::schema::PrimitiveType_Reciprocal;
using mindspore::schema::PrimitiveType_Round;
using mindspore::schema::PrimitiveType_Rsqrt;
using mindspore::schema::PrimitiveType_Sin;
@ -80,6 +81,8 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel {
case PrimitiveType_LogicalNot:
arithmeticSelf_run_ = Int8ElementLogicalNot;
break;
case PrimitiveType_Reciprocal:
arithmeticSelf_run_ = Int8ElementReciprocal;
default:
break;
}

View File

@ -690,6 +690,29 @@ STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReciprocalParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ReciprocalT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Reciprocal;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
@ -720,5 +743,6 @@ OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser());
OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser());
OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser());
OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser());
OnnxNodeRegistrar g_onnxReciprocalParser("Reciprocal", new OnnxReciprocalParser());
} // namespace lite
} // namespace mindspore

View File

@ -217,6 +217,13 @@ class OnnxRoundParser : public OnnxNodeParser {
~OnnxRoundParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxReciprocalParser : public OnnxNodeParser {
public:
OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {}
~OnnxReciprocalParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

View File

@ -41,8 +41,11 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
auto dst_type = OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i()));
if (dst_type == kNumberTypeInt64) {
dst_type = kNumberTypeInt32;
}
attr->dstT = static_cast<int>(dst_type);
}
}

View File

@ -105,7 +105,7 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
MS_LOG(ERROR) << "new tensor failed";
return RET_ERROR;
}
tensor->dataType = data_type;
tensor->dataType = data_type == kNumberTypeInt64 ? kNumberTypeInt32 : data_type;
tensor->dims = GetDimsFromOnnxValue(proto);
tensor->format = schema::Format::Format_NCHW;
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
@ -370,7 +370,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name;
return;
}
quant_param->inited = true;
int argNum = 0;
for (const auto &onnx_node_attr : node.attribute()) {
if (onnx_node_attr.name() == "Y_scale") {
@ -382,11 +381,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
}
}
if (argNum != 2) {
quant_param->scale = FLT_MAX;
quant_param->zeroPoint = 0;
quant_param->min = FLT_MAX;
quant_param->max = FLT_MAX;
quant_param->inited = false;
continue;
}
dst_tensor->quantParams.emplace_back(std::move(quant_param));
if (argNum == 2) {

View File

@ -15,7 +15,9 @@
*/
#include "tools/converter/parser/onnx/onnx_slice_parser.h"
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include <string>
@ -46,6 +48,35 @@ STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std
return RET_OK;
}
STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) {
if (onnx_val == nullptr) {
MS_LOG(ERROR) << "input vector is nullptr.";
return RET_ERROR;
}
if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) {
MS_LOG(ERROR) << "cannot get tensorcache.";
return RET_ERROR;
}
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << name;
return RET_ERROR;
}
auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (input_tensor->data.empty()) {
MS_LOG(DEBUG) << "data is empty.";
return RET_NO_CHANGE;
}
int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>());
onnx_val->resize(data_num);
if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) !=
EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
return RET_OK;
}
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser";
@ -97,6 +128,36 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
}
int status = RET_OK;
switch (onnx_node.input_size()) {
case 5: {
if (steps.empty()) {
status = GetInputTensor(&steps, onnx_node.input(4));
}
}
case 4: {
if (status != RET_ERROR && axes.empty()) {
status = GetInputTensor(&axes, onnx_node.input(3));
}
}
case 3: {
if (status != RET_ERROR && ends.empty()) {
status = GetInputTensor(&ends, onnx_node.input(2));
}
}
case 2: {
if (status != RET_ERROR && starts.empty()) {
status = GetInputTensor(&starts, onnx_node.input(1));
}
}
default: {
if (status == RET_ERROR) {
MS_LOG(ERROR) << "onnx slice inputs are invalid.";
return RET_INPUT_TENSOR_ERROR;
}
}
}
if (axes.empty()) {
for (size_t i = 0; i < starts.size(); ++i) {
axes.push_back(i);
@ -112,7 +173,6 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
int insert_num = 5 - onnx_node.input_size();
int status = RET_OK;
switch (insert_num) {
case 4: {
std::string name = "slice/starts/";

View File

@ -32,6 +32,7 @@ class OnnxSliceParser : public OnnxNodeParser {
STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name);
};
} // namespace lite
} // namespace mindspore

View File

@ -38,6 +38,7 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
return RET_NULL_PTR;
}
attr->splitDim = 0;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {

View File

@ -16,9 +16,7 @@
#include "tools/converter/parser/onnx/onnx_tile_parser.h"
#include <memory>
#include <numeric>
#include <vector>
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
namespace mindspore {
namespace lite {
@ -39,26 +37,6 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &onnx_tile_multiple = onnx_node.input(1);
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple;
return RET_ERROR;
}
auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (tile_attr->data.data() == nullptr) {
MS_LOG(ERROR) << "power's attr pow can't be obtained.";
return RET_INVALID_OP_ATTR;
}
int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies<int>());
std::vector<int> multiples;
std::vector<int> dims;
for (int i = 0; i < element_size; ++i) {
multiples.push_back(reinterpret_cast<int *>(tile_attr->data.data())[i]);
dims.push_back(i);
}
attr->multiples = multiples;
attr->dims = dims;
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
return RET_OK;

View File

@ -23,7 +23,6 @@ namespace mindspore {
namespace lite {
PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
@ -35,16 +34,6 @@ PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr<tflite::O
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->multiples)) {
MS_LOG(ERROR) << "get tile -> multiples failed";
return nullptr;
}
std::vector<int> dims(attr->multiples.size(), 0);
for (size_t i = 0; i < dims.size(); ++i) {
dims[i] = i;
}
attr->dims = dims;
primitive->value.type = schema::PrimitiveType_Tile;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());

View File

@ -133,6 +133,10 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_no
if (output_tensors.size() != 1) {
for (size_t k = 0; k < output_tensors.size(); k++) {
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k);
if (used_node_list->empty()) {
MS_LOG(DEBUG) << "this output don't be used by other node.";
continue;
}
if (used_node_list->size() != 1) {
MS_LOG(ERROR) << " output must tuple_getitem";
return lite::RET_ERROR;