forked from mindspore-Ecosystem/mindspore
refactor primitive hook function
This commit is contained in:
parent
0a3bf64b79
commit
9682d08d96
|
@ -24,7 +24,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive_base.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
|
|
|
@ -15,108 +15,57 @@
|
|||
*/
|
||||
|
||||
#include "ir/primitive.h"
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include "ir/signature.h"
|
||||
#include "operator/ops.h"
|
||||
#include "./common.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
static ValuePtr PyArgToValue(const py::object &arg) {
|
||||
if (py::isinstance<SignatureEnumKind>(arg) &&
|
||||
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
||||
return nullptr;
|
||||
}
|
||||
return parse::data_converter::PyDataToValue(arg);
|
||||
}
|
||||
|
||||
void PrimitivePy::set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
|
||||
signatures_.clear();
|
||||
for (auto &signature : signatures) {
|
||||
auto [name, rw, kind, arg_default, dtype] = signature;
|
||||
auto default_value = PyArgToValue(arg_default);
|
||||
signatures_.emplace_back(name, rw, kind, default_value, dtype);
|
||||
}
|
||||
set_has_signature(true);
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetBpropFunction() {
|
||||
static const char *const get_bprop_func_name = "get_bprop";
|
||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
|
||||
return fn;
|
||||
bool Primitive::operator==(const Value &other) const {
|
||||
if (other.isa<Primitive>()) {
|
||||
auto other_prim = static_cast<const Primitive &>(other);
|
||||
return *this == other_prim;
|
||||
} else {
|
||||
auto fn = GetBpropFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetComputeFunction() {
|
||||
static const char *const compute_func_name = "vm_impl";
|
||||
|
||||
if (py::hasattr(python_obj_, compute_func_name)) {
|
||||
MS_LOG(INFO) << name() << " compute_func_name";
|
||||
py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
|
||||
return fn;
|
||||
bool Primitive::operator==(const Primitive &other) const {
|
||||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static const std::string vm_module = "mindspore.ops.vm_impl_registry";
|
||||
static const std::string get_vm_impl_fn = "get_vm_impl_fn";
|
||||
MS_LOG(INFO) << name() << ": get_vm_impl_fn";
|
||||
py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
|
||||
py::function vm_fn = get_fn(python_obj_);
|
||||
|
||||
if (py::isinstance<py::none>(vm_fn)) {
|
||||
MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
|
||||
vm_fn = mindspore::GetComputeFunction(Primitive::name());
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
return vm_fn;
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
}
|
||||
|
||||
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
|
||||
std::string Primitive::GetAttrsText() const {
|
||||
if (attrs_.empty()) {
|
||||
return "";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
}
|
||||
(void)this->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
|
||||
py::dict PrimitivePy::GetAttrDict() {
|
||||
py::dict attr_dict;
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
bool is_first = true;
|
||||
for (auto &attr : attrs_) {
|
||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << attr.first << "=" << attr.second->DumpText();
|
||||
}
|
||||
return attr_dict;
|
||||
}
|
||||
oss << "]";
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
||||
.value("user_custom", PrimType::kPrimTypeUserCustom);
|
||||
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
||||
.def(py::init<py::str &, py::object>())
|
||||
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
||||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
|
||||
}));
|
||||
return oss.str();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,45 +23,129 @@
|
|||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ir/dtype/type.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/primitive_base.h"
|
||||
#include "ir/signature.h"
|
||||
#include "parallel/ops_info/operator_info.h"
|
||||
|
||||
#include "utils/base_ref_extends.h"
|
||||
namespace mindspore {
|
||||
class PrimitivePy : public Primitive {
|
||||
public:
|
||||
PrimitivePy(const py::str &name, const py::object &python_obj)
|
||||
: Primitive(name, false), python_obj_(python_obj), signatures_() {}
|
||||
~PrimitivePy() override = default;
|
||||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||
py::function GetBpropFunction();
|
||||
py::function GetComputeFunction();
|
||||
|
||||
void set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
||||
signatures);
|
||||
|
||||
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||
|
||||
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||
|
||||
py::dict GetAttrDict();
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
bool is_tuple_input_ = false;
|
||||
|
||||
private:
|
||||
py::object python_obj_;
|
||||
py::function hook_;
|
||||
std::vector<Signature> signatures_;
|
||||
// Supported meta type
|
||||
enum PrimType {
|
||||
kPrimTypeUnknown = 0,
|
||||
kPrimTypeBegin = kTypeUnknown,
|
||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||
kPrimTypePyInferShape, // Primitive operator defined by custom
|
||||
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
||||
kPrimTypeUserCustom
|
||||
};
|
||||
|
||||
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
||||
class Primitive : public Named {
|
||||
public:
|
||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||
: Named(name),
|
||||
is_base_(is_base),
|
||||
has_signature_(false),
|
||||
prim_type_(prim_type),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
attrs_(prim.attrs_),
|
||||
instance_name_(prim.instance_name_),
|
||||
is_base_(prim.is_base_),
|
||||
has_signature_(prim.has_signature_),
|
||||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
|
||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||
std::string ToString() const override { return name(); }
|
||||
void BeginRecordAddAttr() {
|
||||
evaluate_added_attrs_.clear();
|
||||
record_evaluate_add_attr_ = true;
|
||||
}
|
||||
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
|
||||
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||
attrs_[name] = attr;
|
||||
if (record_evaluate_add_attr_) {
|
||||
evaluate_added_attrs_[name] = attr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
||||
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
||||
|
||||
ValuePtr GetAttr(const std::string &attrName) const {
|
||||
auto iter = attrs_.find(attrName);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
|
||||
|
||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||
bool HasAttr() const { return !attrs_.empty(); }
|
||||
bool HasAttr(const std::string &attrName) const {
|
||||
auto iter = attrs_.find(attrName);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||
void set_instance_name(const std::string s) { instance_name_ = s; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||
|
||||
PrimType prim_type() const { return prim_type_; }
|
||||
std::string instance_name() const { return instance_name_; }
|
||||
std::string GetAttrsText() const;
|
||||
bool operator==(const Value &other) const override;
|
||||
bool operator==(const Primitive &other) const;
|
||||
~Primitive() override = default;
|
||||
|
||||
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
|
||||
bool has_signature() const { return has_signature_; }
|
||||
bool is_base() const { return is_base_; }
|
||||
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
|
||||
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
|
||||
|
||||
private:
|
||||
std::string instance_name_;
|
||||
bool is_base_;
|
||||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
bool record_evaluate_add_attr_;
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||
os << *p;
|
||||
return os;
|
||||
}
|
||||
|
||||
struct PrimitiveEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return t1->name() == t2->name();
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimitiveHasher {
|
||||
std::size_t operator()(PrimitivePtr const &prim) const {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->Hash();
|
||||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/primitive_base.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
bool Primitive::operator==(const Value &other) const {
|
||||
if (other.isa<Primitive>()) {
|
||||
auto other_prim = static_cast<const Primitive &>(other);
|
||||
return *this == other_prim;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Primitive::operator==(const Primitive &other) const {
|
||||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
}
|
||||
|
||||
std::string Primitive::GetAttrsText() const {
|
||||
if (attrs_.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
bool is_first = true;
|
||||
for (auto &attr : attrs_) {
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << attr.first << "=" << attr.second->DumpText();
|
||||
}
|
||||
oss << "]";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,150 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||
#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ir/dtype/type.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
// Supported meta type
|
||||
enum PrimType {
|
||||
kPrimTypeUnknown = 0,
|
||||
kPrimTypeBegin = kTypeUnknown,
|
||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||
kPrimTypePyInferShape, // Primitive operator defined by custom
|
||||
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
||||
kPrimTypeUserCustom
|
||||
};
|
||||
|
||||
class Primitive : public Named {
|
||||
public:
|
||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||
: Named(name),
|
||||
is_base_(is_base),
|
||||
has_signature_(false),
|
||||
prim_type_(prim_type),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
attrs_(prim.attrs_),
|
||||
instance_name_(prim.instance_name_),
|
||||
is_base_(prim.is_base_),
|
||||
has_signature_(prim.has_signature_),
|
||||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
|
||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||
std::string ToString() const override { return name(); }
|
||||
void BeginRecordAddAttr() {
|
||||
evaluate_added_attrs_.clear();
|
||||
record_evaluate_add_attr_ = true;
|
||||
}
|
||||
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
|
||||
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||
attrs_[name] = attr;
|
||||
if (record_evaluate_add_attr_) {
|
||||
evaluate_added_attrs_[name] = attr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
||||
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
||||
|
||||
ValuePtr GetAttr(const std::string &attrName) const {
|
||||
auto iter = attrs_.find(attrName);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
|
||||
|
||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||
bool HasAttr() const { return !attrs_.empty(); }
|
||||
bool HasAttr(const std::string &attrName) const {
|
||||
auto iter = attrs_.find(attrName);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||
void set_instance_name(const std::string s) { instance_name_ = s; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||
|
||||
PrimType prim_type() const { return prim_type_; }
|
||||
std::string instance_name() const { return instance_name_; }
|
||||
std::string GetAttrsText() const;
|
||||
bool operator==(const Value &other) const override;
|
||||
bool operator==(const Primitive &other) const;
|
||||
~Primitive() override = default;
|
||||
|
||||
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
|
||||
bool has_signature() const { return has_signature_; }
|
||||
bool is_base() const { return is_base_; }
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
|
||||
|
||||
private:
|
||||
std::string instance_name_;
|
||||
bool is_base_;
|
||||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
bool record_evaluate_add_attr_;
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||
os << *p;
|
||||
return os;
|
||||
}
|
||||
|
||||
struct PrimitiveEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return t1->name() == t2->name();
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimitiveHasher {
|
||||
std::size_t operator()(PrimitivePtr const &prim) const {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->Hash();
|
||||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/primitive_base.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
|
||||
namespace mindspore {
|
|
@ -0,0 +1,195 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/primitive_py.h"
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include "ir/signature.h"
|
||||
#include "operator/ops.h"
|
||||
#include "./common.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/base_ref_py.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kBpropAttrName = "bprop";
|
||||
constexpr auto kCellHookAttrName = "cell_hook";
|
||||
constexpr auto kCellIDAttrName = "cell_id";
|
||||
void SyncData(const py::object &arg) {
|
||||
if (py::isinstance<py::tuple>(arg)) {
|
||||
py::tuple arg_list = py::cast<py::tuple>(arg);
|
||||
for (size_t i = 0; i < arg_list.size(); i++) {
|
||||
SyncData(arg_list[i]);
|
||||
}
|
||||
}
|
||||
if (py::isinstance<tensor::Tensor>(arg)) {
|
||||
auto tensor = py::cast<tensor::TensorPtr>(arg);
|
||||
(void)tensor->data_sync();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
std::map<std::string, py::object> PrimitivePy::hook_grad_;
|
||||
static ValuePtr PyArgToValue(const py::object &arg) {
|
||||
if (py::isinstance<SignatureEnumKind>(arg) &&
|
||||
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
|
||||
return nullptr;
|
||||
}
|
||||
return parse::data_converter::PyDataToValue(arg);
|
||||
}
|
||||
|
||||
void PrimitivePy::set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
|
||||
signatures_.clear();
|
||||
for (auto &signature : signatures) {
|
||||
auto [name, rw, kind, arg_default, dtype] = signature;
|
||||
auto default_value = PyArgToValue(arg_default);
|
||||
signatures_.emplace_back(name, rw, kind, default_value, dtype);
|
||||
}
|
||||
set_has_signature(true);
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetBpropFunction() {
|
||||
static const char *const get_bprop_func_name = "get_bprop";
|
||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
|
||||
return fn;
|
||||
} else {
|
||||
auto fn = GetBpropFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
||||
auto py_args = py::tuple(args.size());
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg:" << i << ":";
|
||||
i++;
|
||||
}
|
||||
py::object obj;
|
||||
bool is_bprop = this->HasAttr(kBpropAttrName);
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
obj = hook_(*py_args);
|
||||
return std::make_shared<PyObjectRef>(obj);
|
||||
}
|
||||
SyncData(py_args[2]);
|
||||
bool is_cell = this->HasAttr(kCellHookAttrName);
|
||||
if (is_cell) {
|
||||
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
||||
auto iter = hook_grad_.find(cell_id);
|
||||
if (iter != hook_grad_.end()) {
|
||||
auto hook_args = py::tuple(3);
|
||||
hook_args[0] = cell_id;
|
||||
hook_args[1] = py::make_tuple(iter->second);
|
||||
hook_args[2] = py::make_tuple(py_args[2]);
|
||||
obj = hook_(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
hook_grad_.erase(cell_id);
|
||||
} else {
|
||||
hook_grad_[cell_id] = py_args[2];
|
||||
obj = py_args[2];
|
||||
}
|
||||
} else {
|
||||
// Hook operator for execute variable hook function
|
||||
obj = hook_(py::make_tuple(py_args[2]));
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
}
|
||||
obj = py::make_tuple(obj);
|
||||
return std::make_shared<PyObjectRef>(obj);
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetComputeFunction() {
|
||||
static const char *const compute_func_name = "vm_impl";
|
||||
|
||||
if (py::hasattr(python_obj_, compute_func_name)) {
|
||||
MS_LOG(INFO) << name() << " compute_func_name";
|
||||
py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
|
||||
return fn;
|
||||
}
|
||||
|
||||
static const std::string vm_module = "mindspore.ops.vm_impl_registry";
|
||||
static const std::string get_vm_impl_fn = "get_vm_impl_fn";
|
||||
MS_LOG(INFO) << name() << ": get_vm_impl_fn";
|
||||
py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
|
||||
py::function vm_fn = get_fn(python_obj_);
|
||||
|
||||
if (py::isinstance<py::none>(vm_fn)) {
|
||||
MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
|
||||
vm_fn = mindspore::GetComputeFunction(Primitive::name());
|
||||
}
|
||||
return vm_fn;
|
||||
}
|
||||
|
||||
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
}
|
||||
(void)this->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
|
||||
py::dict PrimitivePy::GetAttrDict() {
|
||||
py::dict attr_dict;
|
||||
for (auto &attr : attrs_) {
|
||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||
}
|
||||
return attr_dict;
|
||||
}
|
||||
|
||||
void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (!primitive->isa<PrimitivePy>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
|
||||
}
|
||||
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_py);
|
||||
this->set_hook(primitive_py->hook());
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
||||
.value("user_custom", PrimType::kPrimTypeUserCustom);
|
||||
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
||||
.def(py::init<py::str &, py::object>())
|
||||
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
||||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
|
||||
}));
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
|
||||
#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "utils/misc.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/signature.h"
|
||||
#include "parallel/ops_info/operator_info.h"
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
class PrimitivePy : public Primitive {
|
||||
public:
|
||||
PrimitivePy(const py::str &name, const py::object &python_obj)
|
||||
: Primitive(name, false), python_obj_(python_obj), signatures_() {}
|
||||
~PrimitivePy() override = default;
|
||||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||
py::function GetBpropFunction();
|
||||
py::function GetComputeFunction();
|
||||
|
||||
void set_signatures(
|
||||
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
|
||||
signatures);
|
||||
|
||||
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||
|
||||
void CopyHookFunction(const PrimitivePtr &primitive) override;
|
||||
|
||||
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||
|
||||
py::dict GetAttrDict();
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
BaseRef RunHookFunction(const VectorRef &args) const override;
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
bool is_tuple_input_ = false;
|
||||
|
||||
private:
|
||||
py::object python_obj_;
|
||||
py::function hook_;
|
||||
std::vector<Signature> signatures_;
|
||||
static std::map<std::string, py::object> hook_grad_;
|
||||
};
|
||||
|
||||
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "kernel/cpu/addn_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "kernel/cpu/allgather_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "kernel/cpu/concat_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "kernel/cpu/gather_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "kernel/cpu/slice_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive_base.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support primitive operators
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/manager.h"
|
||||
|
@ -232,10 +232,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
|
|||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
if (!prim->is_base()) {
|
||||
PrimitivePyPtr prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
bprop_cut->set_hook(prim_py->hook());
|
||||
}
|
||||
bprop_cut->CopyHookFunction(prim);
|
||||
|
||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (cell_id != "") {
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "pipeline/static_analysis/analysis_context.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "utils/any.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
|
|
@ -181,15 +181,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return ProcessGraphKernelOp(func_graph, node);
|
||||
} else {
|
||||
// insert cast for single op.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
}
|
||||
// insert cast for single op.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <unordered_set>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "ir/value.h"
|
||||
#include "transform/types.h"
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
#ifdef OPEN_SOURCE
|
||||
#include "graph/types.h"
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive_base.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/scalar.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "debug/label.h"
|
||||
|
|
|
@ -648,57 +648,8 @@ void FinalVM::SyncData(const py::object &arg) {
|
|||
|
||||
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "input for operation:";
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
std::size_t args_size = args.size();
|
||||
auto py_args = py::tuple(args_size);
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg: " << i << ":";
|
||||
i++;
|
||||
}
|
||||
// Hook operator for execute cell custom bprop function
|
||||
py::object obj;
|
||||
bool is_bprop = prim->HasAttr("bprop");
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
py::function fn_bprop = prim_py->hook();
|
||||
obj = fn_bprop(*py_args);
|
||||
return obj;
|
||||
}
|
||||
// Sync gradient data from device to host
|
||||
SyncData(py_args[2]);
|
||||
bool is_cell = prim->HasAttr("cell_hook");
|
||||
if (is_cell) {
|
||||
// Hook operator for execute cell hook function
|
||||
std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
|
||||
std::size_t hook_args_size = 3;
|
||||
auto hook_args = py::tuple(hook_args_size);
|
||||
hook_args[0] = cell_id;
|
||||
hook_args[1] = py::make_tuple(_hook_grad[cell_id]);
|
||||
hook_args[2] = py::make_tuple(py_args[2]);
|
||||
py::function fn_hook = prim_py->hook();
|
||||
obj = fn_hook(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
_hook_grad.erase(cell_id);
|
||||
} else {
|
||||
_hook_grad[cell_id] = py_args[2];
|
||||
obj = py_args[2];
|
||||
}
|
||||
} else {
|
||||
// Hook operator for execute variable hook function
|
||||
py::function fn_hook = prim_py->hook();
|
||||
obj = fn_hook(py::make_tuple(py_args[2]));
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
}
|
||||
obj = py::make_tuple(obj);
|
||||
return obj;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->RunHookFunction(args);
|
||||
}
|
||||
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -161,7 +161,6 @@ class FinalVM {
|
|||
{Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
|
||||
{Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
|
||||
{Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}};
|
||||
std::map<std::string, py::object> _hook_grad;
|
||||
};
|
||||
|
||||
using FinalVMPtr = std::shared_ptr<FinalVM>;
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "operator/ops.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "debug/draw.h"
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include "common/common_test.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "operator/ops.h"
|
||||
#include "./common.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue