forked from mindspore-Ecosystem/mindspore
!15385 remove the useless prim type
From: @lianliguang Reviewed-by: @ginfung,@zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
490d2e1efb
|
@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
|
|||
const AbstractBasePtrList &args) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||
if (prim_py == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
|
||||
}
|
||||
// Call checking method 'infer_value' for python primitive
|
||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||
|
@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
|
|||
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||
if (prim_py == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
|
||||
}
|
||||
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||
|
@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
}
|
||||
}
|
||||
|
||||
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
||||
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
|
||||
return EvalPyCheckPrim(engine, args);
|
||||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
|
|
|
@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
||||
return;
|
||||
}
|
||||
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
||||
if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
|
||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
|||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
||||
.value("py_infer_shape", PrimType::kPrimTypePyInfer)
|
||||
.value("user_custom", PrimType::kPrimTypeUserCustom)
|
||||
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
|
||||
.value("py_infer_check", PrimType::kPrimTypePyCheck);
|
||||
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
|
||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
|
||||
.def(py::init<py::str &>())
|
||||
|
|
|
@ -32,11 +32,10 @@ namespace mindspore {
|
|||
enum PrimType {
|
||||
kPrimTypeUnknown = 0,
|
||||
kPrimTypeBegin = kTypeUnknown,
|
||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||
kPrimTypePyInferShape, // Primitive operator defined by custom
|
||||
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||
kPrimTypePyInfer, // Primitive operator defined by custom
|
||||
kPrimTypeUserCustom,
|
||||
kPrimTypePyInferCheck // Primitive operator with input args checking method
|
||||
kPrimTypePyCheck // Primitive operator with input args checking method
|
||||
};
|
||||
|
||||
class Primitive : public Named {
|
||||
|
@ -100,8 +99,7 @@ class Primitive : public Named {
|
|||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
||||
void set_instance_name(const std::string &s) { instance_name_ = s; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||
|
||||
PrimType prim_type() const { return prim_type_; }
|
||||
|
|
|
@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) {
|
|||
}
|
||||
|
||||
TEST_F(TestOps, CustomOpAttrTest) {
|
||||
Primitive prim("CustomOp", true, kPrimTypePyInferShape);
|
||||
Primitive prim("CustomOp", true, kPrimTypePyInfer);
|
||||
prim.SetAttrs({
|
||||
{"attr1", MakeValue(static_cast<int64_t>(3))},
|
||||
{"attr2", MakeValue(static_cast<int64_t>(1))},
|
||||
});
|
||||
ASSERT_EQ(prim.name(), std::string("CustomOp"));
|
||||
ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
|
||||
ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer);
|
||||
|
||||
auto attrs = prim.attrs();
|
||||
for (auto attr : attrs) {
|
||||
|
|
Loading…
Reference in New Issue