fix arithmetic compare, matmul, logicalnot, constant_folding_fusion

This commit is contained in:
gongdaguo 2020-11-25 22:06:39 +08:00
parent 14a51ef727
commit 815b7af9ec
23 changed files with 244 additions and 184 deletions

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto res = Arithmetic::InferShape(inputs_, outputs_);
if (res == RET_OK) {
auto output = outputs_.front();
output->set_data_type(TypeId::kNumberTypeBool);
return RET_OK;
} else {
return res;
}
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
namespace mindspore {
namespace lite {
class ArithmeticCompare : public Arithmetic {
public:
ArithmeticCompare() = default;
~ArithmeticCompare() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic);
explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_

View File

@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); }
Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator);
#endif
int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -20,21 +20,20 @@
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class Equal : public Arithmetic {
class Equal : public ArithmeticCompare {
public:
Equal() = default;
~Equal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, PrimitiveC);
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Equal, ArithmeticCompare);
explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); }
Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator);
#endif
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -20,21 +20,20 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class Greater : public Arithmetic {
class Greater : public ArithmeticCompare {
public:
Greater() = default;
~Greater() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, Arithmetic);
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Greater, ArithmeticCompare);
explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) {
Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator);
#endif
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -21,21 +21,20 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class GreaterEqual : public Arithmetic {
class GreaterEqual : public ArithmeticCompare {
public:
GreaterEqual() = default;
~GreaterEqual() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare);
explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry LessRegistry(schema::PrimitiveType_Less, LessCreator);
#endif
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -21,21 +21,20 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class Less : public Arithmetic {
class Less : public ArithmeticCompare {
public:
Less() = default;
~Less() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Less, Arithmetic);
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Less, ArithmeticCompare);
explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) {
}
Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator);
#endif
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -21,21 +21,20 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class LessEqual : public Arithmetic {
class LessEqual : public ArithmeticCompare {
public:
LessEqual() = default;
~LessEqual() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LessEqual, Arithmetic);
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(LessEqual, ArithmeticCompare);
explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
input0->set_shape(a_shape);
}
if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR;
bool del_start = false;
bool del_end = false;
if (a_shape.size() == 1) {
a_shape.insert(a_shape.begin(), 1);
input0->set_shape(a_shape);
del_start = true;
}
if (b_shape.size() == 1) {
b_shape.push_back(1);
input1->set_shape(b_shape);
del_end = true;
}
for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) {
if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) {
@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
}
std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
if (del_start) {
c_shape.erase(c_shape.begin());
}
if (del_end) {
c_shape.pop_back();
}
output->set_shape(c_shape);
return RET_OK;
}

View File

@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) {
Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator);
#endif
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -21,21 +21,20 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
class NotEqual : public Arithmetic {
class NotEqual : public ArithmeticCompare {
public:
NotEqual() = default;
~NotEqual() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NotEqual, Arithmetic);
explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(NotEqual, ArithmeticCompare);
explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic);
Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);

View File

@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel {
namespace {
typedef struct {
int primitive_type_;
ArithmeticCompareFp32Func func_;
} TYPE_FUNC_INFO;
} // namespace
ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) {
TYPE_FUNC_INFO type_func_table[] = {
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32},
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32},
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}};
for (size_t i = 0; i < sizeof(type_func_table); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
if (data_type_ == kDataTypeInt) {
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
}
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code;
if (data_type_ == kDataTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
}
}
return nullptr;
return RET_OK;
}
int ArithmeticCompareCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; }
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
int ArithmeticCompareCPUKernel::DoExecute(int task_id) {
if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) {
MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! ";
if (func_fp32_ == nullptr) {
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
return RET_ERROR;
}
int elements_num = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {
return RET_OK;
int error_code;
if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
}
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
if (error_code != RET_OK) {
return RET_ERROR;
}
// two inputs have the same shape, support broadcast later
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! ";
}
return ret;
}
int ArithmeticCompareRun(void *cdata, int task_id) {
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata);
auto ret = kernel->DoExecute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]";
}
return ret;
}
int ArithmeticCompareCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
}
return ret;
return RET_OK;
}
kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,

View File

@ -18,27 +18,57 @@
#include <vector>
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/arithmetic_compare_fp32.h"
namespace mindspore::kernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
public:
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticCompareFun(parameter->type_);
switch (parameter->type_) {
case PrimitiveType_Equal:
func_fp32_ = ElementEqualFp32;
func_int32_ = ElementEqualInt32;
break;
case PrimitiveType_NotEqual:
func_fp32_ = ElementNotEqualFp32;
func_int32_ = ElementNotEqualInt32;
break;
case PrimitiveType_Less:
func_fp32_ = ElementLessFp32;
func_int32_ = ElementLessInt32;
break;
case PrimitiveType_LessEqual:
func_fp32_ = ElementLessEqualFp32;
func_int32_ = ElementLessEqualInt32;
break;
case PrimitiveType_Greater:
func_fp32_ = ElementGreaterFp32;
func_int32_ = ElementGreaterInt32;
break;
case PrimitiveType_GreaterEqual:
func_fp32_ = ElementGreaterEqualFp32;
func_int32_ = ElementGreaterEqualInt32;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
func_fp32_ = nullptr;
func_int32_ = nullptr;
break;
}
}
~ArithmeticCompareCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
virtual int DoExecute(int task_id);
int DoArithmetic(int task_id) override;
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
private:
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type);
ArithmeticCompareFp32Func func_;
ArithmeticCompareFp32Func func_fp32_ = nullptr;
ArithmeticCompareIntFunc func_int32_ = nullptr;
};
int ArithmeticCompareRun(void *cdata, int task_id);
} // namespace mindspore::kernel

View File

@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() {
break;
}
break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
default:
break;
}

View File

@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel {
int PreProcess() override;
int ReSize() override;
int Run() override;
int DoArithmetic(int task_id);
virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
private:
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
protected:
int break_pos_ = 0;
int outside_ = 0;
int thread_count_ = 1;
ArithmeticParameter *arithmeticParameter_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
private:
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

View File

@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
return RET_ERROR;
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
} else {
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
}
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
}
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
}
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
} else {
auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),

View File

@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
output_tensors[m]->AddQuantParam(quant_arg);
}
}
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for (auto &input_tensor : input_tensors) {
input_tensor->set_format(schema::Format::Format_NHWC);
if (input_tensor->shape().size() == 4) {
MS_LOG(INFO) << "init input_tensor format to nhwc";
}
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto primitive = lite_primitive.get();
MS_ASSERT(primitive != nullptr);