forked from mindspore-Ecosystem/mindspore
173 lines
6.1 KiB
C++
173 lines
6.1 KiB
C++
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_
|
|
#define MINDSPORE_CORE_IR_PRIMITIVE_H_
|
|
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <tuple>
|
|
|
|
#include "ir/dtype/type.h"
|
|
#include "abstract/abstract_value.h"
|
|
#include "base/base_ref.h"
|
|
|
|
namespace mindspore {
|
|
// Supported meta type
|
|
enum PrimType {
|
|
kPrimTypeUnknown = 0,
|
|
kPrimTypeBegin = kTypeUnknown,
|
|
kPrimTypeBuiltIn, // Built-in primitive operator
|
|
kPrimTypePyInfer, // Primitive operator defined by custom
|
|
kPrimTypeUserCustom,
|
|
kPrimTypePyCheck // Primitive operator with input args checking method
|
|
};
|
|
|
|
class MS_CORE_API Primitive : public Named {
|
|
public:
|
|
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
|
|
Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs);
|
|
Primitive(const Primitive &prim);
|
|
MS_DECLARE_PARENT(Primitive, Named);
|
|
abstract::AbstractBasePtr ToAbstract() override;
|
|
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
|
std::string ToString() const override { return name(); }
|
|
void BeginRecordAddAttr() {
|
|
evaluate_added_attrs_.clear();
|
|
record_evaluate_add_attr_ = true;
|
|
}
|
|
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
|
|
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
|
attrs_[name] = attr;
|
|
if (record_evaluate_add_attr_) {
|
|
evaluate_added_attrs_[name] = attr;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
Primitive &DelAttr(const std::string &name) {
|
|
attrs_.erase(name);
|
|
return *this;
|
|
}
|
|
|
|
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
|
for (auto &attr : attrs) {
|
|
attrs_[attr.first] = attr.second;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
|
|
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
|
|
virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
|
|
|
|
ValuePtr GetAttr(const std::string &attrName) const {
|
|
auto iter = attrs_.find(attrName);
|
|
return iter == attrs_.cend() ? nullptr : iter->second;
|
|
}
|
|
|
|
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
|
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
|
|
void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
|
for (auto &attr : attrs) {
|
|
MS_LOG(DEBUG) << " set evalu attrl " << name() << attr.first;
|
|
attrs_[attr.first] = attr.second;
|
|
}
|
|
}
|
|
|
|
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
|
bool HasAttr() const { return !attrs_.empty(); }
|
|
bool HasAttr(const std::string &attrName) const {
|
|
auto iter = attrs_.find(attrName);
|
|
return !(iter == attrs_.cend());
|
|
}
|
|
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
|
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
|
void set_instance_name(const std::string &s) { instance_name_ = s; }
|
|
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
|
|
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
|
|
|
PrimType prim_type() const { return prim_type_; }
|
|
std::string instance_name() const { return instance_name_; }
|
|
std::string GetAttrsText() const;
|
|
bool operator==(const Value &other) const override;
|
|
bool operator==(const Primitive &other) const;
|
|
~Primitive() override = default;
|
|
|
|
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
|
|
bool has_signature() const { return has_signature_; }
|
|
bool is_base() const { return is_base_; }
|
|
virtual BaseRef RunHookFunction(const VectorRef &args) const {
|
|
MS_LOG(EXCEPTION) << "call a empty function!";
|
|
BaseRef result;
|
|
return result;
|
|
}
|
|
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
|
|
void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; }
|
|
bool is_const_prim() const { return is_const_prim_; }
|
|
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
|
|
const_input_indexes_ = const_input_indexes;
|
|
}
|
|
const std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; }
|
|
std::string id() const { return id_; }
|
|
|
|
protected:
|
|
std::unordered_map<std::string, ValuePtr> attrs_;
|
|
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
|
|
|
|
private:
|
|
std::string instance_name_;
|
|
bool is_base_;
|
|
bool has_signature_;
|
|
PrimType prim_type_;
|
|
bool record_evaluate_add_attr_;
|
|
bool is_const_prim_;
|
|
std::vector<size_t> const_input_indexes_;
|
|
std::string id_{""};
|
|
};
|
|
|
|
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
|
os << *p;
|
|
return os;
|
|
}
|
|
|
|
struct MS_CORE_API PrimitiveEqual {
|
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
|
MS_EXCEPTION_IF_NULL(t1);
|
|
MS_EXCEPTION_IF_NULL(t2);
|
|
return t1 == t2 || t1->name() == t2->name();
|
|
}
|
|
};
|
|
|
|
struct MS_CORE_API PrimitiveHasher {
|
|
std::size_t operator()(PrimitivePtr const &prim) const {
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
return prim->Hash();
|
|
}
|
|
};
|
|
|
|
struct MS_CORE_API PrimitiveTotalEqual {
|
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
|
MS_EXCEPTION_IF_NULL(t1);
|
|
MS_EXCEPTION_IF_NULL(t2);
|
|
return *t1 == *t2;
|
|
}
|
|
};
|
|
} // namespace mindspore
|
|
#endif // MINDSPORE_CORE_IR_PRIMITIVE_H_
|