!15385 remove the useless prim type

From: @lianliguang
Reviewed-by: @ginfung,@zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-04-20 16:07:45 +08:00 committed by Gitee
commit 490d2e1efb
5 changed files with 12 additions and 14 deletions

View File

@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
}
// Call checking method 'infer_value' for python primitive
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
}
}
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
return EvalPyCheckPrim(engine, args);
}
auto context = MsContext::GetInstance();

View File

@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;
}
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
return;
}

View File

@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
.value("builtin", PrimType::kPrimTypeBuiltIn)
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
.value("py_infer_shape", PrimType::kPrimTypePyInfer)
.value("user_custom", PrimType::kPrimTypeUserCustom)
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
.value("py_infer_check", PrimType::kPrimTypePyCheck);
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
.def(py::init<py::str &>())

View File

@ -32,11 +32,10 @@ namespace mindspore {
enum PrimType {
kPrimTypeUnknown = 0,
kPrimTypeBegin = kTypeUnknown,
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInferShape, // Primitive operator defined by custom
kPrimTypePyInferTensor, // Primitive operator defined by custom
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInfer, // Primitive operator defined by custom
kPrimTypeUserCustom,
kPrimTypePyInferCheck // Primitive operator with input args checking method
kPrimTypePyCheck // Primitive operator with input args checking method
};
class Primitive : public Named {
@ -100,8 +99,7 @@ class Primitive : public Named {
void set_prim_type(const PrimType t) { prim_type_ = t; }
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
void set_instance_name(const std::string &s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
PrimType prim_type() const { return prim_type_; }

View File

@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) {
}
TEST_F(TestOps, CustomOpAttrTest) {
Primitive prim("CustomOp", true, kPrimTypePyInferShape);
Primitive prim("CustomOp", true, kPrimTypePyInfer);
prim.SetAttrs({
{"attr1", MakeValue(static_cast<int64_t>(3))},
{"attr2", MakeValue(static_cast<int64_t>(1))},
});
ASSERT_EQ(prim.name(), std::string("CustomOp"));
ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer);
auto attrs = prim.attrs();
for (auto attr : attrs) {