fix switch op infershape bug && adding dataType check in arithmetic op

This commit is contained in:
fuzhiye 2021-02-20 15:05:15 +08:00
parent 36ae3950eb
commit ce4fe0bcf9
5 changed files with 66 additions and 6 deletions

View File

@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
for (size_t i = 0; i < outputs_.size() / 2; i++) {
auto *input = inputs_[i + 1];
auto *output_true = outputs_[i];
auto *output_false = outputs_[i + outputs_.size() / 2];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output_true == nullptr || output_false == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output_true->set_data_type(input->data_type());
output_false->set_data_type(input->data_type());
output_true->set_format(input->format());
output_false->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true);
auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false);
output_true_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_false_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}
@ -88,12 +119,8 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output_true->set_data_type(input->data_type());
output_false->set_data_type(input->data_type());
output_true->set_shape(input->shape());
output_false->set_shape(input->shape());
output_true->set_format(input->format());
output_false->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;

View File

@ -118,7 +118,24 @@ void ArithmeticFP16CPUKernel::InitParam() {
return;
}
int ArithmeticFP16CPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if ((in0_dataType != kNumberTypeFloat16 && in0_dataType != kNumberTypeFloat32) ||
(in1_dataType != kNumberTypeFloat16 && in1_dataType != kNumberTypeFloat32)) {
MS_LOG(ERROR)
<< "The dataTypes of input tensor0 and input tensor1 should be any of float16 and float32, otherwise got error.";
return RET_ERROR;
}
return RET_OK;
}
int ArithmeticFP16CPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed.";
return RET_ERROR;
}
InitParam();
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
@ -131,6 +148,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
return RET_ERROR;
}
if (param_->broadcasting_) {
outside_ = 1;
for (int i = param_->ndim_ - 1; i >= 0; --i) {

View File

@ -46,6 +46,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override;
int CheckDataType();
int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride);

View File

@ -46,7 +46,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
* and all need-broadcast-node are const
* broadcast in resize */
if (arithmeticParameter_->broadcasting_ == false) {
if (!arithmeticParameter_->broadcasting_) {
return RET_OK;
}
@ -183,7 +183,21 @@ void ArithmeticCPUKernel::InitParam() {
return;
}
int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
return RET_ERROR;
}
return RET_OK;
}
int ArithmeticCPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed.";
return RET_ERROR;
}
InitParam();
return InitBroadCastCase();
}

View File

@ -80,9 +80,9 @@ class ArithmeticCPUKernel : public LiteKernel {
private:
void InitRunFunction();
void InitOptRunFunction();
void InitParam();
void FreeTmpPtr();
int CheckDataType();
int InitBroadCastCase();
void InitParamInRunTime();
bool CanBatchScalar();