Code check

This commit is contained in:
zhangzhaoju 2021-09-26 19:55:52 +08:00
parent 830f4cf441
commit d94c1410b4
10 changed files with 103 additions and 193 deletions

View File

@ -63,6 +63,9 @@ void CNode::add_input(const AnfNodePtr &input) {
}
void CNode::set_input(size_t i, const AnfNodePtr &new_input) {
if (i >= inputs_.size()) {
MS_LOG(EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
}
inputs_[i] = new_input;
input_tensor_num_ = -1;
}
@ -74,7 +77,7 @@ void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
const AnfNodePtr &CNode::input(size_t i) const {
if (i >= inputs_.size()) {
MS_LOG(EXCEPTION) << "i:" << i << "out of range:" << inputs_.size() << ",cnode:" << DebugString();
MS_LOG(EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
}
return inputs_.at(i);
}
@ -130,7 +133,7 @@ ParamInfoPtr Parameter::param_info() const {
std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) {
return value_->cast<FuncGraphPtr>()->ToString();
return value_->ToString();
}
std::ostringstream buffer;
buffer << AnfNode::ToString();
@ -168,10 +171,7 @@ bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
}
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto cnode = dyn_cast<CNode>(node);
if (cnode != nullptr) {
if (cnode->size() > 0) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -187,8 +187,8 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
}
AnfNodePtr valuenode = cnode->input(0);
if (valuenode->isa<ValueNode>()) {
auto value = GetValueNode(valuenode);
auto value = GetValueNode(valuenode);
if (value != nullptr) {
// check whether the valuenode is primitive
if (value->isa<Primitive>()) {
return value->cast<PrimitivePtr>()->name();

View File

@ -38,19 +38,14 @@ bool AnfUtils::IsNodeOutputDynamicShape(const CNodePtr &node) {
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
return true;
}
if (base_shape->isa<abstract::Shape>() && IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
return true;
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
for (size_t i = 0; i < tuple_shape->size(); i++) {
auto b_shape = (*tuple_shape)[i];
if (!b_shape->isa<abstract::Shape>()) {
continue;
}
if (IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
if (b_shape->isa<abstract::Shape>() && IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
return true;
}
}

View File

@ -24,7 +24,7 @@ namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
std::ostringstream oss;
bool begin = true;
int cnt = 0;
size_t cnt = 0;
// write 'Tuple[Bool, Bool, Bool, Int, Float, Float]' as 'Tuple[Bool...3, Int, Float...2]'
for (size_t i = 0; i < elements.size(); ++i) {
TypePtr elem = elements[i];
@ -95,25 +95,13 @@ Class::Class(const Named &tag, const ClassAttrVector &attributes,
const std::unordered_map<std::string, ValuePtr> &methods)
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
std::string List::ToString() const {
std::string List::DumpContent(bool is_dumptext) const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "List";
} else {
buffer << "List[";
buffer << DumpTypeVector(elements_, false);
buffer << "]";
}
return buffer.str();
}
std::string List::DumpText() const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "List";
} else {
buffer << "List[";
buffer << DumpTypeVector(elements_, true);
buffer << DumpTypeVector(elements_, is_dumptext);
buffer << "]";
}
return buffer.str();
@ -133,27 +121,7 @@ TypePtr Class::DeepCopy() const {
}
}
std::string Class::ToString() const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "cls";
} else {
bool begin = true;
buffer << "cls." << tag_ << "[";
for (auto &attr : attributes_) {
if (!begin) {
buffer << ", ";
} else {
begin = false;
}
buffer << attr.first << ":" << attr.second->ToString();
}
buffer << "]";
}
return buffer.str();
}
std::string Class::DumpText() const {
std::string Class::DumpContent(bool is_dumptext) const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "Cls";
@ -166,7 +134,8 @@ std::string Class::DumpText() const {
} else {
begin = false;
}
buffer << attr.first << ":" << attr.second->DumpText();
auto sub_content = is_dumptext ? attr.second->DumpText() : attr.second->ToString();
buffer << attr.first << ":" << sub_content;
}
buffer << "]";
}
@ -208,25 +177,13 @@ const TypePtr Tuple::operator[](std::size_t dim) const {
return elements_[dim];
}
std::string Tuple::ToString() const {
std::string Tuple::DumpContent(bool is_dumptext) const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "Tuple";
} else {
buffer << "Tuple[";
buffer << DumpTypeVector(elements_, false);
buffer << "]";
}
return buffer.str();
}
std::string Tuple::DumpText() const {
std::ostringstream buffer;
if (IsGeneric()) {
buffer << "Tuple";
} else {
buffer << "Tuple[";
buffer << DumpTypeVector(elements_, true);
buffer << DumpTypeVector(elements_, is_dumptext);
buffer << "]";
}
return buffer.str();
@ -252,7 +209,7 @@ std::string DumpKeyVector(std::vector<std::string> keys) {
return buffer.str();
}
std::string Dictionary::ToString() const {
std::string Dictionary::DumpContent(bool) const {
std::ostringstream buffer;
std::vector<std::string> keys;
std::vector<TypePtr> values;
@ -271,8 +228,6 @@ std::string Dictionary::ToString() const {
return buffer.str();
}
std::string Dictionary::DumpText() const { return ToString(); }
bool Dictionary::operator==(const mindspore::Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;

View File

@ -54,11 +54,12 @@ class MS_CORE_API List : public Object {
bool operator==(const Type &other) const override;
std::size_t size() const { return elements_.size(); }
TypePtrList elements() const { return elements_; }
std::string ToString() const override;
std::string ToReprString() const override { return "list_"; }
std::string DumpText() const override;
std::string ToString() const override { return DumpContent(false); }
std::string DumpText() const override { return DumpContent(true); };
private:
std::string DumpContent(bool is_dumptext) const;
TypePtrList elements_;
};
using ListPtr = std::shared_ptr<List>;
@ -76,8 +77,8 @@ class MS_CORE_API Class : public Object {
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
std::string ToString() const override { return DumpContent(false); }
std::string DumpText() const override { return DumpContent(true); };
void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
Named tag() { return tag_; }
@ -88,6 +89,7 @@ class MS_CORE_API Class : public Object {
ClassAttrVector attributes_;
private:
std::string DumpContent(bool is_dumptext) const;
Named tag_;
std::unordered_map<std::string, ValuePtr> methods_;
// For AbstractClass build value
@ -111,9 +113,9 @@ class MS_CORE_API Tuple : public Object {
TypeId generic_type_id() const override { return kObjectTypeTuple; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override { return "tuple_"; }
std::string DumpText() const override;
std::string ToString() const override { return DumpContent(false); }
std::string DumpText() const override { return DumpContent(true); };
const TypePtr operator[](size_t dim) const;
bool operator==(const Type &other) const override;
@ -121,6 +123,7 @@ class MS_CORE_API Tuple : public Object {
std::size_t size() const { return elements_.size(); }
private:
std::string DumpContent(bool is_dumptext) const;
TypePtrList elements_;
};
using TuplePtr = std::shared_ptr<Tuple>;
@ -138,10 +141,11 @@ class MS_CORE_API Dictionary : public Object {
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
std::string ToString() const override { return DumpContent(false); }
std::string DumpText() const override { return DumpContent(true); };
private:
std::string DumpContent(bool is_dumptext) const;
std::vector<std::pair<std::string, TypePtr>> key_values_;
};
using DictionaryPtr = std::shared_ptr<Dictionary>;

View File

@ -30,25 +30,28 @@ bool Number::operator==(const Type &other) const {
}
Int::Int(const int nbits) : Number(IntBitsToTypeId(nbits), nbits, false) {
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
if (nbits != static_cast<int>(BitsNum::eBits8) && nbits != static_cast<int>(BitsNum::eBits16) &&
nbits != static_cast<int>(BitsNum::eBits32) && nbits != static_cast<int>(BitsNum::eBits64)) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}
UInt::UInt(const int nbits) : Number(UIntBitsToTypeId(nbits), nbits, false) {
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
if (nbits != static_cast<int>(BitsNum::eBits8) && nbits != static_cast<int>(BitsNum::eBits16) &&
nbits != static_cast<int>(BitsNum::eBits32) && nbits != static_cast<int>(BitsNum::eBits64)) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}
Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
if (nbits != 16 && nbits != 32 && nbits != 64) {
if (nbits != static_cast<int>(BitsNum::eBits16) && nbits != static_cast<int>(BitsNum::eBits32) &&
nbits != static_cast<int>(BitsNum::eBits64)) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}
Complex::Complex(const int nbits) : Number(ComplexBitsToTypeId(nbits), nbits, false) {
if (nbits != 64 && nbits != 128) {
if (nbits != static_cast<int>(BitsNum::eBits64) && nbits != static_cast<int>(BitsNum::eBits128)) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}

View File

@ -171,23 +171,23 @@ class MS_CORE_API Complex : public Number {
};
inline const TypePtr kBool = std::make_shared<Bool>();
inline const TypePtr kInt8 = std::make_shared<Int>(8);
inline const TypePtr kInt16 = std::make_shared<Int>(16);
inline const TypePtr kInt32 = std::make_shared<Int>(32);
inline const TypePtr kInt64 = std::make_shared<Int>(64);
inline const TypePtr kUInt8 = std::make_shared<UInt>(8);
inline const TypePtr kUInt16 = std::make_shared<UInt>(16);
inline const TypePtr kUInt32 = std::make_shared<UInt>(32);
inline const TypePtr kUInt64 = std::make_shared<UInt>(64);
inline const TypePtr kFloat16 = std::make_shared<Float>(16);
inline const TypePtr kFloat32 = std::make_shared<Float>(32);
inline const TypePtr kFloat64 = std::make_shared<Float>(64);
inline const TypePtr kInt8 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits8));
inline const TypePtr kInt16 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits16));
inline const TypePtr kInt32 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits32));
inline const TypePtr kInt64 = std::make_shared<Int>(static_cast<int>(BitsNum::eBits64));
inline const TypePtr kUInt8 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits8));
inline const TypePtr kUInt16 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits16));
inline const TypePtr kUInt32 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits32));
inline const TypePtr kUInt64 = std::make_shared<UInt>(static_cast<int>(BitsNum::eBits64));
inline const TypePtr kFloat16 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits16));
inline const TypePtr kFloat32 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits32));
inline const TypePtr kFloat64 = std::make_shared<Float>(static_cast<int>(BitsNum::eBits64));
inline const TypePtr kInt = std::make_shared<Int>();
inline const TypePtr kUInt = std::make_shared<UInt>();
inline const TypePtr kFloat = std::make_shared<Float>();
inline const TypePtr kNumber = std::make_shared<Number>();
inline const TypePtr kComplex64 = std::make_shared<Complex>(64);
inline const TypePtr kComplex128 = std::make_shared<Complex>(128);
inline const TypePtr kComplex64 = std::make_shared<Complex>(static_cast<int>(BitsNum::eBits64));
inline const TypePtr kComplex128 = std::make_shared<Complex>(static_cast<int>(BitsNum::eBits128));
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_

View File

@ -82,13 +82,6 @@ static std::unordered_map<TypeId, std::string> g_type_2_lable{
{kObjectTypeIOMonad, MS_TYPE2LABLE(kObjectTypeIOMonad)},
{kMonadTypeEnd, MS_TYPE2LABLE(kMonadTypeEnd)}};
enum class BitsNum : int {
eBits8 = 8,
eBits16 = 16,
eBits32 = 32,
eBits64 = 64,
eBits128 = 128,
};
TypeId IntBitsToTypeId(const int nbits) {
switch (nbits) {
case static_cast<int>(BitsNum::eBits8):

View File

@ -47,9 +47,16 @@ TypeId NormalizeTypeId(const TypeId type_id);
bool IsSameObjectType(const Type &lhs, const Type &rhs);
size_t GetTypeByte(const TypePtr &type_ptr);
enum class BitsNum : int {
eBits8 = 8,
eBits16 = 16,
eBits32 = 32,
eBits64 = 64,
eBits128 = 128,
};
// Base class for all types
// forward declaration.
class MS_CORE_API Type : public Value {
public:
Type() : meta_type_(kMetaTypeType), is_generic_(true) {}

View File

@ -271,86 +271,46 @@ TypePtr FunctionStrToType(const std::string &type_name) {
} // namespace
TypePtr GetTypeByFullString(const std::string &type_name) {
if (type_name == "None") {
return std::make_shared<TypeNone>();
}
if (type_name == "Ellipsis") {
return std::make_shared<TypeEllipsis>();
}
if (type_name == "TypeType") {
return std::make_shared<TypeType>();
}
if (type_name == "SymbolicKeyType") {
return std::make_shared<SymbolicKeyType>();
}
if (type_name == "RefKeyType") {
return std::make_shared<RefKeyType>();
}
if (type_name == "EnvType") {
return std::make_shared<EnvType>();
}
if (type_name == "Number") {
return std::make_shared<Number>();
}
if (type_name == "Bool") {
return std::make_shared<Bool>();
}
if (type_name == "Slice") {
return std::make_shared<Slice>();
}
if (type_name == "Dictionary") {
return std::make_shared<Dictionary>();
}
if (type_name == "String") {
return std::make_shared<String>();
}
if (type_name == "Problem") {
return std::make_shared<Problem>();
}
if (type_name == "mstype") {
return std::make_shared<TypeType>();
}
if (type_name == "UMonad") {
return kUMonadType;
}
if (type_name == "IOMonad") {
return kIOMonadType;
}
return nullptr;
static std::map<std::string, TypePtr> type_map = {{"None", std::make_shared<TypeNone>()},
{"Ellipsis", std::make_shared<TypeEllipsis>()},
{"TypeType", std::make_shared<TypeType>()},
{"SymbolicKeyType", std::make_shared<SymbolicKeyType>()},
{"RefKeyType", std::make_shared<RefKeyType>()},
{"EnvType", std::make_shared<EnvType>()},
{"Number", std::make_shared<Number>()},
{"Bool", std::make_shared<Bool>()},
{"Slice", std::make_shared<Slice>()},
{"Dictionary", std::make_shared<Dictionary>()},
{"String", std::make_shared<String>()},
{"Problem", std::make_shared<Problem>()},
{"mstype", std::make_shared<TypeType>()},
{"UMonad", kUMonadType},
{"IOMonad", kIOMonadType}};
auto iter = type_map.find(type_name);
return iter == type_map.end() ? nullptr : iter->second;
}
TypePtr GetTypeByStringStarts(const std::string &type_name) {
if (type_name.compare(0, strlen("Int"), "Int") == 0) {
return StringToNumberType<Int>(type_name, "Int");
}
if (type_name.compare(0, strlen("UInt"), "UInt") == 0) {
return StringToNumberType<UInt>(type_name, "UInt");
}
if (type_name.compare(0, strlen("Float"), "Float") == 0) {
return StringToNumberType<Float>(type_name, "Float");
}
if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
return TensorStrToType(type_name);
}
if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
return UndeterminedStrToType(type_name);
}
if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) {
return RowTensorStrToType(type_name);
}
if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
return SparseTensorStrToType(type_name);
}
if (type_name.compare(0, strlen("List"), "List") == 0) {
return ListStrToType(type_name);
}
if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
return TupleStrToType(type_name);
}
if (type_name.compare(0, strlen("Function"), "Function") == 0) {
return FunctionStrToType(type_name);
}
return nullptr;
struct name_cmp {
bool operator()(const std::string &l, const std::string &r) {
auto cmp_len = std::min(l.length(), r.length());
return r.compare(0, cmp_len, l, 0, cmp_len) < 0;
}
};
static std::map<std::string, std::function<TypePtr(const std::string &type_name)>, name_cmp> type_map = {
{"Int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "Int"); }},
{"UInt", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "UInt"); }},
{"Float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "Float"); }},
{"Tensor", [](const std::string &type_name) -> TypePtr { return TensorStrToType(type_name); }},
{"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
{"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
{"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
{"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
{"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
{"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};
auto iter = type_map.find(type_name);
return iter == type_map.end() ? nullptr : iter->second(type_name);
}
TypePtr StringToType(const std::string &type_name) {
@ -375,17 +335,13 @@ bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
}
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
auto type_id = base_type->type_id();
if (type_id == kTypeUnknown || x->type_id() == kTypeUnknown) {
return false;
} else if (!(base_type->IsGeneric())) {
return *(base_type) == *(x);
} else if (base_type->type_id() == x->type_id()) {
return true;
} else if (base_type->type_id() == x->generic_type_id()) {
return true;
} else if (base_type->type_id() == x->object_type()) {
return true;
} else if (base_type->type_id() == x->meta_type()) {
} else if (type_id == x->type_id() || type_id == x->generic_type_id() || type_id == x->object_type() ||
type_id == x->meta_type()) {
return true;
} else {
return false;

View File

@ -37,16 +37,13 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType
Primitive::Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs)
: Named(name),
attrs_(attrs),
is_base_(true),
has_signature_(false),
prim_type_(kPrimTypeBuiltIn),
record_evaluate_add_attr_(false),
is_const_prim_(false),
id_(MakeId()) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
}
id_(MakeId()) {}
Primitive::Primitive(const Primitive &prim)
: Named(prim),