forked from mindspore-Ecosystem/mindspore
Code check
This commit is contained in:
parent
830f4cf441
commit
d94c1410b4
|
@ -63,6 +63,9 @@ void CNode::add_input(const AnfNodePtr &input) {
|
|||
}
|
||||
|
||||
void CNode::set_input(size_t i, const AnfNodePtr &new_input) {
|
||||
if (i >= inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
|
||||
}
|
||||
inputs_[i] = new_input;
|
||||
input_tensor_num_ = -1;
|
||||
}
|
||||
|
@ -74,7 +77,7 @@ void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
|
|||
|
||||
const AnfNodePtr &CNode::input(size_t i) const {
|
||||
if (i >= inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "i:" << i << "out of range:" << inputs_.size() << ",cnode:" << DebugString();
|
||||
MS_LOG(EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
|
||||
}
|
||||
return inputs_.at(i);
|
||||
}
|
||||
|
@ -130,7 +133,7 @@ ParamInfoPtr Parameter::param_info() const {
|
|||
std::string ValueNode::ToString() const {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
if (value_->isa<FuncGraph>()) {
|
||||
return value_->cast<FuncGraphPtr>()->ToString();
|
||||
return value_->ToString();
|
||||
}
|
||||
std::ostringstream buffer;
|
||||
buffer << AnfNode::ToString();
|
||||
|
@ -168,10 +171,7 @@ bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
|
|||
}
|
||||
|
||||
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode != nullptr) {
|
||||
if (cnode->size() > 0) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -187,8 +187,8 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
|
|||
}
|
||||
|
||||
AnfNodePtr valuenode = cnode->input(0);
|
||||
if (valuenode->isa<ValueNode>()) {
|
||||
auto value = GetValueNode(valuenode);
|
||||
auto value = GetValueNode(valuenode);
|
||||
if (value != nullptr) {
|
||||
// check whether the valuenode is primitive
|
||||
if (value->isa<Primitive>()) {
|
||||
return value->cast<PrimitivePtr>()->name();
|
||||
|
|
|
@ -38,19 +38,14 @@ bool AnfUtils::IsNodeOutputDynamicShape(const CNodePtr &node) {
|
|||
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
if (base_shape->isa<abstract::Shape>()) {
|
||||
if (IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
|
||||
return true;
|
||||
}
|
||||
if (base_shape->isa<abstract::Shape>() && IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
|
||||
return true;
|
||||
} else if (base_shape->isa<abstract::TupleShape>()) {
|
||||
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||
for (size_t i = 0; i < tuple_shape->size(); i++) {
|
||||
auto b_shape = (*tuple_shape)[i];
|
||||
if (!b_shape->isa<abstract::Shape>()) {
|
||||
continue;
|
||||
}
|
||||
if (IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
|
||||
if (b_shape->isa<abstract::Shape>() && IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
|
||||
std::ostringstream oss;
|
||||
bool begin = true;
|
||||
int cnt = 0;
|
||||
size_t cnt = 0;
|
||||
// write 'Tuple[Bool, Bool, Bool, Int, Float, Float]' as 'Tuple[Bool...3, Int, Float...2]'
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
TypePtr elem = elements[i];
|
||||
|
@ -95,25 +95,13 @@ Class::Class(const Named &tag, const ClassAttrVector &attributes,
|
|||
const std::unordered_map<std::string, ValuePtr> &methods)
|
||||
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
|
||||
|
||||
std::string List::ToString() const {
|
||||
std::string List::DumpContent(bool is_dumptext) const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "List";
|
||||
} else {
|
||||
buffer << "List[";
|
||||
buffer << DumpTypeVector(elements_, false);
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string List::DumpText() const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "List";
|
||||
} else {
|
||||
buffer << "List[";
|
||||
buffer << DumpTypeVector(elements_, true);
|
||||
buffer << DumpTypeVector(elements_, is_dumptext);
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
|
@ -133,27 +121,7 @@ TypePtr Class::DeepCopy() const {
|
|||
}
|
||||
}
|
||||
|
||||
std::string Class::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "cls";
|
||||
} else {
|
||||
bool begin = true;
|
||||
buffer << "cls." << tag_ << "[";
|
||||
for (auto &attr : attributes_) {
|
||||
if (!begin) {
|
||||
buffer << ", ";
|
||||
} else {
|
||||
begin = false;
|
||||
}
|
||||
buffer << attr.first << ":" << attr.second->ToString();
|
||||
}
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Class::DumpText() const {
|
||||
std::string Class::DumpContent(bool is_dumptext) const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "Cls";
|
||||
|
@ -166,7 +134,8 @@ std::string Class::DumpText() const {
|
|||
} else {
|
||||
begin = false;
|
||||
}
|
||||
buffer << attr.first << ":" << attr.second->DumpText();
|
||||
auto sub_content = is_dumptext ? attr.second->DumpText() : attr.second->ToString();
|
||||
buffer << attr.first << ":" << sub_content;
|
||||
}
|
||||
buffer << "]";
|
||||
}
|
||||
|
@ -208,25 +177,13 @@ const TypePtr Tuple::operator[](std::size_t dim) const {
|
|||
return elements_[dim];
|
||||
}
|
||||
|
||||
std::string Tuple::ToString() const {
|
||||
std::string Tuple::DumpContent(bool is_dumptext) const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "Tuple";
|
||||
} else {
|
||||
buffer << "Tuple[";
|
||||
buffer << DumpTypeVector(elements_, false);
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Tuple::DumpText() const {
|
||||
std::ostringstream buffer;
|
||||
if (IsGeneric()) {
|
||||
buffer << "Tuple";
|
||||
} else {
|
||||
buffer << "Tuple[";
|
||||
buffer << DumpTypeVector(elements_, true);
|
||||
buffer << DumpTypeVector(elements_, is_dumptext);
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
|
@ -252,7 +209,7 @@ std::string DumpKeyVector(std::vector<std::string> keys) {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Dictionary::ToString() const {
|
||||
std::string Dictionary::DumpContent(bool) const {
|
||||
std::ostringstream buffer;
|
||||
std::vector<std::string> keys;
|
||||
std::vector<TypePtr> values;
|
||||
|
@ -271,8 +228,6 @@ std::string Dictionary::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Dictionary::DumpText() const { return ToString(); }
|
||||
|
||||
bool Dictionary::operator==(const mindspore::Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
|
|
|
@ -54,11 +54,12 @@ class MS_CORE_API List : public Object {
|
|||
bool operator==(const Type &other) const override;
|
||||
std::size_t size() const { return elements_.size(); }
|
||||
TypePtrList elements() const { return elements_; }
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override { return "list_"; }
|
||||
std::string DumpText() const override;
|
||||
std::string ToString() const override { return DumpContent(false); }
|
||||
std::string DumpText() const override { return DumpContent(true); };
|
||||
|
||||
private:
|
||||
std::string DumpContent(bool is_dumptext) const;
|
||||
TypePtrList elements_;
|
||||
};
|
||||
using ListPtr = std::shared_ptr<List>;
|
||||
|
@ -76,8 +77,8 @@ class MS_CORE_API Class : public Object {
|
|||
|
||||
bool operator==(const Type &other) const override;
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string DumpText() const override;
|
||||
std::string ToString() const override { return DumpContent(false); }
|
||||
std::string DumpText() const override { return DumpContent(true); };
|
||||
void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
|
||||
|
||||
Named tag() { return tag_; }
|
||||
|
@ -88,6 +89,7 @@ class MS_CORE_API Class : public Object {
|
|||
ClassAttrVector attributes_;
|
||||
|
||||
private:
|
||||
std::string DumpContent(bool is_dumptext) const;
|
||||
Named tag_;
|
||||
std::unordered_map<std::string, ValuePtr> methods_;
|
||||
// For AbstractClass build value
|
||||
|
@ -111,9 +113,9 @@ class MS_CORE_API Tuple : public Object {
|
|||
TypeId generic_type_id() const override { return kObjectTypeTuple; }
|
||||
TypePtr DeepCopy() const override;
|
||||
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override { return "tuple_"; }
|
||||
std::string DumpText() const override;
|
||||
std::string ToString() const override { return DumpContent(false); }
|
||||
std::string DumpText() const override { return DumpContent(true); };
|
||||
const TypePtr operator[](size_t dim) const;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
|
@ -121,6 +123,7 @@ class MS_CORE_API Tuple : public Object {
|
|||
std::size_t size() const { return elements_.size(); }
|
||||
|
||||
private:
|
||||
std::string DumpContent(bool is_dumptext) const;
|
||||
TypePtrList elements_;
|
||||
};
|
||||
using TuplePtr = std::shared_ptr<Tuple>;
|
||||
|
@ -138,10 +141,11 @@ class MS_CORE_API Dictionary : public Object {
|
|||
|
||||
bool operator==(const Type &other) const override;
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string DumpText() const override;
|
||||
std::string ToString() const override { return DumpContent(false); }
|
||||
std::string DumpText() const override { return DumpContent(true); };
|
||||
|
||||
private:
|
||||
std::string DumpContent(bool is_dumptext) const;
|
||||
std::vector<std::pair<std::string, TypePtr>> key_values_;
|
||||
};
|
||||
using DictionaryPtr = std::shared_ptr<Dictionary>;
|
||||
|
|
|
@ -30,25 +30,28 @@ bool Number::operator==(const Type &other) const {
|
|||
}
|
||||
|
||||
Int::Int(const int nbits) : Number(IntBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
if (nbits != static_cast<int>(BitsNum::eBits8) && nbits != static_cast<int>(BitsNum::eBits16) &&
|
||||
nbits != static_cast<int>(BitsNum::eBits32) && nbits != static_cast<int>(BitsNum::eBits64)) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
UInt::UInt(const int nbits) : Number(UIntBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
if (nbits != static_cast<int>(BitsNum::eBits8) && nbits != static_cast<int>(BitsNum::eBits16) &&
|
||||
nbits != static_cast<int>(BitsNum::eBits32) && nbits != static_cast<int>(BitsNum::eBits64)) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
if (nbits != static_cast<int>(BitsNum::eBits16) && nbits != static_cast<int>(BitsNum::eBits32) &&
|
||||
nbits != static_cast<int>(BitsNum::eBits64)) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
Complex::Complex(const int nbits) : Number(ComplexBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 64 && nbits != 128) {
|
||||
if (nbits != static_cast<int>(BitsNum::eBits64) && nbits != static_cast<int>(BitsNum::eBits128)) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -171,23 +171,23 @@ class MS_CORE_API Complex : public Number {
|
|||
};
|
||||
|
||||
inline const TypePtr kBool = std::make_shared<Bool>();
|
||||
inline const TypePtr kInt8 = std::make_shared<Int>(8);
|
||||
inline const TypePtr kInt16 = std::make_shared<Int>(16);
|
||||
inline const TypePtr kInt32 = std::make_shared<Int>(32);
|
||||
inline const TypePtr kInt64 = std::make_shared<Int>(64);
|
||||
inline const TypePtr kUInt8 = std::make_shared<UInt>(8);
|
||||
inline const TypePtr kUInt16 = std::make_shared<UInt>(16);
|
||||
inline const TypePtr kUInt32 = std::make_shared<UInt>(32);
|
||||
inline const TypePtr kUInt64 = std::make_shared<UInt>(64);
|
||||
inline const TypePtr kFloat16 = std::make_shared<Float>(16);
|
||||
inline const TypePtr kFloat32 = std::make_shared<Float>(32);
|
||||
inline const TypePtr kFloat64 = std::make_shared<Float>(64);
|
||||
inline const TypePtr kInt8 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits8));
|
||||
inline const TypePtr kInt16 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits16));
|
||||
inline const TypePtr kInt32 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits32));
|
||||
inline const TypePtr kInt64 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits64));
|
||||
inline const TypePtr kUInt8 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits8));
|
||||
inline const TypePtr kUInt16 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits16));
|
||||
inline const TypePtr kUInt32 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits32));
|
||||
inline const TypePtr kUInt64 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits64));
|
||||
inline const TypePtr kFloat16 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits16));
|
||||
inline const TypePtr kFloat32 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits32));
|
||||
inline const TypePtr kFloat64 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits64));
|
||||
inline const TypePtr kInt = std::make_shared<Int>();
|
||||
inline const TypePtr kUInt = std::make_shared<UInt>();
|
||||
inline const TypePtr kFloat = std::make_shared<Float>();
|
||||
inline const TypePtr kNumber = std::make_shared<Number>();
|
||||
inline const TypePtr kComplex64 = std::make_shared<Complex>(64);
|
||||
inline const TypePtr kComplex128 = std::make_shared<Complex>(128);
|
||||
inline const TypePtr kComplex64 = std::make_shared<Complex>(static_cast<int>(BitsNum::eBits64));
|
||||
inline const TypePtr kComplex128 = std::make_shared<Complex>(static_cast<int>(BitsNum::eBits128));
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
|
||||
|
|
|
@ -82,13 +82,6 @@ static std::unordered_map<TypeId, std::string> g_type_2_lable{
|
|||
{kObjectTypeIOMonad, MS_TYPE2LABLE(kObjectTypeIOMonad)},
|
||||
{kMonadTypeEnd, MS_TYPE2LABLE(kMonadTypeEnd)}};
|
||||
|
||||
enum class BitsNum : int {
|
||||
eBits8 = 8,
|
||||
eBits16 = 16,
|
||||
eBits32 = 32,
|
||||
eBits64 = 64,
|
||||
eBits128 = 128,
|
||||
};
|
||||
TypeId IntBitsToTypeId(const int nbits) {
|
||||
switch (nbits) {
|
||||
case static_cast<int>(BitsNum::eBits8):
|
||||
|
|
|
@ -47,9 +47,16 @@ TypeId NormalizeTypeId(const TypeId type_id);
|
|||
bool IsSameObjectType(const Type &lhs, const Type &rhs);
|
||||
size_t GetTypeByte(const TypePtr &type_ptr);
|
||||
|
||||
enum class BitsNum : int {
|
||||
eBits8 = 8,
|
||||
eBits16 = 16,
|
||||
eBits32 = 32,
|
||||
eBits64 = 64,
|
||||
eBits128 = 128,
|
||||
};
|
||||
|
||||
// Base class for all types
|
||||
// forward declaration.
|
||||
|
||||
class MS_CORE_API Type : public Value {
|
||||
public:
|
||||
Type() : meta_type_(kMetaTypeType), is_generic_(true) {}
|
||||
|
|
|
@ -271,86 +271,46 @@ TypePtr FunctionStrToType(const std::string &type_name) {
|
|||
} // namespace
|
||||
|
||||
TypePtr GetTypeByFullString(const std::string &type_name) {
|
||||
if (type_name == "None") {
|
||||
return std::make_shared<TypeNone>();
|
||||
}
|
||||
if (type_name == "Ellipsis") {
|
||||
return std::make_shared<TypeEllipsis>();
|
||||
}
|
||||
if (type_name == "TypeType") {
|
||||
return std::make_shared<TypeType>();
|
||||
}
|
||||
if (type_name == "SymbolicKeyType") {
|
||||
return std::make_shared<SymbolicKeyType>();
|
||||
}
|
||||
if (type_name == "RefKeyType") {
|
||||
return std::make_shared<RefKeyType>();
|
||||
}
|
||||
if (type_name == "EnvType") {
|
||||
return std::make_shared<EnvType>();
|
||||
}
|
||||
if (type_name == "Number") {
|
||||
return std::make_shared<Number>();
|
||||
}
|
||||
if (type_name == "Bool") {
|
||||
return std::make_shared<Bool>();
|
||||
}
|
||||
if (type_name == "Slice") {
|
||||
return std::make_shared<Slice>();
|
||||
}
|
||||
if (type_name == "Dictionary") {
|
||||
return std::make_shared<Dictionary>();
|
||||
}
|
||||
if (type_name == "String") {
|
||||
return std::make_shared<String>();
|
||||
}
|
||||
if (type_name == "Problem") {
|
||||
return std::make_shared<Problem>();
|
||||
}
|
||||
if (type_name == "mstype") {
|
||||
return std::make_shared<TypeType>();
|
||||
}
|
||||
if (type_name == "UMonad") {
|
||||
return kUMonadType;
|
||||
}
|
||||
if (type_name == "IOMonad") {
|
||||
return kIOMonadType;
|
||||
}
|
||||
return nullptr;
|
||||
static std::map<std::string, TypePtr> type_map = {{"None", std::make_shared<TypeNone>()},
|
||||
{"Ellipsis", std::make_shared<TypeEllipsis>()},
|
||||
{"TypeType", std::make_shared<TypeType>()},
|
||||
{"SymbolicKeyType", std::make_shared<SymbolicKeyType>()},
|
||||
{"RefKeyType", std::make_shared<RefKeyType>()},
|
||||
{"EnvType", std::make_shared<EnvType>()},
|
||||
{"Number", std::make_shared<Number>()},
|
||||
{"Bool", std::make_shared<Bool>()},
|
||||
{"Slice", std::make_shared<Slice>()},
|
||||
{"Dictionary", std::make_shared<Dictionary>()},
|
||||
{"String", std::make_shared<String>()},
|
||||
{"Problem", std::make_shared<Problem>()},
|
||||
{"mstype", std::make_shared<TypeType>()},
|
||||
{"UMonad", kUMonadType},
|
||||
{"IOMonad", kIOMonadType}};
|
||||
|
||||
auto iter = type_map.find(type_name);
|
||||
return iter == type_map.end() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
TypePtr GetTypeByStringStarts(const std::string &type_name) {
|
||||
if (type_name.compare(0, strlen("Int"), "Int") == 0) {
|
||||
return StringToNumberType<Int>(type_name, "Int");
|
||||
}
|
||||
if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
|
||||
return StringToNumberType<UInt>(type_name, "UInt");
|
||||
}
|
||||
if (type_name.compare(0, strlen("Float"), "Float") == 0) {
|
||||
return StringToNumberType<Float>(type_name, "Float");
|
||||
}
|
||||
if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
|
||||
return TensorStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
|
||||
return UndeterminedStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) {
|
||||
return RowTensorStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
|
||||
return SparseTensorStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
return ListStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
return TupleStrToType(type_name);
|
||||
}
|
||||
if (type_name.compare(0, strlen("Function"), "Function") == 0) {
|
||||
return FunctionStrToType(type_name);
|
||||
}
|
||||
return nullptr;
|
||||
struct name_cmp {
|
||||
bool operator()(const std::string &l, const std::string &r) {
|
||||
auto cmp_len = std::min(l.length(), r.length());
|
||||
return r.compare(0, cmp_len, l, 0, cmp_len) < 0;
|
||||
}
|
||||
};
|
||||
static std::map<std::string, std::function<TypePtr(const std::string &type_name)>, name_cmp> type_map = {
|
||||
{"Int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "Int"); }},
|
||||
{"UInt", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "UInt"); }},
|
||||
{"Float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "Float"); }},
|
||||
{"Tensor", [](const std::string &type_name) -> TypePtr { return TensorStrToType(type_name); }},
|
||||
{"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
|
||||
{"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
|
||||
{"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
|
||||
{"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
|
||||
{"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
|
||||
{"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};
|
||||
auto iter = type_map.find(type_name);
|
||||
return iter == type_map.end() ? nullptr : iter->second(type_name);
|
||||
}
|
||||
|
||||
TypePtr StringToType(const std::string &type_name) {
|
||||
|
@ -375,17 +335,13 @@ bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
|||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
auto type_id = base_type->type_id();
|
||||
if (type_id == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
} else if (!(base_type->IsGeneric())) {
|
||||
return *(base_type) == *(x);
|
||||
} else if (base_type->type_id() == x->type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->generic_type_id()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->object_type()) {
|
||||
return true;
|
||||
} else if (base_type->type_id() == x->meta_type()) {
|
||||
} else if (type_id == x->type_id() || type_id == x->generic_type_id() || type_id == x->object_type() ||
|
||||
type_id == x->meta_type()) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
|
|
|
@ -37,16 +37,13 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType
|
|||
|
||||
Primitive::Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs)
|
||||
: Named(name),
|
||||
attrs_(attrs),
|
||||
is_base_(true),
|
||||
has_signature_(false),
|
||||
prim_type_(kPrimTypeBuiltIn),
|
||||
record_evaluate_add_attr_(false),
|
||||
is_const_prim_(false),
|
||||
id_(MakeId()) {
|
||||
for (auto &attr : attrs) {
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
id_(MakeId()) {}
|
||||
|
||||
Primitive::Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
|
|
Loading…
Reference in New Issue