!3088 decoupling getobj from primitivepy
Merge pull request !3088 from lianliguang/primitive-decoupling
This commit is contained in:
commit
863f4e4fbe
|
@ -523,14 +523,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
return iter->second;
|
||||
}
|
||||
auto py_args = PreparePyInputs(prim_py_, args);
|
||||
|
||||
auto pyobj = prim_py_->GetPyObj();
|
||||
if (pyobj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
|
||||
}
|
||||
auto infer_fuc = pyobj.attr("__infer__");
|
||||
prim_py_->BeginRecordAddAttr();
|
||||
py::dict output = infer_fuc(*py_args);
|
||||
py::dict output = prim_py_->RunInfer(py_args);
|
||||
prim_py_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_py_->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||
|
|
|
@ -654,17 +654,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c
|
|||
}
|
||||
}
|
||||
if (!is_attr_same) {
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
auto clone_fn = prim_py->GetPyObj().attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
|
||||
for (auto &item : *attrs) {
|
||||
cloned_prim->AddAttr(item.first, item.second);
|
||||
}
|
||||
return cloned_prim;
|
||||
}
|
||||
auto cloned_prim = std::make_shared<Primitive>(*prim);
|
||||
auto cloned_prim = prim->Clone();
|
||||
for (auto &item : *attrs) {
|
||||
cloned_prim->AddAttr(item.first, item.second);
|
||||
}
|
||||
|
|
|
@ -280,8 +280,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(
|
||||
input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()));
|
||||
args_spec_list.emplace_back(
|
||||
abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa<tensor::Tensor>()));
|
||||
}
|
||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||
op_exec_info->abstract = infer_res;
|
||||
|
@ -296,8 +296,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
|
|||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto pyobj = prim->GetPyObj();
|
||||
if (pyobj == nullptr) {
|
||||
if (!prim->HasPyObj()) {
|
||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||
}
|
||||
|
||||
|
@ -708,7 +707,7 @@ py::tuple RunOpInner(const py::args &args) {
|
|||
value_ret[0] = output["value"];
|
||||
return value_ret;
|
||||
}
|
||||
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
|
||||
if (op_exec_info->py_primitive->ObjHasAttr("const_value")) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = "";
|
||||
return value_ret;
|
||||
|
|
|
@ -100,6 +100,7 @@ class Primitive : public Named {
|
|||
return !(iter == attrs_.cend());
|
||||
}
|
||||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
||||
void set_instance_name(const std::string s) { instance_name_ = s; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||
|
|
|
@ -196,6 +196,21 @@ bool PrimitivePy::HasComputeFunction() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
PrimitivePtr PrimitivePy::Clone() {
|
||||
auto clone_fn = python_obj_.attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
|
||||
return cloned_prim;
|
||||
}
|
||||
|
||||
py::dict PrimitivePy::RunInfer(const py::tuple &args) {
|
||||
if (!HasPyObj()) {
|
||||
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
|
||||
}
|
||||
auto infer_fuc = python_obj_.attr("__infer__");
|
||||
return infer_fuc(*args);
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
|
|
|
@ -61,6 +61,10 @@ class PrimitivePy : public Primitive {
|
|||
bool HasComputeFunction() const;
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
py::dict RunInfer(const py::tuple &args);
|
||||
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
|
||||
bool HasPyObj() { return python_obj_ != nullptr; }
|
||||
PrimitivePtr Clone() override;
|
||||
bool is_tuple_input_ = false;
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue