Support logticnot bool operation,fix arithmetic

fix bool
This commit is contained in:
gongdaguo 2020-12-01 10:27:37 +08:00
parent e45f19adc8
commit 538b8faa3d
7 changed files with 53 additions and 27 deletions

View File

@ -91,6 +91,14 @@ int ElementLogicalNot(const float *input, float *output, const int element_size)
return NNACL_OK;
}
// logical_not:
int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = !input[i];
}
return NNACL_OK;
}
// round:
int ElementRound(const float *input, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {

View File

@ -42,6 +42,8 @@ int ElementSin(const float *input, float *output, const int element_size);
int ElementLogicalNot(const float *input, float *output, const int element_size);
int ElementLogicalNotBool(const bool *input, bool *output, const int element_size);
int ElementRound(const float *input, float *output, const int element_size);
int ElementFloor(const float *input, float *output, const int element_size);

View File

@ -80,6 +80,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (out_count <= 0) {
return RET_OK;
}
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),

View File

@ -82,27 +82,12 @@ int ArithmeticCPUKernel::ReSize() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->in_shape0_[i] == -1) {
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()),
in_tensors_[0]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
if (arithmeticParameter_->in_shape1_[i] == -1) {
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()),
in_tensors_[1]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->out_shape_[i] == -1) {
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()),
out_tensors_[0]->shape().size() * sizeof(int));
break;
}
}
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) {
@ -244,6 +229,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()),

View File

@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
return nullptr;
}
ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) {
if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) {
return ElementLogicalNotBool;
}
return nullptr;
}
int ArithmeticSelfCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) {
if (count <= 0) {
return RET_OK;
}
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
int ret = RET_ERROR;
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR;
}
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
ret = func_(input_ptr + offset, output_ptr + offset, count);
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
if (func_bool_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR;
}
bool *input_ptr = reinterpret_cast<bool *>(in_tensors_.at(0)->data_c());
bool *output_ptr = reinterpret_cast<bool *>(out_tensors_.at(0)->data_c());
ret = func_bool_(input_ptr + offset, output_ptr + offset, count);
} else {
MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << ".";
return RET_ERROR;
}
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! ";
}
@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)

View File

@ -34,6 +34,7 @@ using mindspore::schema::PrimitiveType_Square;
namespace mindspore::kernel {
typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size);
typedef int (*ArithmeticSelfBoolFunc)(const bool *input, bool *output, const int element_size);
class ArithmeticSelfCPUKernel : public LiteKernel {
public:
explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@ -41,6 +42,7 @@ class ArithmeticSelfCPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticSelfFun(parameter->type_);
func_bool_ = GetArithmeticSelfBoolFun(parameter->type_);
}
~ArithmeticSelfCPUKernel() override = default;
@ -51,7 +53,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel {
private:
ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type);
ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type);
ArithmeticSelfFunc func_;
ArithmeticSelfBoolFunc func_bool_;
};
int ArithmeticSelfRun(void *cdata, int task_id);
} // namespace mindspore::kernel

View File

@ -146,5 +146,4 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel