forked from mindspore-Ecosystem/mindspore
!9989 [MS][LITE]fix controlflow ops
From: @YeFeng_24 Reviewed-by: @zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
8e6571044b
|
@ -890,6 +890,14 @@ int ElementLogicalAnd(const float *input0, const float *input1, float *output, c
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size) {
|
||||
int index = 0;
|
||||
for (; index < element_size; index++) {
|
||||
output[index] = (int)((int)(input0[index]) & (int)(input1[index]));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size) {
|
||||
ElementSub(input0, input1, output, element_size);
|
||||
return ElementMul(output, output, output, element_size);
|
||||
|
|
|
@ -92,6 +92,7 @@ int BroadcastDiv(const float *input0, const float *input1, float *tile_input0, f
|
|||
int element_size, ArithmeticParameter *param);
|
||||
|
||||
int ElementLogicalAnd(const float *input0, const float *input1, float *output, const int element_size);
|
||||
int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size);
|
||||
int BroadcastLogicalAnd(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
|
||||
int element_size, ArithmeticParameter *param);
|
||||
|
||||
|
|
|
@ -68,8 +68,12 @@ Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
|
|||
|
||||
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_.size() / 2; i++) {
|
||||
outputs_[i]->set_data_type(inputs_[i]->data_type());
|
||||
outputs_[i]->set_shape(inputs_[i]->shape());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -70,6 +70,20 @@ PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return Primitive
|
|||
Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
|
||||
#endif
|
||||
|
||||
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
|
||||
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_.size() / 2; i++) {
|
||||
outputs_[i]->set_data_type(inputs_[i + 1]->data_type());
|
||||
outputs_[i + outputs_.size() / 2]->set_data_type(inputs_[i + 1]->data_type());
|
||||
outputs_[i]->set_shape(inputs_[i + 1]->shape());
|
||||
outputs_[i + outputs_.size() / 2]->set_shape(inputs_[i + 1]->shape());
|
||||
outputs_[i]->set_format(inputs_[i + 1]->format());
|
||||
outputs_[i + outputs_.size() / 2]->set_format(inputs_[i + 1]->format());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,7 @@ int SwitchCPUKernel::PostProcess() {
|
|||
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
|
||||
MS_ASSERT(bool_tensor->shape().size() == 1);
|
||||
MS_ASSERT(bool_tensor->shape().front() == 1);
|
||||
auto *active = static_cast<bool *>(bool_tensor->data_c());
|
||||
auto active = static_cast<bool *>(bool_tensor->data_c());
|
||||
if (active == nullptr) {
|
||||
MS_LOG(ERROR) << "data of bool tensor is nullptr";
|
||||
return lite::RET_NULL_PTR;
|
||||
|
|
|
@ -132,6 +132,7 @@ kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Less, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticCompareFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticCompareFp32KernelCreator)
|
||||
|
|
|
@ -190,6 +190,7 @@ void ArithmeticCPUKernel::InitRunFunction() {
|
|||
break;
|
||||
case PrimitiveType_LogicalAnd:
|
||||
arithmetic_run_ = ElementLogicalAnd;
|
||||
arithmetic_run_int_ = ElementLogicalAndInt;
|
||||
break;
|
||||
case PrimitiveType_LogicalOr:
|
||||
arithmetic_run_ = ElementLogicalOr;
|
||||
|
@ -544,6 +545,7 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCre
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)
|
||||
|
|
|
@ -85,4 +85,5 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::Tensor *
|
|||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -340,8 +340,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
}
|
||||
// when merge is removed, this if is removed automatically
|
||||
if (kernel->Type() == schema::PrimitiveType_Merge) {
|
||||
MS_ASSERT(kernel->in_kernels().size() == 2);
|
||||
return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]);
|
||||
return MergeOpIsReady(kernel, is_kernel_finish);
|
||||
} else {
|
||||
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
|
||||
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; });
|
||||
|
@ -371,6 +370,28 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel,
|
||||
std::map<const kernel::LiteKernel *, bool> is_kernel_finish) {
|
||||
std::map<const lite::Tensor *, bool> merge_in_tensors_map;
|
||||
for (auto merge_in_tensor : kernel->in_tensors()) {
|
||||
merge_in_tensors_map[merge_in_tensor] = false;
|
||||
if (merge_in_tensor->category() == Tensor::CONST_TENSOR || merge_in_tensor->category() == Tensor::CONST_SCALAR) {
|
||||
merge_in_tensors_map[merge_in_tensor] = true;
|
||||
}
|
||||
for (auto merge_in_kernel : kernel->in_kernels()) {
|
||||
for (auto tensor : merge_in_kernel->out_tensors()) {
|
||||
if (tensor == merge_in_tensor && is_kernel_finish[merge_in_kernel]) {
|
||||
merge_in_tensors_map[merge_in_tensor] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto kernel_in_tensors_num = kernel->in_tensors().size();
|
||||
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().begin() + kernel_in_tensors_num / 2,
|
||||
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }) ||
|
||||
std::all_of(kernel->in_tensors().begin() + kernel_in_tensors_num / 2, kernel->in_tensors().end(),
|
||||
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; });
|
||||
}
|
||||
|
||||
kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
|
||||
kernel::SubGraphType type) {
|
||||
|
|
|
@ -65,6 +65,8 @@ class Scheduler {
|
|||
kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
|
||||
kernel::SubGraphType type);
|
||||
|
||||
bool MergeOpIsReady(const kernel::LiteKernel *kernel, std::map<const kernel::LiteKernel *, bool> is_kernel_finish);
|
||||
|
||||
std::vector<kernel::LiteKernel *> FindAllSubGraphKernels(
|
||||
kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);
|
||||
|
||||
|
|
Loading…
Reference in New Issue