forked from mindspore-Ecosystem/mindspore
fix switch op infershape bug && adding dataType check in arithmetic op
This commit is contained in:
parent
36ae3950eb
commit
ce4fe0bcf9
|
@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
|
|||
|
||||
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
|
||||
for (size_t i = 0; i < outputs_.size() / 2; i++) {
|
||||
auto *input = inputs_[i + 1];
|
||||
auto *output_true = outputs_[i];
|
||||
auto *output_false = outputs_[i + outputs_.size() / 2];
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (output_true == nullptr || output_false == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_true->set_data_type(input->data_type());
|
||||
output_false->set_data_type(input->data_type());
|
||||
output_true->set_format(input->format());
|
||||
output_false->set_format(input->format());
|
||||
auto data_type = input->data_type();
|
||||
if (data_type != kObjectTypeTensorType) {
|
||||
continue;
|
||||
} else {
|
||||
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
|
||||
auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true);
|
||||
auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false);
|
||||
output_true_tensorlist->set_element_shape(input_tensorlist->element_shape());
|
||||
output_false_tensorlist->set_element_shape(input_tensorlist->element_shape());
|
||||
output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
|
||||
output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
|
||||
output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
|
||||
output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
|
||||
}
|
||||
}
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
@ -88,12 +119,8 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
MS_LOG(ERROR) << "output tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_true->set_data_type(input->data_type());
|
||||
output_false->set_data_type(input->data_type());
|
||||
output_true->set_shape(input->shape());
|
||||
output_false->set_shape(input->shape());
|
||||
output_true->set_format(input->format());
|
||||
output_false->set_format(input->format());
|
||||
auto data_type = input->data_type();
|
||||
if (data_type != kObjectTypeTensorType) {
|
||||
continue;
|
||||
|
|
|
@ -118,7 +118,24 @@ void ArithmeticFP16CPUKernel::InitParam() {
|
|||
return;
|
||||
}
|
||||
|
||||
int ArithmeticFP16CPUKernel::CheckDataType() {
|
||||
auto in0_dataType = in_tensors_.at(0)->data_type();
|
||||
auto in1_dataType = in_tensors_.at(1)->data_type();
|
||||
if ((in0_dataType != kNumberTypeFloat16 && in0_dataType != kNumberTypeFloat32) ||
|
||||
(in1_dataType != kNumberTypeFloat16 && in1_dataType != kNumberTypeFloat32)) {
|
||||
MS_LOG(ERROR)
|
||||
<< "The dataTypes of input tensor0 and input tensor1 should be any of float16 and float32, otherwise got error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP16CPUKernel::ReSize() {
|
||||
if (CheckDataType() != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
InitParam();
|
||||
|
||||
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
|
||||
|
@ -131,6 +148,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|||
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (param_->broadcasting_) {
|
||||
outside_ = 1;
|
||||
for (int i = param_->ndim_ - 1; i >= 0; --i) {
|
||||
|
|
|
@ -46,6 +46,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int CheckDataType();
|
||||
int DoArithmetic(int task_id);
|
||||
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
|
||||
int out_thread_stride);
|
||||
|
|
|
@ -46,7 +46,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
|||
* and all need-broadcast-node are const
|
||||
* broadcast in resize */
|
||||
|
||||
if (arithmeticParameter_->broadcasting_ == false) {
|
||||
if (!arithmeticParameter_->broadcasting_) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -183,7 +183,21 @@ void ArithmeticCPUKernel::InitParam() {
|
|||
return;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::CheckDataType() {
|
||||
auto in0_dataType = in_tensors_.at(0)->data_type();
|
||||
auto in1_dataType = in_tensors_.at(1)->data_type();
|
||||
if (in0_dataType != in1_dataType) {
|
||||
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
if (CheckDataType() != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
InitParam();
|
||||
return InitBroadCastCase();
|
||||
}
|
||||
|
|
|
@ -80,9 +80,9 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
void InitRunFunction();
|
||||
void InitOptRunFunction();
|
||||
void InitParam();
|
||||
void FreeTmpPtr();
|
||||
int CheckDataType();
|
||||
int InitBroadCastCase();
|
||||
void InitParamInRunTime();
|
||||
bool CanBatchScalar();
|
||||
|
|
Loading…
Reference in New Issue