forked from mindspore-Ecosystem/mindspore
!9347 [lite] add reciprocal op and adjust tile、split
From: @xu_anyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c04304337a
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include "nnacl/fp16/arithmetic_self_fp16.h"
|
||||
|
||||
|
@ -108,3 +109,11 @@ int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) {
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
assert(input[i] != 0.0f);
|
||||
output[i] = 1.f / input[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -48,6 +48,8 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size);
|
|||
int ElementCeilFp16(float16_t *input, float16_t *output, int number);
|
||||
|
||||
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size);
|
||||
|
||||
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include "nnacl/fp32/arithmetic_self_fp32.h"
|
||||
|
||||
// abs:
|
||||
|
@ -128,3 +129,11 @@ int ElementNegative(const float *input, float *output, const int element_size) {
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementReciprocal(const float *input, float *output, const int element_size) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
assert(input[i] != 0.0f);
|
||||
output[i] = 1.f / input[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -51,6 +51,8 @@ int ElementFloor(const float *input, float *output, const int element_size);
|
|||
int ElementCeil(const float *input, float *output, const int number);
|
||||
|
||||
int ElementNegative(const float *input, float *output, const int element_size);
|
||||
|
||||
int ElementReciprocal(const float *input, float *output, const int element_size);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include "nnacl/int8/arithmetic_self_int8.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
|
@ -278,3 +279,24 @@ int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, Arith
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) {
|
||||
float in_scale = para.in_args_.scale_;
|
||||
int32_t in_zp = para.in_args_.zp_;
|
||||
float out_scale = para.out_args_.scale_;
|
||||
int32_t out_zp = para.out_args_.zp_;
|
||||
float bias = in_zp * in_scale;
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
float input_f32 = input[i] * in_scale + bias;
|
||||
assert(input_f32 != 0.0f);
|
||||
int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp;
|
||||
if (output_tmp > para.output_activation_max_) {
|
||||
output[i] = para.output_activation_max_;
|
||||
} else if (output_tmp < para.output_activation_min_) {
|
||||
output[i] = para.output_activation_min_;
|
||||
} else {
|
||||
output[i] = (int8_t)output_tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -50,6 +50,8 @@ int Int8ElementSquare(int8_t *input, int8_t *output, int element_size, ArithSelf
|
|||
|
||||
int Int8ElementLogicalNot(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
|
||||
|
||||
int Int8ElementReciprocal(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -253,7 +253,8 @@ union PrimitiveType {
|
|||
All,
|
||||
Assert,
|
||||
Adder,
|
||||
SparseSoftmaxCrossEntropy
|
||||
SparseSoftmaxCrossEntropy,
|
||||
Reciprocal,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -1203,3 +1203,6 @@ table All {
|
|||
table Assert {
|
||||
summarize : int;
|
||||
}
|
||||
|
||||
table Reciprocal {
|
||||
}
|
|
@ -375,11 +375,11 @@ void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output
|
|||
|
||||
int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
if (inputs_.size() != 2 && inputs_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Add should has two or three inputs";
|
||||
MS_LOG(ERROR) << "Conv2d should has two or three inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputs_.size() != 1) {
|
||||
MS_LOG(ERROR) << "Add should has one outputs";
|
||||
MS_LOG(ERROR) << "Conv2d should has one outputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *input_tensor = inputs_.front();
|
||||
|
|
|
@ -47,6 +47,7 @@ Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateA
|
|||
Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf);
|
||||
Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf);
|
||||
Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf);
|
||||
Registry ReciprocalParameterRegistry(schema::PrimitiveType_Reciprocal, PopulateArithmeticSelf);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,7 +31,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
|
|||
memset(split_param, 0, sizeof(SplitParameter));
|
||||
auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
split_param->op_parameter_.type_ = primitive->Type();
|
||||
split_param->num_split_ = param->GetNumberSplit();
|
||||
split_param->num_split_ = param->num_split();
|
||||
if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
|
||||
MS_LOG(ERROR) << "The value of split_param->num_split_ is too big";
|
||||
return nullptr;
|
||||
|
@ -44,7 +44,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
|
|||
}
|
||||
memset(split_sizes, 0, split_param->num_split_ * sizeof(int));
|
||||
split_param->split_sizes_ = split_sizes;
|
||||
auto split_sizes_vector_ = param->GetSizeSplits();
|
||||
auto split_sizes_vector_ = param->size_splits();
|
||||
int i = 0;
|
||||
for (int &iter : split_sizes_vector_) {
|
||||
split_param->split_sizes_[i++] = iter;
|
||||
|
|
|
@ -43,8 +43,10 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
for (size_t i = 0; i < kDimension_4d; ++i) {
|
||||
tile_param->multiples_[i] = 1;
|
||||
}
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
tile_param->multiples_[dims.at(i)] = multiples.at(i);
|
||||
if (!dims.empty() && !multiples.empty()) {
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
tile_param->multiples_[dims[i]] = multiples[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return reinterpret_cast<OpParameter *>(tile_param);
|
||||
|
|
|
@ -148,6 +148,7 @@
|
|||
#include "src/ops/while.h"
|
||||
#include "src/ops/oneslike.h"
|
||||
#include "src/ops/unsorted_segment_sum.h"
|
||||
#include "src/ops/reciprocal.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -888,6 +889,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) Quant(primitive);
|
||||
case schema::PrimitiveType_OnnxInt8Dequantize:
|
||||
return new (std::nothrow) Dequant(primitive);
|
||||
case schema::PrimitiveType_Reciprocal:
|
||||
return new (std::nothrow) Reciprocal(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/reciprocal.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<Reciprocal>(primitive);
|
||||
}
|
||||
Registry ReciprocalRegistry(schema::PrimitiveType_Reciprocal, ReciprocalCreator);
|
||||
#endif
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
||||
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Reciprocal : public ArithmeticSelf {
|
||||
public:
|
||||
Reciprocal() = default;
|
||||
~Reciprocal() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf);
|
||||
explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateReciprocal(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reciprocal, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; }
|
||||
std::vector<int> Split::GetSizeSplits() const { return this->primitive_->value.AsSplit()->sizeSplits; }
|
||||
std::vector<int> Split::GetSizeSplit() const { return this->primitive_->value.AsSplit()->sizeSplits; }
|
||||
int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; }
|
||||
|
||||
void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; }
|
||||
|
@ -67,7 +67,7 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
|
|||
#else
|
||||
|
||||
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
|
||||
std::vector<int> Split::GetSizeSplits() const {
|
||||
std::vector<int> Split::GetSizeSplit() const {
|
||||
auto fb_vector = this->primitive_->value_as_Split()->sizeSplits();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
@ -108,42 +108,50 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
if (output == nullptr) {
|
||||
MS_LOG(ERROR) << "output null pointer dereferencing.";
|
||||
if (outputs_.empty()) {
|
||||
MS_LOG(ERROR) << "split has no output.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int number_split = GetNumberSplit();
|
||||
if (static_cast<int>(outputs_.size()) != number_split) {
|
||||
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
outputs_.at(i)->set_data_type(input->data_type());
|
||||
outputs_.at(i)->set_format(input->format());
|
||||
for (auto &output : outputs_) {
|
||||
output->set_data_type(input->data_type());
|
||||
output->set_format(input->format());
|
||||
}
|
||||
size_splits_ = GetSizeSplit();
|
||||
num_split_ = GetNumberSplit() == 0 ? static_cast<int>(outputs_.size()) : GetNumberSplit();
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
size_t split_dim = GetSplitDim() == -1 ? input->shape().size() - 1 : GetSplitDim();
|
||||
size_t split_dim = GetSplitDim() < 0 ? input->shape().size() + GetSplitDim() : GetSplitDim();
|
||||
std::vector<int> input_shape = input->shape();
|
||||
std::vector<int> size_split;
|
||||
for (size_t i = 0; i < GetSizeSplits().size(); ++i) {
|
||||
size_split.push_back(GetSizeSplits().at(i));
|
||||
if (split_dim > input_shape.size()) {
|
||||
MS_LOG(ERROR) << "split dim is out of range, which is " << input_shape.size();
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
if (static_cast<int>(outputs_.size()) != num_split_) {
|
||||
MS_LOG(ERROR) << "outputs number is not equal to " << num_split_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (size_splits_.empty()) {
|
||||
if (input_shape[split_dim] % num_split_ != 0) {
|
||||
MS_LOG(ERROR) << "cannot split to equal size, which dim is " << input_shape[split_dim] << ", num split is "
|
||||
<< num_split_;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
for (int i = 0; i < num_split_; ++i) {
|
||||
size_splits_.push_back(input_shape[split_dim] / num_split_);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_split_; ++i) {
|
||||
std::vector<int> output_shape;
|
||||
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
||||
int split_dim_i = input_shape.at(split_dim);
|
||||
// support split size is -1 in the end.
|
||||
if (size_split.empty()) {
|
||||
split_dim_i = input_shape.at(split_dim) / number_split;
|
||||
} else if (i == number_split - 1 && size_split.at(i) == -1) {
|
||||
for (size_t j = 0; j < size_split.size() - 1; ++j) {
|
||||
split_dim_i -= size_split.at(j);
|
||||
if (i == num_split_ - 1 && size_splits_[i] == -1) {
|
||||
for (size_t j = 0; j < size_splits_.size() - 1; ++j) {
|
||||
split_dim_i -= size_splits_[j];
|
||||
}
|
||||
} else {
|
||||
split_dim_i = size_split.at(i);
|
||||
split_dim_i = size_splits_[i];
|
||||
}
|
||||
output_shape.at(split_dim) = split_dim_i;
|
||||
outputs_.at(i)->set_shape(output_shape);
|
||||
|
|
|
@ -42,8 +42,14 @@ class Split : public PrimitiveC {
|
|||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetNumberSplit() const;
|
||||
std::vector<int> GetSizeSplits() const;
|
||||
std::vector<int> GetSizeSplit() const;
|
||||
int GetSplitDim() const;
|
||||
int num_split() const { return num_split_; }
|
||||
std::vector<int> size_splits() const { return size_splits_; }
|
||||
|
||||
protected:
|
||||
int num_split_ = 0;
|
||||
std::vector<int> size_splits_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -139,8 +139,22 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> multiples = GetMultiples();
|
||||
|
||||
std::vector<int> multiples;
|
||||
if (inputs_.size() == 2) {
|
||||
if (inputs_[1]->data_c() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
int data_num = inputs_[1]->ElementsNum();
|
||||
if (data_num > static_cast<int>(input->shape().size())) {
|
||||
MS_LOG(ERROR) << "multiples data num cannot be larger than input shape size.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
multiples.resize(data_num);
|
||||
memcpy(multiples.data(), inputs_[1]->data_c(), inputs_[1]->Size());
|
||||
} else {
|
||||
multiples = GetMultiples();
|
||||
}
|
||||
#ifdef SUPPORT_TRAIN
|
||||
const size_t in_dims = input->shape().size();
|
||||
const size_t delta_dims = in_dims - multiples.size();
|
||||
|
@ -156,6 +170,11 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
}
|
||||
#else
|
||||
std::vector<int> dims = GetDims();
|
||||
if (inputs_.size() == 2 && dims.empty()) {
|
||||
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
|
||||
dims.push_back(dim);
|
||||
}
|
||||
}
|
||||
const size_t in_dims = input->shape().size();
|
||||
|
||||
MS_ASSERT(multiples.size() == dims.size());
|
||||
|
|
|
@ -38,7 +38,7 @@ int SplitBaseCPUKernel::ReSize() {
|
|||
auto input_shape = in_tensor->shape();
|
||||
|
||||
MS_ASSERT(param);
|
||||
MS_ASSERT(input_shape.size() >= 2 && input_shape.size() <= SPLIT_STRIDES_SIZE);
|
||||
MS_ASSERT(input_shape.size() >= 1 && input_shape.size() <= SPLIT_STRIDES_SIZE);
|
||||
param->strides_[input_shape.size() - 1] = 1;
|
||||
for (int i = input_shape.size() - 2; i >= 0; i--) {
|
||||
param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1);
|
||||
|
@ -50,8 +50,8 @@ int SplitBaseCPUKernel::ReSize() {
|
|||
param->n_dims_ = input_shape.size();
|
||||
|
||||
if (param->split_sizes_[0] == 0) {
|
||||
MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) < input_shape.size());
|
||||
if (input_shape.at(param->split_dim_) % param->num_split_ != 0) {
|
||||
MS_ASSERT(param->num_split_ > 0 && static_cast<int>(param->num_split_) <= input_shape[param->split_dim_]);
|
||||
if (input_shape[param->split_dim_] % param->num_split_ != 0) {
|
||||
MS_LOG(ERROR) << "Default split size is not usable.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,8 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int
|
|||
{mindspore::schema::PrimitiveType_Floor, ElementFloorFp16},
|
||||
{mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16},
|
||||
{mindspore::schema::PrimitiveType_Round, ElementRoundFp16},
|
||||
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}};
|
||||
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16},
|
||||
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16}};
|
||||
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
|
||||
if (type_func_table[i].primitive_type_ == primitive_type) {
|
||||
return type_func_table[i].func_;
|
||||
|
@ -139,4 +140,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Floor, CpuArithmeticSelfFp16K
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reciprocal, CpuArithmeticSelfFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -41,7 +41,8 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
|
|||
{mindspore::schema::PrimitiveType_Floor, ElementFloor},
|
||||
{mindspore::schema::PrimitiveType_Ceil, ElementCeil},
|
||||
{mindspore::schema::PrimitiveType_Round, ElementRound},
|
||||
{mindspore::schema::PrimitiveType_Neg, ElementNegative}};
|
||||
{mindspore::schema::PrimitiveType_Neg, ElementNegative},
|
||||
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocal}};
|
||||
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
|
||||
if (type_func_table[i].primitive_type_ == primitive_type) {
|
||||
return type_func_table[i].func_;
|
||||
|
@ -152,4 +153,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32K
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Neg, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reciprocal, CpuArithmeticSelfFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -26,6 +26,7 @@ using mindspore::schema::PrimitiveType_Floor;
|
|||
using mindspore::schema::PrimitiveType_Log;
|
||||
using mindspore::schema::PrimitiveType_LogicalNot;
|
||||
using mindspore::schema::PrimitiveType_Neg;
|
||||
using mindspore::schema::PrimitiveType_Reciprocal;
|
||||
using mindspore::schema::PrimitiveType_Round;
|
||||
using mindspore::schema::PrimitiveType_Rsqrt;
|
||||
using mindspore::schema::PrimitiveType_Sin;
|
||||
|
|
|
@ -24,6 +24,9 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_Tile;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr size_t kDoubleInputsSize = 2;
|
||||
}
|
||||
int TileCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
@ -42,6 +45,17 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
|
|||
int TileCPUKernel::ReSize() {
|
||||
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
|
||||
MS_ASSERT(tile_parameter_);
|
||||
if (in_tensors_.size() == kDoubleInputsSize) {
|
||||
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
|
||||
MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size.";
|
||||
return false;
|
||||
}
|
||||
auto input1_addr = reinterpret_cast<int *>(in_tensors_[1]->data_c());
|
||||
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
|
||||
tile_parameter_->dims_[i] = i;
|
||||
tile_parameter_->multiples_[i] = input1_addr[i];
|
||||
}
|
||||
}
|
||||
tile_parameter_->in_dim_ = in_tensors_.at(0)->shape().size();
|
||||
for (int i = 0; i < tile_parameter_->in_dim_; ++i) {
|
||||
tile_parameter_->in_shape_[i] = in_tensors_.at(0)->shape().at(i);
|
||||
|
@ -93,4 +107,5 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector<lite::Tensor *> &
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
|
|||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reciprocal, CpuArithmeticSelfInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -31,6 +31,7 @@ using mindspore::schema::PrimitiveType_Cos;
|
|||
using mindspore::schema::PrimitiveType_Floor;
|
||||
using mindspore::schema::PrimitiveType_Log;
|
||||
using mindspore::schema::PrimitiveType_LogicalNot;
|
||||
using mindspore::schema::PrimitiveType_Reciprocal;
|
||||
using mindspore::schema::PrimitiveType_Round;
|
||||
using mindspore::schema::PrimitiveType_Rsqrt;
|
||||
using mindspore::schema::PrimitiveType_Sin;
|
||||
|
@ -80,6 +81,8 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel {
|
|||
case PrimitiveType_LogicalNot:
|
||||
arithmeticSelf_run_ = Int8ElementLogicalNot;
|
||||
break;
|
||||
case PrimitiveType_Reciprocal:
|
||||
arithmeticSelf_run_ = Int8ElementReciprocal;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -690,6 +690,29 @@ STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ReciprocalParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto attr = std::make_unique<schema::ReciprocalT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Reciprocal;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
|
||||
|
@ -720,5 +743,6 @@ OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser());
|
|||
OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser());
|
||||
OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser());
|
||||
OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser());
|
||||
OnnxNodeRegistrar g_onnxReciprocalParser("Reciprocal", new OnnxReciprocalParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -217,6 +217,13 @@ class OnnxRoundParser : public OnnxNodeParser {
|
|||
~OnnxRoundParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxReciprocalParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {}
|
||||
~OnnxReciprocalParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
|
|
|
@ -41,8 +41,11 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "to") {
|
||||
attr->dstT = static_cast<int32_t>(
|
||||
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
|
||||
auto dst_type = OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i()));
|
||||
if (dst_type == kNumberTypeInt64) {
|
||||
dst_type = kNumberTypeInt32;
|
||||
}
|
||||
attr->dstT = static_cast<int>(dst_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
|
|||
MS_LOG(ERROR) << "new tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor->dataType = data_type;
|
||||
tensor->dataType = data_type == kNumberTypeInt64 ? kNumberTypeInt32 : data_type;
|
||||
tensor->dims = GetDimsFromOnnxValue(proto);
|
||||
tensor->format = schema::Format::Format_NCHW;
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
|
@ -370,7 +370,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
|
|||
MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name;
|
||||
return;
|
||||
}
|
||||
quant_param->inited = true;
|
||||
int argNum = 0;
|
||||
for (const auto &onnx_node_attr : node.attribute()) {
|
||||
if (onnx_node_attr.name() == "Y_scale") {
|
||||
|
@ -382,11 +381,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
|
|||
}
|
||||
}
|
||||
if (argNum != 2) {
|
||||
quant_param->scale = FLT_MAX;
|
||||
quant_param->zeroPoint = 0;
|
||||
quant_param->min = FLT_MAX;
|
||||
quant_param->max = FLT_MAX;
|
||||
quant_param->inited = false;
|
||||
continue;
|
||||
}
|
||||
dst_tensor->quantParams.emplace_back(std::move(quant_param));
|
||||
if (argNum == 2) {
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_slice_parser.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
|
@ -46,6 +48,35 @@ STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) {
|
||||
if (onnx_val == nullptr) {
|
||||
MS_LOG(ERROR) << "input vector is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot get tensorcache.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name);
|
||||
if (index == -1) {
|
||||
MS_LOG(ERROR) << "can not find node: " << name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
|
||||
if (input_tensor->data.empty()) {
|
||||
MS_LOG(DEBUG) << "data is empty.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>());
|
||||
onnx_val->resize(data_num);
|
||||
if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) !=
|
||||
EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SliceParser";
|
||||
|
@ -97,6 +128,36 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
}
|
||||
int status = RET_OK;
|
||||
switch (onnx_node.input_size()) {
|
||||
case 5: {
|
||||
if (steps.empty()) {
|
||||
status = GetInputTensor(&steps, onnx_node.input(4));
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
if (status != RET_ERROR && axes.empty()) {
|
||||
status = GetInputTensor(&axes, onnx_node.input(3));
|
||||
}
|
||||
}
|
||||
case 3: {
|
||||
if (status != RET_ERROR && ends.empty()) {
|
||||
status = GetInputTensor(&ends, onnx_node.input(2));
|
||||
}
|
||||
}
|
||||
case 2: {
|
||||
if (status != RET_ERROR && starts.empty()) {
|
||||
status = GetInputTensor(&starts, onnx_node.input(1));
|
||||
}
|
||||
}
|
||||
default: {
|
||||
if (status == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "onnx slice inputs are invalid.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (axes.empty()) {
|
||||
for (size_t i = 0; i < starts.size(); ++i) {
|
||||
axes.push_back(i);
|
||||
|
@ -112,7 +173,6 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
int insert_num = 5 - onnx_node.input_size();
|
||||
int status = RET_OK;
|
||||
switch (insert_num) {
|
||||
case 4: {
|
||||
std::string name = "slice/starts/";
|
||||
|
|
|
@ -32,6 +32,7 @@ class OnnxSliceParser : public OnnxNodeParser {
|
|||
|
||||
STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,7 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->splitDim = 0;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axis") {
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
#include "tools/converter/parser/onnx/onnx_tile_parser.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -39,26 +37,6 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &onnx_tile_multiple = onnx_node.input(1);
|
||||
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple);
|
||||
if (index == -1) {
|
||||
MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
|
||||
if (tile_attr->data.data() == nullptr) {
|
||||
MS_LOG(ERROR) << "power's attr pow can't be obtained.";
|
||||
return RET_INVALID_OP_ATTR;
|
||||
}
|
||||
int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies<int>());
|
||||
std::vector<int> multiples;
|
||||
std::vector<int> dims;
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
multiples.push_back(reinterpret_cast<int *>(tile_attr->data.data())[i]);
|
||||
dims.push_back(i);
|
||||
}
|
||||
attr->multiples = multiples;
|
||||
attr->dims = dims;
|
||||
op->primitive->value.type = schema::PrimitiveType_Tile;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
|
|
|
@ -23,7 +23,6 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
|
@ -35,16 +34,6 @@ PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr<tflite::O
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->multiples)) {
|
||||
MS_LOG(ERROR) << "get tile -> multiples failed";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> dims(attr->multiples.size(), 0);
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = i;
|
||||
}
|
||||
attr->dims = dims;
|
||||
primitive->value.type = schema::PrimitiveType_Tile;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
|
|
|
@ -133,6 +133,10 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_no
|
|||
if (output_tensors.size() != 1) {
|
||||
for (size_t k = 0; k < output_tensors.size(); k++) {
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k);
|
||||
if (used_node_list->empty()) {
|
||||
MS_LOG(DEBUG) << "this output don't be used by other node.";
|
||||
continue;
|
||||
}
|
||||
if (used_node_list->size() != 1) {
|
||||
MS_LOG(ERROR) << " output must tuple_getitem";
|
||||
return lite::RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue