move signature to primitivepy and bprop_func to utils

This commit is contained in:
leopz 2020-05-20 17:37:44 +08:00
parent 183144e135
commit 04763b8b76
11 changed files with 328 additions and 196 deletions

View File

@ -24,75 +24,13 @@
#include "pipeline/parse/data_converter.h"
#include "pybind11/pytypes.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
using mindspore::abstract::AbstractFunction;
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
return prim_func;
}
static py::function GetBpropFunctionByObj(py::object obj) {
static const std::string get_bprop_fn = "get_bprop_fn";
static const std::string ad_module = "mindspore.ops._grad";
py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj);
return fn;
}
py::function Primitive::GetBpropFunction() {
auto fn = GetBpropFunctionByObj(py::str(name()));
if (fn.is_none()) {
MS_LOG(WARNING) << "Can't find bprop function for " << name();
}
return fn;
}
py::function Primitive::GetComputeFunction() {
static const std::string module = "mindspore._extends.builtin_operations";
py::module mod = py::module::import(common::SafeCStr(module));
if (!py::hasattr(mod, common::SafeCStr(name()))) {
PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name()));
// If raise AttributeError, user can't understand. This case need raise NotImplementedError.
throw py::error_already_set();
}
py::object fn = mod.attr(common::SafeCStr(name()));
return fn;
}
bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive &>(other);
return *this == other_prim;
} else {
return false;
}
}
bool Primitive::operator==(const Primitive &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) {
return false;
}
auto iter = other.attrs_.find(item.first);
if (iter == other.attrs_.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
}
void Primitive::set_signatures(
void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto &signature : signatures) {
@ -104,27 +42,7 @@ void Primitive::set_signatures(
std::tie(name, rw, kind, default_value, dtype) = signature;
signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype));
}
}
std::string Primitive::GetAttrsText() const {
if (attrs_.empty()) {
return "";
}
std::ostringstream oss;
oss << "[";
bool is_first = true;
for (auto &attr : attrs_) {
if (is_first) {
is_first = false;
} else {
oss << ", ";
}
oss << attr.first << "=" << attr.second->DumpText();
}
oss << "]";
return oss.str();
set_has_signature(true);
}
py::function PrimitivePy::GetBpropFunction() {
@ -158,7 +76,7 @@ py::function PrimitivePy::GetComputeFunction() {
if (py::isinstance<py::none>(vm_fn)) {
MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
vm_fn = Primitive::GetComputeFunction();
vm_fn = mindspore::GetComputeFunction(Primitive::name());
}
return vm_fn;
}

View File

@ -22,59 +22,26 @@
#include <memory>
#include <string>
#include <tuple>
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "utils/log_adapter.h"
#include "ir/primitive_base.h"
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
namespace py = pybind11;
namespace mindspore {
using abstract::AbstractBasePtr;
using abstract::AbstractBasePtrList;
// Supported meta type
enum PrimType {
kPrimTypeUnknown = 0,
kPrimTypeBegin = kTypeUnknown,
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInferShape, // Primitive operator defined by custom
kPrimTypePyInferTensor, // Primitive operator defined by custom
kPrimTypeUserCustom
};
class Primitive : public Named {
class PrimitivePy : public Primitive {
public:
explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), signatures_(), prim_type_(prim_type) {}
Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
signatures_(prim.signatures_),
instance_name_(prim.instance_name_),
prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
virtual py::function GetBpropFunction();
virtual py::function GetComputeFunction();
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr;
return *this;
}
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
return *this;
}
PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {}
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
py::function GetComputeFunction();
void set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
@ -82,52 +49,6 @@ class Primitive : public Named {
const std::vector<Signature> &signatures() const { return signatures_; }
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return iter == attrs_.cend() ? nullptr : iter->second;
}
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
bool HasAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return !(iter == attrs_.cend());
}
void set_prim_type(const PrimType t) { prim_type_ = t; }
void set_instance_name(const std::string s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const;
bool operator==(const Value &other) const override;
bool operator==(const Primitive &other) const;
~Primitive() override = default;
protected:
std::unordered_map<std::string, ValuePtr> attrs_;
private:
std::vector<Signature> signatures_;
std::string instance_name_;
PrimType prim_type_;
};
class PrimitivePy : public Primitive {
public:
PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {}
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction() override;
py::function GetComputeFunction() override;
void AddPyAttr(const py::str &name, const py::object &obj);
py::dict GetAttrDict();
@ -138,25 +59,9 @@ class PrimitivePy : public Primitive {
private:
py::object python_obj_;
std::vector<Signature> signatures_;
};
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
os << *p;
return os;
}
struct PrimitiveEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->name() == t2->name();
}
};
struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_

View File

@ -0,0 +1,71 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_base.h"
#include <utility>
namespace mindspore {
bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive &>(other);
return *this == other_prim;
} else {
return false;
}
}
bool Primitive::operator==(const Primitive &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) {
return false;
}
auto iter = other.attrs_.find(item.first);
if (iter == other.attrs_.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
}
std::string Primitive::GetAttrsText() const {
if (attrs_.empty()) {
return "";
}
std::ostringstream oss;
oss << "[";
bool is_first = true;
for (auto &attr : attrs_) {
if (is_first) {
is_first = false;
} else {
oss << ", ";
}
oss << attr.first << "=" << attr.second->DumpText();
}
oss << "]";
return oss.str();
}
} // namespace mindspore

View File

@ -0,0 +1,128 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include "ir/dtype/type.h"
namespace mindspore {
// Supported meta type
enum PrimType {
kPrimTypeUnknown = 0,
kPrimTypeBegin = kTypeUnknown,
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInferShape, // Primitive operator defined by custom
kPrimTypePyInferTensor, // Primitive operator defined by custom
kPrimTypeUserCustom
};
class Primitive : public Named {
public:
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {}
Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr;
return *this;
}
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
return *this;
}
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return iter == attrs_.cend() ? nullptr : iter->second;
}
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
bool HasAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return !(iter == attrs_.cend());
}
void set_prim_type(const PrimType t) { prim_type_ = t; }
void set_instance_name(const std::string s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const;
bool operator==(const Value &other) const override;
bool operator==(const Primitive &other) const;
~Primitive() override = default;
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
bool has_signature() const { return has_signature_; }
bool is_base() const { return is_base_; }
protected:
std::unordered_map<std::string, ValuePtr> attrs_;
private:
std::string instance_name_;
bool is_base_;
bool has_signature_;
PrimType prim_type_;
};
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
os << *p;
return os;
}
struct PrimitiveEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->name() == t2->name();
}
};
struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/primitive_base.h"
#include "pipeline/static_analysis/abstract_function.h"
namespace mindspore {
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
return prim_func;
}
} // namespace mindspore

View File

@ -36,8 +36,8 @@ using PatternListType = std::initializer_list<BaseRef>;
const std::vector<Signature> &GetSignature(const ValuePtr &function) {
static const auto empty = std::vector<Signature>();
if (function->isa<Primitive>()) {
return function->cast<PrimitivePtr>()->signatures();
if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) {
return function->cast<PrimitivePyPtr>()->signatures();
} else if (function->isa<MetaFuncGraph>()) {
return function->cast<MetaFuncGraphPtr>()->signatures();
}

View File

@ -20,6 +20,7 @@
#include <string>
#include <utility>
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ir/meta_func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
@ -30,6 +31,7 @@
#include "operator/ops.h"
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "debug/info.h"
#include "debug/trace.h"
@ -49,7 +51,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
py::function fn = prim->GetBpropFunction();
py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction();
if (fn == nullptr || py::isinstance<py::none>(fn)) {
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
return nullptr;

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/primitive_utils.h"
#include "pipeline/parse/python_adapter.h"
#include "utils/log_adapter.h"
#include "common/utils.h"
namespace mindspore {
py::function GetBpropFunctionByObj(py::object obj) {
static const std::string get_bprop_fn = "get_bprop_fn";
static const std::string ad_module = "mindspore.ops._grad";
py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj);
return fn;
}
py::function GetBpropFunction(std::string name) {
auto fn = GetBpropFunctionByObj(py::str(name));
if (fn.is_none()) {
MS_LOG(WARNING) << "Can't find bprop function for " << name;
}
return fn;
}
py::function GetComputeFunction(std::string name) {
static const std::string module = "mindspore._extends.builtin_operations";
py::module mod = py::module::import(common::SafeCStr(module));
if (!py::hasattr(mod, common::SafeCStr(name))) {
PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name));
// If raise AttributeError, user can't understand. This case need raise NotImplementedError.
throw py::error_already_set();
}
py::object fn = mod.attr(common::SafeCStr(name));
return fn;
}
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
#define MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
#include <string>
#include "pybind11/pybind11.h"
namespace py = pybind11;
namespace mindspore {
py::function GetBpropFunctionByObj(py::object obj);
py::function GetBpropFunction(std::string name);
py::function GetComputeFunction(std::string name);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_

View File

@ -31,6 +31,7 @@
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
#include "debug/draw.h"
namespace mindspore {
@ -443,7 +444,7 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim);
MS_LOG(DEBUG) << "operation start " << prim->name();
auto func = operation != nullptr ? operation->GetComputeFunction() : prim->GetComputeFunction();
auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {
MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented";
}

View File

@ -390,7 +390,7 @@ TEST_F(TestOps, Conv2dAttrTest) {
}
TEST_F(TestOps, CustomOpAttrTest) {
Primitive prim("CustomOp", kPrimTypePyInferShape);
Primitive prim("CustomOp", true, kPrimTypePyInferShape);
prim.SetAttrs({
{"attr1", MakeValue(3)},
{"attr2", MakeValue(1)},