change SparseTensorType constructor: use TypePtrList

This commit is contained in:
wangrao124 2022-06-15 14:41:09 +08:00
parent 00b19b8bf2
commit 1a49773032
8 changed files with 209 additions and 169 deletions

View File

@ -157,10 +157,10 @@ REGISTER_PYBIND_DEFINE(
.def_property_readonly("ElementType", &RowTensorType::element, "Get the RowTensorType's element type.");
(void)py::class_<COOTensorType, Type, std::shared_ptr<COOTensorType>>(m_sub, "COOTensorType")
.def(py::init())
.def_property_readonly("ElementType", &COOTensorType::element, "Get the COOTensorType's element type.");
.def_property_readonly("ElementType", &COOTensorType::element_type, "Get the COOTensorType's element type.");
(void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType")
.def(py::init())
.def_property_readonly("ElementType", &CSRTensorType::element, "Get the CSRTensorType's element type.");
.def_property_readonly("ElementType", &CSRTensorType::element_type, "Get the CSRTensorType's element type.");
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
.def(py::init());
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")

View File

@ -658,6 +658,13 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
MS_EXCEPTION_IF_NULL(type_ptr);
}
if (type_ptr->isa<SparseTensorType>()) {
auto tensor_ptr = type_ptr->cast<SparseTensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
type_ptr = (*tensor_ptr)[output_idx];
MS_EXCEPTION_IF_NULL(type_ptr);
}
if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
@ -665,15 +672,6 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
}
if (type_ptr->isa<CSRTensorType>()) {
auto tensor_ptr = type_ptr->cast<CSRTensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
}
return type_ptr->type_id();
}

View File

@ -1431,7 +1431,7 @@ const TypeId AbstractSparseTensor::GetTypeIdAt(size_t index) const {
if (index < shape_idx) {
auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
MS_EXCEPTION_IF_NULL(abs_tensor);
return abs_tensor->BuildType()->type_id();
return abs_tensor->element()->BuildType()->type_id();
} else if (index < shape_idx + shape()->size()) {
return shape()->elements()[index - shape_idx]->BuildType()->type_id();
}
@ -1544,9 +1544,13 @@ std::string AbstractRowTensor::ToString() const {
// COOTensor
TypePtr AbstractCOOTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(indices());
MS_EXCEPTION_IF_NULL(values());
TypePtr element_type = values()->element()->BuildType();
return std::make_shared<COOTensorType>(element_type);
MS_EXCEPTION_IF_NULL(shape());
TypePtrList elements{indices()->element()->BuildType(), values()->element()->BuildType()};
std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
[](AbstractBasePtr p) { return p->BuildType(); });
return std::make_shared<COOTensorType>(elements);
}
AbstractBasePtr AbstractCOOTensor::Clone() const {
@ -1589,9 +1593,15 @@ const AbstractTensorPtr AbstractCOOTensor::values() const {
// CSRTensor
TypePtr AbstractCSRTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(indptr());
MS_EXCEPTION_IF_NULL(indices());
MS_EXCEPTION_IF_NULL(values());
TypePtr element_type = values()->element()->BuildType();
return std::make_shared<CSRTensorType>(element_type);
MS_EXCEPTION_IF_NULL(shape());
TypePtrList elements{indptr()->element()->BuildType(), indices()->element()->BuildType(),
values()->element()->BuildType()};
std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
[](AbstractBasePtr p) { return p->BuildType(); });
return std::make_shared<CSRTensorType>(elements);
}
AbstractBasePtr AbstractCSRTensor::Clone() const {

View File

@ -105,11 +105,76 @@ bool TensorType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type;
}
std::string SparseTensorType::ElementsDtypeStr(const StringType str_type) const {
std::ostringstream oss;
for (const TypePtr &elem : elements_) {
if (str_type == kToString) {
oss << elem->ToString();
} else if (str_type == kDumpText) {
oss << elem->DumpText();
} else if (str_type == kReprString) {
oss << elem->ToReprString();
}
oss << ",";
}
return oss.str();
}
std::string SparseTensorType::ToString() const {
if (elements_.empty()) {
return GetSparseTensorTypeName();
}
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kToString) + "]";
}
std::string SparseTensorType::DumpText() const {
if (elements_.empty()) {
return GetSparseTensorTypeName();
}
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kDumpText) + "]";
}
std::string SparseTensorType::ToReprString() const {
if (elements_.empty()) {
return GetSparseTensorTypeName();
}
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kReprString) + "]";
}
TypePtr SparseTensorType::DeepCopy() const {
if (element_type_ == nullptr || IsGeneric()) {
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
TypePtrList new_elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(new_elements),
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<SparseTensorType>(new_elements);
return copy;
}
const TypePtr SparseTensorType::operator[](std::size_t dim) const {
if (dim >= size()) {
MS_LOG(EXCEPTION) << "Index " << dim << " is out range of the SparseTensorType size " << size() << ".";
}
return elements_[dim];
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const SparseTensorType &other_sparse = static_cast<const SparseTensorType &>(other);
if (!other_sparse.elements().empty()) {
if (elements_.size() != other_sparse.size()) {
return false;
}
for (size_t i = 0; i < elements_.size(); ++i) {
if (*elements_[i] != *other_sparse.elements()[i]) {
return false;
}
}
}
return true;
}
TypePtr RowTensorType::DeepCopy() const {
@ -155,86 +220,56 @@ bool RowTensorType::operator==(const Type &other) const {
}
TypePtr COOTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<COOTensorType>();
}
return std::make_shared<COOTensorType>(element_type_->DeepCopy());
}
std::string COOTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "COOTensor";
}
return "COOTensor[" + element_type_->ToReprString() + "]";
}
std::string COOTensorType::ToString() const {
if (element_type_ == nullptr) {
return "COOTensor";
}
return "COOTensor[" + element_type_->ToString() + "]";
}
std::string COOTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "COOTensor";
}
return "COOTensor[" + element_type_->DumpText() + "]";
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<COOTensorType>(elements);
return copy;
}
bool COOTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const COOTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
const COOTensorType &other_coo = static_cast<const COOTensorType &>(other);
if (elements_.size() != other_coo.size()) {
return false;
}
return *element_type_ == *other_elem_type;
for (size_t i = 0; i < elements_.size(); ++i) {
if (*elements_[i] != *other_coo.elements()[i]) {
return false;
}
}
return true;
}
TypePtr CSRTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<CSRTensorType>();
}
return std::make_shared<CSRTensorType>(element_type_->DeepCopy());
}
std::string CSRTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->ToReprString() + "]";
}
std::string CSRTensorType::ToString() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->ToString() + "]";
}
std::string CSRTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "CSRTensor";
}
return "CSRTensor[" + element_type_->DumpText() + "]";
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<CSRTensorType>(elements);
return copy;
}
bool CSRTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const CSRTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
const CSRTensorType &other_csr = static_cast<const CSRTensorType &>(other);
if (elements_.size() != other_csr.size()) {
return false;
}
return *element_type_ == *other_elem_type;
for (size_t i = 0; i < elements_.size(); ++i) {
if (*elements_[i] != *other_csr.elements()[i]) {
return false;
}
}
return true;
}
} // namespace mindspore

View File

@ -117,47 +117,45 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
/// \brief SparseTensorType is the base type for all sparse tensors.
class MS_CORE_API SparseTensorType : public Object {
public:
/// \brief Default constructor for SparseTensorType.
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
/// \brief Constructor for SparseTensorType.
///
/// \param[in] object_type The type id of derived class type.
explicit SparseTensorType(const TypeId object_type) : Object(object_type, kObjectTypeUndeterminedType) {}
/// \brief Constructor for SparseTensorType.
///
/// \param[in] ele The element of SparseTensorType.
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
explicit SparseTensorType(const TypePtrList &objs)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
/// \brief Constructor for SparseTensorType.
///
/// \param[in] object_type The type id of derived class type.
/// \param[in] ele The element of SparseTensorType.
explicit SparseTensorType(const TypeId object_type, const TypePtr &ele)
: Object(object_type, kObjectTypeUndeterminedType, false), element_type_(ele) {}
SparseTensorType(const TypeId object_type, const TypePtrList &objs)
: Object(object_type, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
/// \brief Destructor of SparseTensorType.
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
enum StringType : int { kToString = 0, kDumpText, kReprString };
virtual std::string GetSparseTensorTypeName() const { return "SparseTensorType"; }
virtual size_t GetElementIndex() { return 0; }
virtual TypePtr element_type() {
if (elements_.empty()) {
return nullptr;
}
return elements_[GetElementIndex()];
}
std::string ElementsDtypeStr(const StringType str_type) const;
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
/// \brief Get the element of SparseTensorType object.
///
/// \return The element of SparseTensorType object.
const TypePtr element() const { return element_type_; }
const TypePtr operator[](size_t dim) const;
bool operator==(const Type &other) const;
TypePtrList elements() const { return elements_; }
/// \brief Set the element of SparseTensorType object.
///
/// \param[in] element_type Define the element type to be set.
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::size_t size() const { return elements_.size(); }
std::string ToString() const;
std::string ToReprString() const;
std::string DumpText() const;
TypePtr DeepCopy() const;
private:
TypePtr element_type_;
TypePtrList elements_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
@ -209,32 +207,22 @@ class MS_CORE_API COOTensorType final : public SparseTensorType {
/// \brief Constructor for COOTensorType.
///
/// \param[in] ele The element of COOTensorType.
explicit COOTensorType(const TypePtr &ele) : SparseTensorType(kObjectTypeCOOTensorType, ele), element_type_(ele) {}
explicit COOTensorType(const TypePtrList &obj)
: SparseTensorType(kObjectTypeCOOTensorType, obj), elements_(obj.begin(), obj.end()) {}
/// \brief Destructor of COOTensorType.
~COOTensorType() override = default;
MS_DECLARE_PARENT(COOTensorType, SparseTensorType)
std::string GetSparseTensorTypeName() const override { return "COOTensor"; }
size_t GetElementIndex() override { return 1; }
TypeId generic_type_id() const override { return kObjectTypeCOOTensorType; }
/// \brief Get the element of COOTensorType object.
///
/// \return The element of COOTensorType object.
const TypePtr element() const { return element_type_; }
/// \brief Set the element of COOTensorType object.
///
/// \param[in] element_type Define the element type to be set.
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
TypePtrList elements_;
};
using COOTensorTypePtr = std::shared_ptr<COOTensorType>;
@ -247,32 +235,21 @@ class MS_CORE_API CSRTensorType : public SparseTensorType {
/// \brief Constructor for CSRTensorType.
///
/// \param[in] ele The element of CSRTensorType.
explicit CSRTensorType(const TypePtr &ele) : SparseTensorType(kObjectTypeCSRTensorType, ele), element_type_(ele) {}
explicit CSRTensorType(const TypePtrList &obj)
: SparseTensorType(kObjectTypeCSRTensorType, obj), elements_(obj.begin(), obj.end()) {}
/// \brief Destructor of CSRTensorType.
~CSRTensorType() override = default;
MS_DECLARE_PARENT(CSRTensorType, SparseTensorType)
std::string GetSparseTensorTypeName() const override { return "CSRTensor"; }
size_t GetElementIndex() override { return 2; }
TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; }
/// \brief Get the element of CSRTensorType object.
///
/// \return The element of CSRTensorType object.
const TypePtr element() const { return element_type_; }
/// \brief Set the element of CSRTensorType object.
///
/// \param[in] element_type Define the element type to be set.
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
TypePtrList elements_;
};
using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>;
} // namespace mindspore

View File

@ -233,37 +233,53 @@ TypePtr RowTensorStrToType(const std::string &type_name) {
}
TypePtr COOTensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "COOTensor") {
return std::make_shared<COOTensorType>();
type = std::make_shared<COOTensorType>();
} else {
size_t start = type_name.find_first_of('[');
size_t end = type_name.find_last_of(']');
// It's better to using regular expression, now just do simple check.
if (start == std::string::npos || end == std::string::npos || end < start) {
MS_EXCEPTION(NotSupportError) << "Expect format like 'COOTensor[type1, type2, ...]', but got '" << type_name
<< "' that not provide pair of ('[', ']').";
}
start = start + 1;
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types;
auto ret = StringToVectorOfType(element_strs, &element_types);
if (!ret) {
MS_EXCEPTION(NotSupportError) << "Expect format like 'COOTensor[type1, type2, ...]', but got '" << type_name
<< "' that miss typename after ','.";
}
type = std::make_shared<COOTensorType>(element_types);
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<COOTensorType>(element_type);
return type;
}
TypePtr CSRTensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "CSRTensor") {
return std::make_shared<CSRTensorType>();
type = std::make_shared<CSRTensorType>();
} else {
size_t start = type_name.find_first_of('[');
size_t end = type_name.find_last_of(']');
// It's better to using regular expression, now just do simple check.
if (start == std::string::npos || end == std::string::npos || end < start) {
MS_EXCEPTION(NotSupportError) << "Expect format like 'CSRTensor[type1, type2, ...]', but got '" << type_name
<< "' that not provide pair of ('[', ']').";
}
start = start + 1;
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types;
auto ret = StringToVectorOfType(element_strs, &element_types);
if (!ret) {
MS_EXCEPTION(NotSupportError) << "Expect format like 'CSRTensor[type1, type2, ...]', but got '" << type_name
<< "' that miss typename after ','.";
}
type = std::make_shared<CSRTensorType>(element_types);
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<CSRTensorType>(element_type);
return type;
}
TypePtr UndeterminedStrToType(const std::string &type_name) {

View File

@ -561,20 +561,18 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
TypePtr CheckAndConvertUtils::CheckSparseTensorTypeValid(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &, const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(type);
if (!type->isa<CSRTensorType>() && !type->isa<COOTensorType>()) {
if (!type->isa<SparseTensorType>()) {
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
<< "] must be a CSRTensor or COOTensor, but got " << type->ToString() << ".";
} else {
auto sparse_type = type->cast<SparseTensorTypePtr>();
if (sparse_type != nullptr) {
return sparse_type->element_type();
}
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
<< "] cast to SparseTensorTypePtr failed! Get type : " << type->ToString() << ".";
}
TypePtr element = nullptr;
if (type->isa<CSRTensorType>()) {
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
element = csr_tensor_type->element();
} else if (type->isa<COOTensorType>()) {
auto coo_tensor_type = type->cast<COOTensorTypePtr>();
element = coo_tensor_type->element();
}
MS_EXCEPTION_IF_NULL(element);
return element;
return nullptr;
}
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,

View File

@ -455,9 +455,15 @@ TEST_F(TestTensor, SparseTensor) {
ASSERT_TRUE(abs_sparse_tensor->elements()[1]->isa<abstract::AbstractTensor>());
// SparseTensorType
TypePtr sparse_tensor_type = std::make_shared<SparseTensorType>(TypeIdToType(kNumberTypeFloat32));
TypePtrList elements{TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeFloat32),
TypeIdToType(kNumberTypeInt64), TypeIdToType(kNumberTypeInt64)};
TypePtr sparse_tensor_type = std::make_shared<SparseTensorType>(elements);
ASSERT_TRUE(sparse_tensor_type->isa<SparseTensorType>());
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[0]->type_id(), kNumberTypeInt32);
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[1]->type_id(), kNumberTypeInt32);
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[2]->type_id(), kNumberTypeFloat32);
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[3]->type_id(), kNumberTypeInt64);
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[4]->type_id(), kNumberTypeInt64);
}
} // namespace tensor
} // namespace mindspore