forked from mindspore-Ecosystem/mindspore
parent
e45f19adc8
commit
538b8faa3d
|
@ -91,6 +91,14 @@ int ElementLogicalNot(const float *input, float *output, const int element_size)
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// logical_not:
|
||||
int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = !input[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// round:
|
||||
int ElementRound(const float *input, float *output, const int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
|
|
@ -42,6 +42,8 @@ int ElementSin(const float *input, float *output, const int element_size);
|
|||
|
||||
int ElementLogicalNot(const float *input, float *output, const int element_size);
|
||||
|
||||
int ElementLogicalNotBool(const bool *input, bool *output, const int element_size);
|
||||
|
||||
int ElementRound(const float *input, float *output, const int element_size);
|
||||
|
||||
int ElementFloor(const float *input, float *output, const int element_size);
|
||||
|
|
|
@ -80,6 +80,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
|||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(
|
||||
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
||||
|
|
|
@ -82,27 +82,12 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
||||
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
|
||||
if (arithmeticParameter_->in_shape0_[i] == -1) {
|
||||
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()),
|
||||
in_tensors_[0]->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
|
||||
if (arithmeticParameter_->in_shape1_[i] == -1) {
|
||||
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()),
|
||||
in_tensors_[1]->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
|
||||
if (arithmeticParameter_->out_shape_[i] == -1) {
|
||||
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()),
|
||||
out_tensors_[0]->shape().size() * sizeof(int));
|
||||
break;
|
||||
}
|
||||
}
|
||||
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
|
||||
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
switch (arithmeticParameter_->op_parameter_.type_) {
|
||||
|
@ -244,6 +229,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
if (arithmeticParameter_->broadcasting_) { // need broadcast
|
||||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()),
|
||||
|
|
|
@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) {
|
||||
if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) {
|
||||
return ElementLogicalNotBool;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ArithmeticSelfCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) {
|
|||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Run function is null! ";
|
||||
int ret = RET_ERROR;
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||
if (func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Run function is null! ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
ret = func_(input_ptr + offset, output_ptr + offset, count);
|
||||
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
|
||||
if (func_bool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Run function is null! ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bool *input_ptr = reinterpret_cast<bool *>(in_tensors_.at(0)->data_c());
|
||||
bool *output_ptr = reinterpret_cast<bool *>(out_tensors_.at(0)->data_c());
|
||||
ret = func_bool_(input_ptr + offset, output_ptr + offset, count);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << ".";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
auto ret = func_(input_ptr + offset, output_ptr + offset, count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run failed, illegal input! ";
|
||||
}
|
||||
|
@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)
|
||||
|
|
|
@ -34,6 +34,7 @@ using mindspore::schema::PrimitiveType_Square;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size);
|
||||
typedef int (*ArithmeticSelfBoolFunc)(const bool *input, bool *output, const int element_size);
|
||||
class ArithmeticSelfCPUKernel : public LiteKernel {
|
||||
public:
|
||||
explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
|
@ -41,6 +42,7 @@ class ArithmeticSelfCPUKernel : public LiteKernel {
|
|||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
func_ = GetArithmeticSelfFun(parameter->type_);
|
||||
func_bool_ = GetArithmeticSelfBoolFun(parameter->type_);
|
||||
}
|
||||
~ArithmeticSelfCPUKernel() override = default;
|
||||
|
||||
|
@ -51,7 +53,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type);
|
||||
ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type);
|
||||
ArithmeticSelfFunc func_;
|
||||
ArithmeticSelfBoolFunc func_bool_;
|
||||
};
|
||||
int ArithmeticSelfRun(void *cdata, int task_id);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -146,5 +146,4 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
|
|||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue