forked from mindspore-Ecosystem/mindspore
!3033 decoupling primitive of compute function
Merge pull request !3033 from lianliguang/primi-decoupling-v2
This commit is contained in:
commit
6566b38371
|
@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
|
||||
if (AnfAlgo::IsRealKernel(cnode)) {
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
|
||||
}
|
||||
|
|
|
@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
MS_LOG(INFO) << "RunOpInVM end";
|
||||
return std::move(result);
|
||||
}
|
||||
auto func = op_exec_info->py_primitive->GetComputeFunction();
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
MS_LOG(ERROR) << "VM failed to get func";
|
||||
auto primitive = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs);
|
||||
if (py::isinstance<py::none>(result)) {
|
||||
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
|
||||
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
|
||||
py::tuple err_ret(0);
|
||||
return std::move(err_ret);
|
||||
}
|
||||
|
||||
// execute op
|
||||
py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
|
||||
py::tuple tuple_result = py::make_tuple(result);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
MS_LOG(INFO) << "RunOpInVM end";
|
||||
return std::move(result);
|
||||
return std::move(tuple_result);
|
||||
}
|
||||
|
||||
bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "utils/primitive_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) {
|
|||
py::object fn = mod.attr(common::SafeCStr(name));
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
|
||||
auto py_args = py::tuple(args.size());
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
|
||||
i++;
|
||||
}
|
||||
return py_args;
|
||||
}
|
||||
|
||||
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
auto func = GetComputeFunction(prim->name());
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
|
||||
}
|
||||
auto py_args = ConvertDatatoPyTuple(args);
|
||||
py::object obj = func(*py_args);
|
||||
return std::make_shared<PyObjectRef>(obj);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/base_ref.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj);
|
|||
py::function GetBpropFunction(std::string name);
|
||||
|
||||
py::function GetComputeFunction(std::string name);
|
||||
|
||||
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
|
||||
|
||||
py::tuple ConvertDatatoPyTuple(const VectorRef &args);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
|
||||
|
|
|
@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
|
|||
}
|
||||
|
||||
BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim);
|
||||
|
||||
MS_LOG(DEBUG) << "operation start " << prim->name();
|
||||
auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name());
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented";
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto result = prim->RunComputeFunction(args);
|
||||
if (result.is_null()) {
|
||||
return RunComputeFunction(prim, args);
|
||||
}
|
||||
|
||||
py::tuple py_args = py::tuple(args.size());
|
||||
MS_LOG(DEBUG) << "input for operation:";
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg: " << i << ":";
|
||||
i++;
|
||||
}
|
||||
py::object obj = func(*py_args);
|
||||
MS_LOG(DEBUG) << "result:" << py::str(obj);
|
||||
return obj;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace compile
|
||||
|
|
|
@ -83,6 +83,7 @@ class Primitive : public Named {
|
|||
|
||||
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
||||
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
||||
virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
|
||||
|
||||
ValuePtr GetAttr(const std::string &attrName) const {
|
||||
auto iter = attrs_.find(attrName);
|
||||
|
|
|
@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() {
|
|||
}
|
||||
|
||||
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
||||
auto py_args = py::tuple(args.size());
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg:" << i << ":";
|
||||
i++;
|
||||
}
|
||||
auto py_args = ConvertDatatoPyTuple(args);
|
||||
py::object obj;
|
||||
bool is_bprop = this->HasAttr(kBpropAttrName);
|
||||
if (is_bprop) {
|
||||
|
@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
return std::make_shared<PyObjectRef>(obj);
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetComputeFunction() {
|
||||
py::function PrimitivePy::GetComputeFunction() const {
|
||||
static const char *const compute_func_name = "vm_impl";
|
||||
|
||||
if (py::hasattr(python_obj_, compute_func_name)) {
|
||||
|
@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
|||
this->set_hook(primitive_py->hook());
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
|
||||
auto py_args = ConvertDatatoPyTuple(args);
|
||||
auto result = this->RunPyComputeFunction(py_args);
|
||||
if (py::isinstance<py::none>(result)) {
|
||||
return std::make_shared<BaseRef>(nullptr);
|
||||
}
|
||||
return std::make_shared<PyObjectRef>(result);
|
||||
}
|
||||
|
||||
py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
|
||||
auto func = this->GetComputeFunction();
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
return py::none();
|
||||
}
|
||||
auto result = func(*py_args);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool PrimitivePy::HasComputeFunction() const {
|
||||
auto func = GetComputeFunction();
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
|
|
|
@ -41,7 +41,6 @@ class PrimitivePy : public Primitive {
|
|||
~PrimitivePy() override = default;
|
||||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||
py::function GetBpropFunction();
|
||||
py::function GetComputeFunction();
|
||||
|
||||
void set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
||||
|
@ -57,11 +56,15 @@ class PrimitivePy : public Primitive {
|
|||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
BaseRef RunHookFunction(const VectorRef &args) const override;
|
||||
BaseRef RunComputeFunction(const VectorRef &args) const override;
|
||||
py::object RunPyComputeFunction(const py::tuple &py_args) const;
|
||||
bool HasComputeFunction() const;
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
bool is_tuple_input_ = false;
|
||||
|
||||
private:
|
||||
py::function GetComputeFunction() const;
|
||||
py::object python_obj_;
|
||||
py::function hook_;
|
||||
std::vector<Signature> signatures_;
|
||||
|
|
|
@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) {
|
|||
ASSERT_TRUE(conv2d_ptr);
|
||||
if (nullptr != conv2d_ptr) {
|
||||
MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name();
|
||||
auto func = conv2d_ptr->GetComputeFunction();
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
if(!conv2d_ptr->HasComputeFunction()){
|
||||
MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented";
|
||||
}
|
||||
|
||||
|
|
|
@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
ASSERT_TRUE(allreduce_ptr);
|
||||
if (nullptr != allreduce_ptr) {
|
||||
MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
|
||||
auto func = allreduce_ptr->GetComputeFunction();
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
if (!allreduce_ptr->HasComputeFunction()) {
|
||||
MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
|
||||
}
|
||||
|
||||
|
|
|
@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
|
|||
|
||||
std::vector<BaseRef> todos(splits.size());
|
||||
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
|
||||
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
todos.resize(std::distance(todos.begin(), it));
|
||||
ASSERT_EQ(todos.size(), 1);
|
||||
|
||||
AnfNodePtrList anf_list;
|
||||
AnfNodePtrList anf_list;
|
||||
for (auto &item : utils::cast<VectorRef>(todos[0])) {
|
||||
anf_list.push_back(utils::cast<AnfNodePtr>(item));
|
||||
}
|
||||
|
@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
|
|||
|
||||
std::vector<BaseRef> todos(splits.size());
|
||||
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
|
||||
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
todos.resize(std::distance(todos.begin(), it));
|
||||
ASSERT_EQ(todos.size(), 1);
|
||||
|
||||
AnfNodePtrList anf_list;
|
||||
AnfNodePtrList anf_list;
|
||||
for (auto &item : utils::cast<VectorRef>(todos[0])) {
|
||||
anf_list.push_back(utils::cast<AnfNodePtr>(item));
|
||||
}
|
||||
|
@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) {
|
|||
|
||||
std::vector<BaseRef> todos(splits.size());
|
||||
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
|
||||
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
|
||||
todos.resize(std::distance(todos.begin(), it));
|
||||
ASSERT_EQ(todos.size(), 1);
|
||||
|
||||
AnfNodePtrList anf_list;
|
||||
AnfNodePtrList anf_list;
|
||||
for (auto &item : utils::cast<VectorRef>(todos[0])) {
|
||||
anf_list.push_back(utils::cast<AnfNodePtr>(item));
|
||||
}
|
||||
|
@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {
|
|||
|
||||
TEST_F(TestCompileSegmentRunner, test_RunOperation1) {
|
||||
VectorRef args({1});
|
||||
auto res = RunOperation(prim::kPrimIdentity, args);
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args);
|
||||
ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1);
|
||||
}
|
||||
|
||||
TEST_F(TestCompileSegmentRunner, test_RunOperation2) {
|
||||
VectorRef args({1, 2});
|
||||
auto res = RunOperation(prim::kPrimScalarGt, args);
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args);
|
||||
ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false);
|
||||
}
|
||||
} // namespace compile
|
||||
|
|
Loading…
Reference in New Issue