forked from mindspore-Ecosystem/mindspore
!15385 remove the useless prim type
From: @lianliguang Reviewed-by: @ginfung,@zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
490d2e1efb
|
@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
|
||||||
const AbstractBasePtrList &args) {
|
const AbstractBasePtrList &args) {
|
||||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||||
if (prim_py == nullptr) {
|
if (prim_py == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
|
||||||
}
|
}
|
||||||
// Call checking method 'infer_value' for python primitive
|
// Call checking method 'infer_value' for python primitive
|
||||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||||
|
@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
|
||||||
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||||
if (prim_py == nullptr) {
|
if (prim_py == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
|
||||||
}
|
}
|
||||||
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
||||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||||
|
@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
|
||||||
return EvalPyCheckPrim(engine, args);
|
return EvalPyCheckPrim(engine, args);
|
||||||
}
|
}
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
|
|
|
@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) {
|
||||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
|
||||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
|
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
.value("py_infer_shape", PrimType::kPrimTypePyInfer)
|
||||||
.value("user_custom", PrimType::kPrimTypeUserCustom)
|
.value("user_custom", PrimType::kPrimTypeUserCustom)
|
||||||
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
|
.value("py_infer_check", PrimType::kPrimTypePyCheck);
|
||||||
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
|
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
|
||||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
|
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
|
||||||
.def(py::init<py::str &>())
|
.def(py::init<py::str &>())
|
||||||
|
|
|
@ -32,11 +32,10 @@ namespace mindspore {
|
||||||
enum PrimType {
|
enum PrimType {
|
||||||
kPrimTypeUnknown = 0,
|
kPrimTypeUnknown = 0,
|
||||||
kPrimTypeBegin = kTypeUnknown,
|
kPrimTypeBegin = kTypeUnknown,
|
||||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||||
kPrimTypePyInferShape, // Primitive operator defined by custom
|
kPrimTypePyInfer, // Primitive operator defined by custom
|
||||||
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
|
||||||
kPrimTypeUserCustom,
|
kPrimTypeUserCustom,
|
||||||
kPrimTypePyInferCheck // Primitive operator with input args checking method
|
kPrimTypePyCheck // Primitive operator with input args checking method
|
||||||
};
|
};
|
||||||
|
|
||||||
class Primitive : public Named {
|
class Primitive : public Named {
|
||||||
|
@ -100,8 +99,7 @@ class Primitive : public Named {
|
||||||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||||
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
||||||
void set_instance_name(const std::string &s) { instance_name_ = s; }
|
void set_instance_name(const std::string &s) { instance_name_ = s; }
|
||||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
|
||||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
|
||||||
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||||
|
|
||||||
PrimType prim_type() const { return prim_type_; }
|
PrimType prim_type() const { return prim_type_; }
|
||||||
|
|
|
@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOps, CustomOpAttrTest) {
|
TEST_F(TestOps, CustomOpAttrTest) {
|
||||||
Primitive prim("CustomOp", true, kPrimTypePyInferShape);
|
Primitive prim("CustomOp", true, kPrimTypePyInfer);
|
||||||
prim.SetAttrs({
|
prim.SetAttrs({
|
||||||
{"attr1", MakeValue(static_cast<int64_t>(3))},
|
{"attr1", MakeValue(static_cast<int64_t>(3))},
|
||||||
{"attr2", MakeValue(static_cast<int64_t>(1))},
|
{"attr2", MakeValue(static_cast<int64_t>(1))},
|
||||||
});
|
});
|
||||||
ASSERT_EQ(prim.name(), std::string("CustomOp"));
|
ASSERT_EQ(prim.name(), std::string("CustomOp"));
|
||||||
ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
|
ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer);
|
||||||
|
|
||||||
auto attrs = prim.attrs();
|
auto attrs = prim.attrs();
|
||||||
for (auto attr : attrs) {
|
for (auto attr : attrs) {
|
||||||
|
|
Loading…
Reference in New Issue