!37190 fix some ops infer
Merge pull request !37190 from lianliguang/master
This commit is contained in:
commit
96d9911a90
|
@ -49,7 +49,8 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
(void)types.emplace("x", input_args[0]->BuildType());
|
(void)types.emplace("x", input_args[0]->BuildType());
|
||||||
(void)types.emplace("y", input_args[1]->BuildType());
|
(void)types.emplace("y", input_args[1]->BuildType());
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||||
|
return input_args[0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
(void)types.emplace("x", input_args[0]->BuildType());
|
(void)types.emplace("x", input_args[0]->BuildType());
|
||||||
(void)types.emplace("y", input_args[1]->BuildType());
|
(void)types.emplace("y", input_args[1]->BuildType());
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||||
|
return std::make_shared<TensorType>(kBool);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -47,8 +48,8 @@ AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
(void)InferType(primitive, input_args);
|
auto type = InferType(primitive, input_args);
|
||||||
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool);
|
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), type);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -49,7 +49,8 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
(void)types.emplace("x", input_args[0]->BuildType());
|
(void)types.emplace("x", input_args[0]->BuildType());
|
||||||
(void)types.emplace("y", input_args[1]->BuildType());
|
(void)types.emplace("y", input_args[1]->BuildType());
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||||
|
return input_args[0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue