!25072 add core api doc

Merge pull request !25072 from lianliguang/code_docs_api
This commit is contained in:
i-robot 2021-10-20 12:27:27 +00:00 committed by Gitee
commit 169b36ba0e
4 changed files with 354 additions and 29 deletions

View File

@ -24,19 +24,37 @@
#include "ir/anf.h"
namespace mindspore {
/// \beief Named defines an abstract class that records the name and hash_id.
class MS_CORE_API Named : public Value {
public:
/// \beief The constructor for Named.
///
/// \param[in] name The name of object.
explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
/// \brief The constructor for Named, create a Named for another Named.
///
/// \param[in] other The input tensor.
Named(const Named &other) : Value(other) {
this->name_ = other.name_;
hash_id_ = std::hash<std::string>{}(other.name_);
}
/// \brief The destructor of None.
~Named() override = default;
MS_DECLARE_PARENT(Named, Value);
/// \brief Getting name of object.
///
/// \return The name of object.
const std::string &name() const { return name_; }
/// \brief Check whether two Named objects are the same.
///
/// \param[in] other The other Named to be compared with.
/// \return Return true if the same,otherwise return false.
virtual bool operator==(const Named &other) const { return name_ == other.name(); }
bool operator==(const Value &other) const override;
/// \brief Overloads operator '=' for Named.
///
/// \param[in] other An existing Named object.
/// \return A Named object set with the same type, name and hash_id as other.
Named &operator=(const Named &other) {
if (&other != this) {
this->type_ = other.type_;
@ -45,15 +63,23 @@ class MS_CORE_API Named : public Value {
}
return *this;
}
/// \brief Get hash id for named.
///
/// \return The restored hash id of Named.
std::size_t Hash() const { return hash_id_; }
std::size_t hash() const override { return hash_id_; }
/// \brief Overloads operator << for Named.
///
/// \param os The output stream.
/// \param nmd Named to be displayed.
/// \return Output stream that contains the name of Named object.
friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
os << nmd.name();
return os;
}
/// \brief Get name for Named.
///
/// \return The restored name of Named.
std::string ToString() const override { return name(); }
private:
@ -61,43 +87,60 @@ class MS_CORE_API Named : public Value {
std::size_t hash_id_;
};
using NamedPtr = std::shared_ptr<Named>;
/// \brief Implementation of hash operation.
struct MS_CORE_API NamedHasher {
/// \brief Implementation of hash operation.
///
/// \param name The Name object need to be hashed.
/// \return The hash result.
std::size_t operator()(NamedPtr const &name) const {
std::size_t hash = name->Hash();
return hash;
}
};
/// \brief Equal operator for Name.
struct MS_CORE_API NamedEqual {
/// \brief Implementation of Equal operation.
///
/// \param t1 The left Named to compare.
/// \param t2 The right Named to compare.
/// \return The comparison result, Return true if t1 and t2 is the same,else return false.
bool operator()(NamedPtr const &t1, NamedPtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
/// \beief None defines interface for none data.
class MS_CORE_API None : public Named {
public:
/// \beief The default constructor for None.
None() : Named("None") {}
/// \brief The destructor of None.
~None() override = default;
MS_DECLARE_PARENT(None, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
inline const NamedPtr kNone = std::make_shared<None>();
/// \beief Null defines interface for null data.
class MS_CORE_API Null : public Named {
public:
/// \beief The default constructor for Null.
Null() : Named("Null") {}
/// \brief The destructor of Null.
~Null() override = default;
MS_DECLARE_PARENT(Null, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
inline const NamedPtr kNull = std::make_shared<Null>();
/// \beief Ellipsis defines interface for ... data.
class MS_CORE_API Ellipsis : public Named {
public:
/// \beief The default constructor for Ellipsis.
Ellipsis() : Named("Ellipsis") {}
/// \brief The destructor of Ellipsis.
~Ellipsis() override = default;
MS_DECLARE_PARENT(Ellipsis, Named);
abstract::AbstractBasePtr ToAbstract() override;

View File

@ -32,26 +32,43 @@ namespace mindspore {
enum PrimType {
kPrimTypeUnknown = 0,
kPrimTypeBegin = kTypeUnknown,
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInfer, // Primitive operator defined by custom
kPrimTypeUserCustom,
kPrimTypePyCheck // Primitive operator with input args checking method
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInfer, // Primitive operator with python infer function
kPrimTypeUserCustom, // Primitive operator defined by custom
kPrimTypePyCheck // Primitive operator with input args checking method
};
/// \brief Primitive defines a operator primitive of MindSpore.
class MS_CORE_API Primitive : public Named {
public:
/// \brief The constructor of Primitive.
///
/// \param[in] name The name of primitive.
/// \param[in] is_base True means the basic Primitive without BProp function inside.
/// \param[in] prim_type The type of primitive.
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs);
/// \brief The constructor for Primitive, create a primitive for another primitive.
///
/// \param[in] prim The input primitive.
Primitive(const Primitive &prim);
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract() override;
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
/// \brief Ready to recording the attribute if the attribute needs to be added when deducing shape and type.
/// This attributes has been recorded needs to add in infer cache.
void BeginRecordAddAttr() {
evaluate_added_attrs_.clear();
record_evaluate_add_attr_ = true;
}
/// \brief End recording attribute.
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
/// \brief Add attribute to primitive attribute map and record the new attribute to evaluate_added_attrs_,
/// if record_evaluate_add_attr_ is true.
///
/// \param[in] name The name of attribute.
/// \param[in] attr The value of attribute.
/// \return The primitive to which attribute has been added.
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr;
if (record_evaluate_add_attr_) {
@ -59,71 +76,160 @@ class MS_CORE_API Primitive : public Named {
}
return *this;
}
/// \brief Delete the attribute.
///
/// \param[in] name The name of attribute to be delete.
/// \return The primitive to which attribute has been added.
Primitive &DelAttr(const std::string &name) {
attrs_.erase(name);
return *this;
}
/// \brief Use add attribute by using a map,all elements of the map will be added in the primitive's attribute map.
///
/// \param[in] attrs The attribute map needs to be added in the primitive attribute.
/// \return The primitive to which attribute has been added.
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
return *this;
}
/// \brief Set attribute to the primitive attribute map.
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
/// \brief Erase attribute to the primitive attribute map.
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
/// \brief Run Primitive's compute function if the compute function has been implemented.
///
/// \param[in] args The arguments of primitive need to compute.
/// \return The primitive's calculation result.
virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
/// \brief Get Primitive's attribute.
///
/// \param[in] attrName Primitive attribute name.
/// \return The value of attribute in primitive attribute map, if the map is not
ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return iter == attrs_.cend() ? nullptr : iter->second;
}
/// \brief Get Primitive's all attributes.
///
/// \return The Primitive's all attribute.
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
/// \brief Get the attributes added in MindSpore renormalize stage.
///
/// \return Attributes which have been added in MindSpore renormalize stage.
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
/// \brief Use add attribute using a map,all elements of the map will be added in the primitive's attribute map.
///
/// \param[in] attrs The attribute map needs to be added in the primitive attribute.
void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
MS_LOG(DEBUG) << " set evalu attrl " << name() << attr.first;
attrs_[attr.first] = attr.second;
}
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
/// \brief Check if Primitive has any attribute.
/// for example Primitives like scalar_add, return, etc, don't have any attribute.
///
/// \return Return ture, If Primitive has attributes, else return false.
bool HasAttr() const { return !attrs_.empty(); }
/// \brief Check If Primitive has an attribute named attrName.
///
/// \param[in] attrName The name of attribute.
/// \return Return true if Primitive has an attribute named attrName,else return false.
bool HasAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return !(iter == attrs_.cend());
}
/// \brief Set the name of primitive.
///
/// \param t The primitive type that needs to be set.
void set_prim_type(const PrimType t) { prim_type_ = t; }
/// \brief Clone a Primitive.
///
/// \return A Primitive which cloned by current primitive.
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
/// \brief Set primitive instance_name.
///
/// \param[in] s The primitive instance name to be set.
void set_instance_name(const std::string &s) { instance_name_ = s; }
/// \brief Check whether the primitive type if has the Python infer function,
///
/// \return Return true if Primitive's type is kPrimTypePyInfer or kPrimTypeUserCustom, else return false.
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
/// \brief Check whether the primitive type if has the python infer function,
///
/// \return Return true if Primitive's type is kPrimTypeUserCustom, else return false.
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
/// \brief Get Primitive type.
///
/// \return The type of Primitive.
PrimType prim_type() const { return prim_type_; }
/// \brief Get primitive instance name.
///
/// \return The instance name of primitive.
std::string instance_name() const { return instance_name_; }
/// \brief Get primitive attribute debug string.
/// If the attribute name of primitive is a,the value is b
/// The return value of GetAttrsText function is [a=b].
///
/// \return Get attribute debug string of primitive.
std::string GetAttrsText() const;
bool operator==(const Value &other) const override;
/// \brief To compare whether two Primitive objects are equal.
///
/// \param[in] other The other Primitive be compared with.
/// \return return true if the name and attributes of primitives are the same,otherwise return false.
bool operator==(const Primitive &other) const;
/// \brief Destructor of Primitive.
~Primitive() override = default;
/// \brief The flag to be set in primitive.
///
/// \param[in] has_signature Set the flag whether there is a signature for the primitive.
void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
/// \brief Check whether the primitive has signature.
///
/// \return Return true if primitive has signature flag , else return false.
bool has_signature() const { return has_signature_; }
/// \brief Check whether the primitive is a basic primitive.
///
/// \return Return true if the primitive is basic, else return false.
bool is_base() const { return is_base_; }
/// \brief Get Primitive's hook function result.
///
/// \param args The arguments of hook function.
/// \return The result of hookfunction.
virtual BaseRef RunHookFunction(const VectorRef &args) const {
MS_LOG(EXCEPTION) << "call a empty function!";
BaseRef result;
return result;
}
/// \brief Copy a Primitive's hook function to another primitive.
///
/// \param[in] primitive Set primitive's hook function to the current object.
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
/// \brief Set primitive const flag.
/// If the is_const_prim_ of primitive is true means the primitive will be eliminated in constant folding.
///
/// \param is_const_prim The flag of primitive to be set.
void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; }
/// \brief Check whether the primitive is const primitive.
///
/// \return Return true if primitive is a const primitive, else return false.
bool is_const_prim() const { return is_const_prim_; }
/// \brief Set const input index for primitive.
///
/// \param const_input_indexes The const input index of the primitive to be set.
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
const_input_indexes_ = const_input_indexes;
}
/// \brief Get const input index of the primitive.
///
/// \return Const input indexes of the primitive.
const std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; }
/// \brief Get Primitive's id.
///
/// \return primitive's Id.
uint64_t id() const { return id_; }
protected:
@ -146,7 +252,13 @@ inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
return os;
}
/// \brief Equal operator for Primitive.
struct MS_CORE_API PrimitiveEqual {
/// \brief Implementation of Equal operation.
///
/// \param t1 The left Primitive to compare.
/// \param t2 The right Primitive to compare.
/// \return The comparison result,Return true if the name and address of t1 and t2 are the same ,else return false.
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
@ -154,14 +266,25 @@ struct MS_CORE_API PrimitiveEqual {
}
};
/// \brief Implementation of hash operation.
struct MS_CORE_API PrimitiveHasher {
/// \brief Implementation of hash operation.
///
/// \param name The PrimitiveHasher to be hashed.
/// \return The hash result.
std::size_t operator()(PrimitivePtr const &prim) const {
MS_EXCEPTION_IF_NULL(prim);
return prim->Hash();
}
};
/// \brief Equal operator for Primitive.
struct MS_CORE_API PrimitiveTotalEqual {
/// \brief Implementation of Equal operation.
///
/// \param t1 The left Primitive to compare.
/// \param t2 The right Primitive to compare.
/// \return The comparison result,Return true if t1 and t2 are the same,else return false.
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);

View File

@ -35,13 +35,25 @@
using std::fabs;
namespace mindspore {
/// \beief Scalar defines interface for scalar data.
class MS_CORE_API Scalar : public Value {
public:
/// \beief The default constructor for Scalar.
Scalar() = default;
/// \brief The constructor for Scalar.
///
/// \param[in] t The type of scalar.
explicit Scalar(const TypePtr t) : Value(t) {}
/// \brief The destructor of Scalar.
~Scalar() override = default;
MS_DECLARE_PARENT(Scalar, Value)
/// \brief Check whether the value of scalar is zero.
///
/// \return Return true if the value of scalar is zero ,else return false.
virtual bool IsZero() = 0;
/// \brief Check whether the value of scalar is zero.
///
/// \return Return true if the value of scalar is zero ,else return false.
virtual bool IsOne() = 0;
abstract::AbstractBasePtr ToAbstract() override;
@ -50,16 +62,28 @@ class MS_CORE_API Scalar : public Value {
};
using ScalarPtr = std::shared_ptr<Scalar>;
/// \beief BoolImm defines interface for bool data.
class MS_CORE_API BoolImm : public Scalar {
public:
/// \brief The constructor of BoolImm.
///
/// \param[in] b The value of bool data.
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
/// \brief The destructor of BoolImm.
~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm, Scalar)
std::size_t hash() const override { return hash_; }
/// \brief Get the value of BoolImm.
///
/// \return Return the value of BoolImm.
bool value() const { return v_; }
bool IsZero() override { return v_ == false; }
bool IsOne() override { return v_ == true; }
bool operator==(const Value &other) const override;
/// \brief Compare two BoolImm objects is equal.
///
/// \param[in] other The other BoolImm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const BoolImm &other) const;
std::string ToString() const override {
if (v_) {
@ -81,25 +105,44 @@ class MS_CORE_API BoolImm : public Scalar {
using BoolImmPtr = std::shared_ptr<BoolImm>;
IMM_TRAITS(BoolImmPtr, bool)
/// \beief IntegerImm defines interface for integer data.
class MS_CORE_API IntergerImm : public Scalar {
public:
/// \beief The default constructor for IntegerImm.
IntergerImm() = default;
/// \brief The constructor for IntegerImm.
///
/// \param[in] t The type of IntegerImm.
explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
/// \brief The destructor of Scalar.
~IntergerImm() override = default;
MS_DECLARE_PARENT(IntergerImm, Scalar)
};
/// \beief Int8Imm defines interface for int8 data.
class MS_CORE_API Int8Imm : public IntergerImm {
public:
/// \beief The default constructor for Int8Imm.
Int8Imm() : IntergerImm(kInt8), v_(0) {}
/// \brief The constructor for Int8Imm.
///
/// \param[in] v The value of Int8Imm.
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
/// \brief The destructor of Int8Imm.
~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of Int8Imm.
///
/// \return Return the value of Int8Imm.
int8_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two Int8Imm objects is equal.
///
/// \param[in] other The other Int8Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const Int8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -114,18 +157,30 @@ class MS_CORE_API Int8Imm : public IntergerImm {
};
using Int8ImmPtr = std::shared_ptr<Int8Imm>;
IMM_TRAITS(Int8ImmPtr, int8_t)
/// \beief Int16Imm defines interface for int16 data.
class MS_CORE_API Int16Imm : public IntergerImm {
public:
/// \beief The default constructor for Int16Imm.
Int16Imm() : IntergerImm(kInt16), v_(0) {}
/// \brief The constructor for Int16Imm.
///
/// \param[in] v The value of Int16Imm.
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
/// \brief The destructor of Int16Imm.
~Int16Imm() override = default;
MS_DECLARE_PARENT(Int16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of Int16Imm.
///
/// \return Return the value of Int16Imm.
int16_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two Int16Imm objects is equal.
///
/// \param[in] other The other Int16Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const Int16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -141,17 +196,30 @@ class MS_CORE_API Int16Imm : public IntergerImm {
using Int16ImmPtr = std::shared_ptr<Int16Imm>;
IMM_TRAITS(Int16ImmPtr, int16_t)
/// \beief Int32Imm defines interface for int32 data.
class MS_CORE_API Int32Imm : public IntergerImm {
public:
/// \beief The default constructor for Int32Imm.
Int32Imm() : IntergerImm(kInt32), v_(0) {}
/// \brief The constructor for Int32Imm.
///
/// \param[in] v The value of Int32Imm.
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
/// \brief The destructor of Int32Imm.
~Int32Imm() override = default;
MS_DECLARE_PARENT(Int32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of Int32Imm.
///
/// \return Return the value of Int32Imm.
int32_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two Int32Imm objects is equal.
///
/// \param[in] other The other Int32Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const Int32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -167,17 +235,30 @@ class MS_CORE_API Int32Imm : public IntergerImm {
using Int32ImmPtr = std::shared_ptr<Int32Imm>;
IMM_TRAITS(Int32ImmPtr, int32_t)
/// \beief Int64Imm defines interface for int64 data.
class MS_CORE_API Int64Imm : public IntergerImm {
public:
/// \beief The default constructor for Int64Imm.
Int64Imm() : IntergerImm(kInt64), v_(0) {}
/// \brief The constructor for Int64Imm.
///
/// \param[in] v The value of Int64Imm.
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
/// \brief The destructor of Int64Imm.
~Int64Imm() override = default;
MS_DECLARE_PARENT(Int64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of Int64Imm.
///
/// \return Return the value of Int64Imm.
int64_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two Int64Imm objects is equal.
///
/// \param[in] other The other Int64Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const Int64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -192,20 +273,32 @@ class MS_CORE_API Int64Imm : public IntergerImm {
};
using Int64ImmPtr = std::shared_ptr<Int64Imm>;
IMM_TRAITS(Int64ImmPtr, int64_t)
/// \beief UInt8Imm defines interface for uint8 data.
class MS_CORE_API UInt8Imm : public IntergerImm {
public:
/// \beief The default constructor for UInt8Imm.
UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
/// \brief The constructor for UInt8Imm.
///
/// \param[in] v The value of UInt8Imm.
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
/// \brief The destructor of UInt8Imm.
~UInt8Imm() override = default;
MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of UInt8Imm.
///
/// \return Return the value of UInt8Imm.
uint8_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two UInt8Imm objects is equal.
///
/// \param[in] other The other UInt8Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const UInt8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -221,19 +314,32 @@ class MS_CORE_API UInt8Imm : public IntergerImm {
using UInt8ImmPtr = std::shared_ptr<UInt8Imm>;
IMM_TRAITS(UInt8ImmPtr, uint8_t);
/// \beief UInt16Imm defines interface for uint16 data.
class MS_CORE_API UInt16Imm : public IntergerImm {
public:
/// \beief The default constructor for UInt16Imm.
UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
/// \brief The constructor for UInt16Imm.
///
/// \param[in] v The value of UInt16Imm.
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
/// \brief The destructor of UInt16Imm.
~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of UInt16Imm.
///
/// \return Return the value of UInt16Imm.
uint16_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two UInt16Imm objects is equal.
///
/// \param[in] other The other UInt16Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const UInt16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -249,19 +355,32 @@ class MS_CORE_API UInt16Imm : public IntergerImm {
using UInt16ImmPtr = std::shared_ptr<UInt16Imm>;
IMM_TRAITS(UInt16ImmPtr, uint16_t);
/// \beief UInt32Imm defines interface for uint32 data.
class MS_CORE_API UInt32Imm : public IntergerImm {
public:
/// \beief The default constructor for UInt32Imm.
UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
/// \brief The constructor for UInt32Imm.
///
/// \param[in] v The value of UInt32Imm.
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
/// \brief The destructor of UInt32Imm.
~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of UInt32Imm.
///
/// \return Return the value of UInt32Imm.
uint32_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two UInt32Imm objects is equal.
///
/// \param[in] other The other UInt32Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const UInt32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -276,20 +395,32 @@ class MS_CORE_API UInt32Imm : public IntergerImm {
};
using UInt32ImmPtr = std::shared_ptr<UInt32Imm>;
IMM_TRAITS(UInt32ImmPtr, uint32_t);
/// \beief UInt64Imm defines interface for uint64 data.
class MS_CORE_API UInt64Imm : public IntergerImm {
public:
/// \beief The default constructor for UInt64Imm.
UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
/// \brief The constructor for UInt64Imm.
///
/// \param[in] v The value of UInt64Imm.
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
}
/// \brief The destructor of UInt64Imm.
~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
/// \brief Get the value of UInt64Imm.
///
/// \return Return the value of UInt64Imm.
uint64_t value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two UInt64Imm objects is equal.
///
/// \param[in] other The other UInt64Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const UInt64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -304,27 +435,45 @@ class MS_CORE_API UInt64Imm : public IntergerImm {
};
using UInt64ImmPtr = std::shared_ptr<UInt64Imm>;
IMM_TRAITS(UInt64ImmPtr, uint64_t);
/// \beief FloatImm defines interface for float data.
class MS_CORE_API FloatImm : public Scalar {
public:
/// \beief The default constructor for FloatImm.
FloatImm() = default;
/// \brief The constructor for FloatImm.
///
/// \param[in] v The value of FloatImm.
explicit FloatImm(const TypePtr &t) : Scalar(t) {}
/// \brief The destructor of FloatImm.
~FloatImm() override = default;
MS_DECLARE_PARENT(FloatImm, Scalar)
};
using FloatImmPtr = std::shared_ptr<FloatImm>;
/// \beief FP32Imm defines interface for float32 data.
class MS_CORE_API FP32Imm : public FloatImm {
public:
/// \beief The default constructor for FP32Imm.
FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
/// \brief The constructor for FP32Imm.
///
/// \param[in] v The value of FP32Imm.
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
/// \brief The destructor of FP32Imm.
~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm, FloatImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
/// \brief Get the value of FP32Imm.
///
/// \return Return the value of FP32Imm.
float value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two FP32Imm objects is equal.
///
/// \param[in] other The other FP32Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const FP32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
@ -339,18 +488,30 @@ class MS_CORE_API FP32Imm : public FloatImm {
};
using FP32ImmPtr = std::shared_ptr<FP32Imm>;
IMM_TRAITS(FP32ImmPtr, float)
/// \beief FP64Imm defines interface for float64 data.
class MS_CORE_API FP64Imm : public FloatImm {
public:
/// \beief The default constructor for FP64Imm.
FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
/// \brief The constructor for FP64Imm.
///
/// \param[in] v The value of FP64Imm.
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
/// \brief The destructor of FP64Imm.
~FP64Imm() override = default;
MS_DECLARE_PARENT(FP64Imm, FloatImm)
std::size_t hash() const override { return hash_; }
bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
/// \brief Get the value of FP64Imm.
///
/// \return Return the value of FP64Imm.
double value() const { return v_; }
bool operator==(const Value &other) const override;
/// \brief Compare two FP64Imm objects is equal.
///
/// \param[in] other The other FP64Imm to be compared with.
/// \return Return true if other's value and the value of current object are the same,else return false.
bool operator==(const FP64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }

View File

@ -568,7 +568,7 @@ class TBERegOp(RegOp):
need_compile (bool): Whether the input needs to be compiled or not. Default: None.
param_type (str): Type of the input. Default: None.
shape (str): Shape of the input. Default: None.
value_depend (str): Whether the input is const value depend. Default: None.
value_depend (str): Whether the input is constant value depend. Default: None.
kwargs (dict): Other information of the input.
"""
param_list = [index, name, need_compile, param_type, shape, value_depend]
@ -611,8 +611,6 @@ class DataType:
Please add it if necessary.
current support:
.. code-block::
None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")