forked from mindspore-Ecosystem/mindspore
fix arithmetic compare, matmul, logicalnot, constant_folding_fusion
This commit is contained in:
parent
14a51ef727
commit
815b7af9ec
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto res = Arithmetic::InferShape(inputs_, outputs_);
|
||||
if (res == RET_OK) {
|
||||
auto output = outputs_.front();
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
return RET_OK;
|
||||
} else {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArithmeticCompare : public Arithmetic {
|
||||
public:
|
||||
ArithmeticCompare() = default;
|
||||
~ArithmeticCompare() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic);
|
||||
explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
|
|
@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); }
|
||||
Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator);
|
||||
#endif
|
||||
int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,21 +20,20 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Equal : public Arithmetic {
|
||||
class Equal : public ArithmeticCompare {
|
||||
public:
|
||||
Equal() = default;
|
||||
~Equal() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Equal, PrimitiveC);
|
||||
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(Equal, ArithmeticCompare);
|
||||
explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); }
|
||||
Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator);
|
||||
#endif
|
||||
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,21 +20,20 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Greater : public Arithmetic {
|
||||
class Greater : public ArithmeticCompare {
|
||||
public:
|
||||
Greater() = default;
|
||||
~Greater() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Greater, Arithmetic);
|
||||
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(Greater, ArithmeticCompare);
|
||||
explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) {
|
|||
Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator);
|
||||
|
||||
#endif
|
||||
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,21 +21,20 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GreaterEqual : public Arithmetic {
|
||||
class GreaterEqual : public ArithmeticCompare {
|
||||
public:
|
||||
GreaterEqual() = default;
|
||||
~GreaterEqual() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
|
||||
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare);
|
||||
explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC:
|
|||
Registry LessRegistry(schema::PrimitiveType_Less, LessCreator);
|
||||
|
||||
#endif
|
||||
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,21 +21,20 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Less : public Arithmetic {
|
||||
class Less : public ArithmeticCompare {
|
||||
public:
|
||||
Less() = default;
|
||||
~Less() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Less, Arithmetic);
|
||||
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(Less, ArithmeticCompare);
|
||||
explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) {
|
|||
}
|
||||
Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator);
|
||||
#endif
|
||||
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,21 +21,20 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class LessEqual : public Arithmetic {
|
||||
class LessEqual : public ArithmeticCompare {
|
||||
public:
|
||||
LessEqual() = default;
|
||||
~LessEqual() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(LessEqual, Arithmetic);
|
||||
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(LessEqual, ArithmeticCompare);
|
||||
explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
input0->set_shape(a_shape);
|
||||
}
|
||||
|
||||
if (a_shape.size() < 2 || b_shape.size() < 2) {
|
||||
MS_LOG(ERROR) << "inputs shape is invalid";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
bool del_start = false;
|
||||
bool del_end = false;
|
||||
if (a_shape.size() == 1) {
|
||||
a_shape.insert(a_shape.begin(), 1);
|
||||
input0->set_shape(a_shape);
|
||||
del_start = true;
|
||||
}
|
||||
if (b_shape.size() == 1) {
|
||||
b_shape.push_back(1);
|
||||
input1->set_shape(b_shape);
|
||||
del_end = true;
|
||||
}
|
||||
for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) {
|
||||
if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) {
|
||||
|
@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
}
|
||||
std::vector<int> c_shape(a_shape);
|
||||
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
|
||||
if (del_start) {
|
||||
c_shape.erase(c_shape.begin());
|
||||
}
|
||||
if (del_end) {
|
||||
c_shape.pop_back();
|
||||
}
|
||||
output->set_shape(c_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) {
|
|||
Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator);
|
||||
|
||||
#endif
|
||||
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeBool);
|
||||
output->set_format(input->format());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,21 +21,20 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
|
||||
#include "src/ops/arithmetic.h"
|
||||
#include "src/ops/arithmetic_compare.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class NotEqual : public Arithmetic {
|
||||
class NotEqual : public ArithmeticCompare {
|
||||
public:
|
||||
NotEqual() = default;
|
||||
~NotEqual() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(NotEqual, Arithmetic);
|
||||
explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
MS_DECLARE_PARENT(NotEqual, ArithmeticCompare);
|
||||
explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme
|
|||
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic);
|
||||
Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
|
||||
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
|
||||
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
|
||||
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
|
||||
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
|
||||
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
|
||||
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
|
||||
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
|
||||
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
|
||||
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
|
||||
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
|
||||
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);
|
||||
|
|
|
@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual;
|
|||
using mindspore::schema::PrimitiveType_NotEqual;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
ArithmeticCompareFp32Func func_;
|
||||
} TYPE_FUNC_INFO;
|
||||
} // namespace
|
||||
|
||||
ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) {
|
||||
TYPE_FUNC_INFO type_func_table[] = {
|
||||
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32},
|
||||
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32},
|
||||
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}};
|
||||
for (size_t i = 0; i < sizeof(type_func_table); i++) {
|
||||
if (type_func_table[i].primitive_type_ == primitive_type) {
|
||||
return type_func_table[i].func_;
|
||||
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
||||
int out_thread_stride) {
|
||||
if (dim > break_pos_) {
|
||||
if (data_type_ == kDataTypeInt) {
|
||||
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
|
||||
reinterpret_cast<int *>(input1) + out_thread_stride,
|
||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
||||
}
|
||||
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
|
||||
reinterpret_cast<float *>(input1) + out_thread_stride,
|
||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
||||
}
|
||||
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
|
||||
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
|
||||
int error_code;
|
||||
if (data_type_ == kDataTypeInt) {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
|
||||
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
|
||||
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
|
||||
dim + 1, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
|
||||
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
|
||||
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
|
||||
dim + 1, out_count, out_thread_stride);
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return error_code;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
|
||||
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; }
|
||||
MS_ASSERT(thread_count_ != 0);
|
||||
int stride = UP_DIV(element_num, thread_count_);
|
||||
int count = MSMIN(stride, element_num - stride * task_id);
|
||||
|
||||
int ArithmeticCompareCPUKernel::DoExecute(int task_id) {
|
||||
if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) {
|
||||
MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! ";
|
||||
if (func_fp32_ == nullptr) {
|
||||
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int elements_num = in_tensors_.at(0)->ElementsNum();
|
||||
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
|
||||
int offset = task_id * stride;
|
||||
int count = MSMIN(stride, elements_num - offset);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
|
||||
int error_code;
|
||||
if (arithmeticParameter_->broadcasting_) { // need broadcast
|
||||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(
|
||||
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(
|
||||
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
} else { // no broadcast, neither is scalar, two same shape
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
} else {
|
||||
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
}
|
||||
}
|
||||
if (func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Run function is null! ";
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
// two inputs have the same shape, support broadcast later
|
||||
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
|
||||
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run failed, illegal input! ";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareRun(void *cdata, int task_id) {
|
||||
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata);
|
||||
auto ret = kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
|
|
|
@ -18,27 +18,57 @@
|
|||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/fp32/arithmetic_compare_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
|
||||
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
|
||||
public:
|
||||
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
func_ = GetArithmeticCompareFun(parameter->type_);
|
||||
switch (parameter->type_) {
|
||||
case PrimitiveType_Equal:
|
||||
func_fp32_ = ElementEqualFp32;
|
||||
func_int32_ = ElementEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_NotEqual:
|
||||
func_fp32_ = ElementNotEqualFp32;
|
||||
func_int32_ = ElementNotEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_Less:
|
||||
func_fp32_ = ElementLessFp32;
|
||||
func_int32_ = ElementLessInt32;
|
||||
break;
|
||||
case PrimitiveType_LessEqual:
|
||||
func_fp32_ = ElementLessEqualFp32;
|
||||
func_int32_ = ElementLessEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_Greater:
|
||||
func_fp32_ = ElementGreaterFp32;
|
||||
func_int32_ = ElementGreaterInt32;
|
||||
break;
|
||||
case PrimitiveType_GreaterEqual:
|
||||
func_fp32_ = ElementGreaterEqualFp32;
|
||||
func_int32_ = ElementGreaterEqualInt32;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
|
||||
func_fp32_ = nullptr;
|
||||
func_int32_ = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
~ArithmeticCompareCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
virtual int DoExecute(int task_id);
|
||||
int DoArithmetic(int task_id) override;
|
||||
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
|
||||
|
||||
private:
|
||||
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type);
|
||||
ArithmeticCompareFp32Func func_;
|
||||
ArithmeticCompareFp32Func func_fp32_ = nullptr;
|
||||
ArithmeticCompareIntFunc func_int32_ = nullptr;
|
||||
};
|
||||
int ArithmeticCompareRun(void *cdata, int task_id);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Equal:
|
||||
case PrimitiveType_Less:
|
||||
case PrimitiveType_Greater:
|
||||
case PrimitiveType_NotEqual:
|
||||
case PrimitiveType_LessEqual:
|
||||
case PrimitiveType_GreaterEqual:
|
||||
arithmetic_opt_run_ = nullptr;
|
||||
arithmetic_opt_run_int_ = nullptr;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
int PreProcess() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoArithmetic(int task_id);
|
||||
virtual int DoArithmetic(int task_id);
|
||||
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
||||
|
||||
private:
|
||||
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
||||
protected:
|
||||
int break_pos_ = 0;
|
||||
int outside_ = 0;
|
||||
int thread_count_ = 1;
|
||||
ArithmeticParameter *arithmeticParameter_ = nullptr;
|
||||
LiteDataType data_type_ = kDataTypeFloat;
|
||||
|
||||
private:
|
||||
ArithmeticRun arithmetic_run_ = nullptr;
|
||||
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
||||
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
||||
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
|
||||
LiteDataType data_type_ = kDataTypeFloat;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_
|
||||
|
|
|
@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
|
|||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
||||
if (nodeIter == onnx_graph.initializer().end()) {
|
||||
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
|
||||
return RET_ERROR;
|
||||
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
|
||||
} else {
|
||||
std::vector<int> weight_shape;
|
||||
auto size = (*nodeIter).dims_size();
|
||||
weight_shape.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
weight_shape.emplace_back((*nodeIter).dims(i));
|
||||
}
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->channelIn = weight_shape[1] * attr->group;
|
||||
}
|
||||
std::vector<int> weight_shape;
|
||||
auto size = (*nodeIter).dims_size();
|
||||
weight_shape.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
weight_shape.emplace_back((*nodeIter).dims(i));
|
||||
}
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->channelIn = weight_shape[1] * attr->group;
|
||||
} else {
|
||||
auto nodeIter =
|
||||
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
|
||||
|
|
|
@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
output_tensors[m]->AddQuantParam(quant_arg);
|
||||
}
|
||||
}
|
||||
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
||||
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
||||
// Others should be added in future.
|
||||
for (auto &input_tensor : input_tensors) {
|
||||
input_tensor->set_format(schema::Format::Format_NHWC);
|
||||
if (input_tensor->shape().size() == 4) {
|
||||
MS_LOG(INFO) << "init input_tensor format to nhwc";
|
||||
}
|
||||
}
|
||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||
auto primitive = lite_primitive.get();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
|
Loading…
Reference in New Issue