!9989 [MS][LITE]fix controlflow ops

From: @YeFeng_24
Reviewed-by: @zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-15 20:04:33 +08:00 committed by Gitee
commit 8e6571044b
10 changed files with 58 additions and 4 deletions

View File

@ -890,6 +890,14 @@ int ElementLogicalAnd(const float *input0, const float *input1, float *output, c
return NNACL_OK;
}
int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size) {
int index = 0;
for (; index < element_size; index++) {
output[index] = (int)((int)(input0[index]) & (int)(input1[index]));
}
return NNACL_OK;
}
int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size) {
ElementSub(input0, input1, output, element_size);
return ElementMul(output, output, output, element_size);

View File

@ -92,6 +92,7 @@ int BroadcastDiv(const float *input0, const float *input1, float *tile_input0, f
int element_size, ArithmeticParameter *param);
int ElementLogicalAnd(const float *input0, const float *input1, float *output, const int element_size);
int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size);
int BroadcastLogicalAnd(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
int element_size, ArithmeticParameter *param);

View File

@ -68,8 +68,12 @@ Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size() / 2; i++) {
outputs_[i]->set_data_type(inputs_[i]->data_type());
outputs_[i]->set_shape(inputs_[i]->shape());
}
return RET_OK;
}

View File

@ -70,6 +70,20 @@ PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return Primitive
Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
#endif
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < outputs_.size() / 2; i++) {
outputs_[i]->set_data_type(inputs_[i + 1]->data_type());
outputs_[i + outputs_.size() / 2]->set_data_type(inputs_[i + 1]->data_type());
outputs_[i]->set_shape(inputs_[i + 1]->shape());
outputs_[i + outputs_.size() / 2]->set_shape(inputs_[i + 1]->shape());
outputs_[i]->set_format(inputs_[i + 1]->format());
outputs_[i + outputs_.size() / 2]->set_format(inputs_[i + 1]->format());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -30,7 +30,7 @@ int SwitchCPUKernel::PostProcess() {
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
MS_ASSERT(bool_tensor->shape().size() == 1);
MS_ASSERT(bool_tensor->shape().front() == 1);
auto *active = static_cast<bool *>(bool_tensor->data_c());
auto active = static_cast<bool *>(bool_tensor->data_c());
if (active == nullptr) {
MS_LOG(ERROR) << "data of bool tensor is nullptr";
return lite::RET_NULL_PTR;

View File

@ -132,6 +132,7 @@ kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator)

View File

@ -190,6 +190,7 @@ void ArithmeticCPUKernel::InitRunFunction() {
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
arithmetic_run_int_ = ElementLogicalAndInt;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
@ -544,6 +545,7 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCre
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)

View File

@ -85,4 +85,5 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::Tensor *
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -340,8 +340,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
}
// when merge is removed, this if is removed automatically
if (kernel->Type() == schema::PrimitiveType_Merge) {
MS_ASSERT(kernel->in_kernels().size() == 2);
return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]);
return MergeOpIsReady(kernel, is_kernel_finish);
} else {
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; });
@ -371,6 +370,28 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
}
return RET_OK;
}
bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel,
std::map<const kernel::LiteKernel *, bool> is_kernel_finish) {
std::map<const lite::Tensor *, bool> merge_in_tensors_map;
for (auto merge_in_tensor : kernel->in_tensors()) {
merge_in_tensors_map[merge_in_tensor] = false;
if (merge_in_tensor->category() == Tensor::CONST_TENSOR || merge_in_tensor->category() == Tensor::CONST_SCALAR) {
merge_in_tensors_map[merge_in_tensor] = true;
}
for (auto merge_in_kernel : kernel->in_kernels()) {
for (auto tensor : merge_in_kernel->out_tensors()) {
if (tensor == merge_in_tensor && is_kernel_finish[merge_in_kernel]) {
merge_in_tensors_map[merge_in_tensor] = true;
}
}
}
}
auto kernel_in_tensors_num = kernel->in_tensors().size();
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().begin() + kernel_in_tensors_num / 2,
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }) ||
std::all_of(kernel->in_tensors().begin() + kernel_in_tensors_num / 2, kernel->in_tensors().end(),
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; });
}
kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
kernel::SubGraphType type) {

View File

@ -65,6 +65,8 @@ class Scheduler {
kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
kernel::SubGraphType type);
bool MergeOpIsReady(const kernel::LiteKernel *kernel, std::map<const kernel::LiteKernel *, bool> is_kernel_finish);
std::vector<kernel::LiteKernel *> FindAllSubGraphKernels(
kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);