forked from mindspore-Ecosystem/mindspore
change SparseTensorType constructor: use TypePtrList
This commit is contained in:
parent
00b19b8bf2
commit
1a49773032
|
@ -157,10 +157,10 @@ REGISTER_PYBIND_DEFINE(
|
|||
.def_property_readonly("ElementType", &RowTensorType::element, "Get the RowTensorType's element type.");
|
||||
(void)py::class_<COOTensorType, Type, std::shared_ptr<COOTensorType>>(m_sub, "COOTensorType")
|
||||
.def(py::init())
|
||||
.def_property_readonly("ElementType", &COOTensorType::element, "Get the COOTensorType's element type.");
|
||||
.def_property_readonly("ElementType", &COOTensorType::element_type, "Get the COOTensorType's element type.");
|
||||
(void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType")
|
||||
.def(py::init())
|
||||
.def_property_readonly("ElementType", &CSRTensorType::element, "Get the CSRTensorType's element type.");
|
||||
.def_property_readonly("ElementType", &CSRTensorType::element_type, "Get the CSRTensorType's element type.");
|
||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||
.def(py::init());
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
|
|
|
@ -658,6 +658,13 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
|||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
}
|
||||
|
||||
if (type_ptr->isa<SparseTensorType>()) {
|
||||
auto tensor_ptr = type_ptr->cast<SparseTensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
type_ptr = (*tensor_ptr)[output_idx];
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
}
|
||||
|
||||
if (type_ptr->isa<TensorType>()) {
|
||||
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
|
@ -665,15 +672,6 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
|||
MS_EXCEPTION_IF_NULL(elem);
|
||||
return elem->type_id();
|
||||
}
|
||||
|
||||
if (type_ptr->isa<CSRTensorType>()) {
|
||||
auto tensor_ptr = type_ptr->cast<CSRTensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
TypePtr elem = tensor_ptr->element();
|
||||
MS_EXCEPTION_IF_NULL(elem);
|
||||
return elem->type_id();
|
||||
}
|
||||
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
|
||||
|
|
|
@ -1431,7 +1431,7 @@ const TypeId AbstractSparseTensor::GetTypeIdAt(size_t index) const {
|
|||
if (index < shape_idx) {
|
||||
auto abs_tensor = GetAbsPtrAt<abstract::AbstractTensorPtr>(index);
|
||||
MS_EXCEPTION_IF_NULL(abs_tensor);
|
||||
return abs_tensor->BuildType()->type_id();
|
||||
return abs_tensor->element()->BuildType()->type_id();
|
||||
} else if (index < shape_idx + shape()->size()) {
|
||||
return shape()->elements()[index - shape_idx]->BuildType()->type_id();
|
||||
}
|
||||
|
@ -1544,9 +1544,13 @@ std::string AbstractRowTensor::ToString() const {
|
|||
|
||||
// COOTensor
|
||||
TypePtr AbstractCOOTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(indices());
|
||||
MS_EXCEPTION_IF_NULL(values());
|
||||
TypePtr element_type = values()->element()->BuildType();
|
||||
return std::make_shared<COOTensorType>(element_type);
|
||||
MS_EXCEPTION_IF_NULL(shape());
|
||||
TypePtrList elements{indices()->element()->BuildType(), values()->element()->BuildType()};
|
||||
std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
|
||||
[](AbstractBasePtr p) { return p->BuildType(); });
|
||||
return std::make_shared<COOTensorType>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractCOOTensor::Clone() const {
|
||||
|
@ -1589,9 +1593,15 @@ const AbstractTensorPtr AbstractCOOTensor::values() const {
|
|||
|
||||
// CSRTensor
|
||||
TypePtr AbstractCSRTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(indptr());
|
||||
MS_EXCEPTION_IF_NULL(indices());
|
||||
MS_EXCEPTION_IF_NULL(values());
|
||||
TypePtr element_type = values()->element()->BuildType();
|
||||
return std::make_shared<CSRTensorType>(element_type);
|
||||
MS_EXCEPTION_IF_NULL(shape());
|
||||
TypePtrList elements{indptr()->element()->BuildType(), indices()->element()->BuildType(),
|
||||
values()->element()->BuildType()};
|
||||
std::transform(shape()->elements().begin(), shape()->elements().end(), std::back_inserter(elements),
|
||||
[](AbstractBasePtr p) { return p->BuildType(); });
|
||||
return std::make_shared<CSRTensorType>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractCSRTensor::Clone() const {
|
||||
|
|
|
@ -105,11 +105,76 @@ bool TensorType::operator==(const Type &other) const {
|
|||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
std::string SparseTensorType::ElementsDtypeStr(const StringType str_type) const {
|
||||
std::ostringstream oss;
|
||||
for (const TypePtr &elem : elements_) {
|
||||
if (str_type == kToString) {
|
||||
oss << elem->ToString();
|
||||
} else if (str_type == kDumpText) {
|
||||
oss << elem->DumpText();
|
||||
} else if (str_type == kReprString) {
|
||||
oss << elem->ToReprString();
|
||||
}
|
||||
oss << ",";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string SparseTensorType::ToString() const {
|
||||
if (elements_.empty()) {
|
||||
return GetSparseTensorTypeName();
|
||||
}
|
||||
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kToString) + "]";
|
||||
}
|
||||
|
||||
std::string SparseTensorType::DumpText() const {
|
||||
if (elements_.empty()) {
|
||||
return GetSparseTensorTypeName();
|
||||
}
|
||||
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kDumpText) + "]";
|
||||
}
|
||||
|
||||
std::string SparseTensorType::ToReprString() const {
|
||||
if (elements_.empty()) {
|
||||
return GetSparseTensorTypeName();
|
||||
}
|
||||
return GetSparseTensorTypeName() + "[" + ElementsDtypeStr(kReprString) + "]";
|
||||
}
|
||||
|
||||
TypePtr SparseTensorType::DeepCopy() const {
|
||||
if (element_type_ == nullptr || IsGeneric()) {
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<SparseTensorType>();
|
||||
}
|
||||
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
|
||||
TypePtrList new_elements;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(new_elements),
|
||||
[](const TypePtr &ele) { return ele->DeepCopy(); });
|
||||
auto copy = std::make_shared<SparseTensorType>(new_elements);
|
||||
return copy;
|
||||
}
|
||||
|
||||
const TypePtr SparseTensorType::operator[](std::size_t dim) const {
|
||||
if (dim >= size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << dim << " is out range of the SparseTensorType size " << size() << ".";
|
||||
}
|
||||
return elements_[dim];
|
||||
}
|
||||
|
||||
bool SparseTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
const SparseTensorType &other_sparse = static_cast<const SparseTensorType &>(other);
|
||||
if (!other_sparse.elements().empty()) {
|
||||
if (elements_.size() != other_sparse.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < elements_.size(); ++i) {
|
||||
if (*elements_[i] != *other_sparse.elements()[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TypePtr RowTensorType::DeepCopy() const {
|
||||
|
@ -155,86 +220,56 @@ bool RowTensorType::operator==(const Type &other) const {
|
|||
}
|
||||
|
||||
TypePtr COOTensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<COOTensorType>();
|
||||
}
|
||||
return std::make_shared<COOTensorType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string COOTensorType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "COOTensor";
|
||||
}
|
||||
return "COOTensor[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string COOTensorType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "COOTensor";
|
||||
}
|
||||
return "COOTensor[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string COOTensorType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "COOTensor";
|
||||
}
|
||||
return "COOTensor[" + element_type_->DumpText() + "]";
|
||||
TypePtrList elements;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
||||
[](const TypePtr &ele) { return ele->DeepCopy(); });
|
||||
auto copy = std::make_shared<COOTensorType>(elements);
|
||||
return copy;
|
||||
}
|
||||
|
||||
bool COOTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const COOTensorType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
const COOTensorType &other_coo = static_cast<const COOTensorType &>(other);
|
||||
if (elements_.size() != other_coo.size()) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
for (size_t i = 0; i < elements_.size(); ++i) {
|
||||
if (*elements_[i] != *other_coo.elements()[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TypePtr CSRTensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<CSRTensorType>();
|
||||
}
|
||||
return std::make_shared<CSRTensorType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string CSRTensorType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string CSRTensorType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string CSRTensorType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "CSRTensor";
|
||||
}
|
||||
return "CSRTensor[" + element_type_->DumpText() + "]";
|
||||
TypePtrList elements;
|
||||
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
|
||||
[](const TypePtr &ele) { return ele->DeepCopy(); });
|
||||
auto copy = std::make_shared<CSRTensorType>(elements);
|
||||
return copy;
|
||||
}
|
||||
|
||||
bool CSRTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const CSRTensorType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
const CSRTensorType &other_csr = static_cast<const CSRTensorType &>(other);
|
||||
if (elements_.size() != other_csr.size()) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
for (size_t i = 0; i < elements_.size(); ++i) {
|
||||
if (*elements_[i] != *other_csr.elements()[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -117,47 +117,45 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
|
|||
/// \brief SparseTensorType is the base type for all sparse tensors.
|
||||
class MS_CORE_API SparseTensorType : public Object {
|
||||
public:
|
||||
/// \brief Default constructor for SparseTensorType.
|
||||
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
|
||||
|
||||
/// \brief Constructor for SparseTensorType.
|
||||
///
|
||||
/// \param[in] object_type The type id of derived class type.
|
||||
explicit SparseTensorType(const TypeId object_type) : Object(object_type, kObjectTypeUndeterminedType) {}
|
||||
|
||||
/// \brief Constructor for SparseTensorType.
|
||||
///
|
||||
/// \param[in] ele The element of SparseTensorType.
|
||||
explicit SparseTensorType(const TypePtr &ele)
|
||||
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
explicit SparseTensorType(const TypePtrList &objs)
|
||||
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
|
||||
|
||||
/// \brief Constructor for SparseTensorType.
|
||||
///
|
||||
/// \param[in] object_type The type id of derived class type.
|
||||
/// \param[in] ele The element of SparseTensorType.
|
||||
explicit SparseTensorType(const TypeId object_type, const TypePtr &ele)
|
||||
: Object(object_type, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
SparseTensorType(const TypeId object_type, const TypePtrList &objs)
|
||||
: Object(object_type, kObjectTypeUndeterminedType), elements_(objs.begin(), objs.end()) {}
|
||||
|
||||
/// \brief Destructor of SparseTensorType.
|
||||
~SparseTensorType() override = default;
|
||||
MS_DECLARE_PARENT(SparseTensorType, Object)
|
||||
|
||||
enum StringType : int { kToString = 0, kDumpText, kReprString };
|
||||
|
||||
virtual std::string GetSparseTensorTypeName() const { return "SparseTensorType"; }
|
||||
virtual size_t GetElementIndex() { return 0; }
|
||||
virtual TypePtr element_type() {
|
||||
if (elements_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return elements_[GetElementIndex()];
|
||||
}
|
||||
std::string ElementsDtypeStr(const StringType str_type) const;
|
||||
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
|
||||
|
||||
/// \brief Get the element of SparseTensorType object.
|
||||
///
|
||||
/// \return The element of SparseTensorType object.
|
||||
const TypePtr element() const { return element_type_; }
|
||||
const TypePtr operator[](size_t dim) const;
|
||||
bool operator==(const Type &other) const;
|
||||
TypePtrList elements() const { return elements_; }
|
||||
|
||||
/// \brief Set the element of SparseTensorType object.
|
||||
///
|
||||
/// \param[in] element_type Define the element type to be set.
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::size_t size() const { return elements_.size(); }
|
||||
std::string ToString() const;
|
||||
std::string ToReprString() const;
|
||||
std::string DumpText() const;
|
||||
TypePtr DeepCopy() const;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
TypePtrList elements_;
|
||||
};
|
||||
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
||||
|
||||
|
@ -209,32 +207,22 @@ class MS_CORE_API COOTensorType final : public SparseTensorType {
|
|||
/// \brief Constructor for COOTensorType.
|
||||
///
|
||||
/// \param[in] ele The element of COOTensorType.
|
||||
explicit COOTensorType(const TypePtr &ele) : SparseTensorType(kObjectTypeCOOTensorType, ele), element_type_(ele) {}
|
||||
explicit COOTensorType(const TypePtrList &obj)
|
||||
: SparseTensorType(kObjectTypeCOOTensorType, obj), elements_(obj.begin(), obj.end()) {}
|
||||
|
||||
/// \brief Destructor of COOTensorType.
|
||||
~COOTensorType() override = default;
|
||||
MS_DECLARE_PARENT(COOTensorType, SparseTensorType)
|
||||
|
||||
std::string GetSparseTensorTypeName() const override { return "COOTensor"; }
|
||||
size_t GetElementIndex() override { return 1; }
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeCOOTensorType; }
|
||||
|
||||
/// \brief Get the element of COOTensorType object.
|
||||
///
|
||||
/// \return The element of COOTensorType object.
|
||||
const TypePtr element() const { return element_type_; }
|
||||
|
||||
/// \brief Set the element of COOTensorType object.
|
||||
///
|
||||
/// \param[in] element_type Define the element type to be set.
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
TypePtrList elements_;
|
||||
};
|
||||
using COOTensorTypePtr = std::shared_ptr<COOTensorType>;
|
||||
|
||||
|
@ -247,32 +235,21 @@ class MS_CORE_API CSRTensorType : public SparseTensorType {
|
|||
/// \brief Constructor for CSRTensorType.
|
||||
///
|
||||
/// \param[in] ele The element of CSRTensorType.
|
||||
explicit CSRTensorType(const TypePtr &ele) : SparseTensorType(kObjectTypeCSRTensorType, ele), element_type_(ele) {}
|
||||
explicit CSRTensorType(const TypePtrList &obj)
|
||||
: SparseTensorType(kObjectTypeCSRTensorType, obj), elements_(obj.begin(), obj.end()) {}
|
||||
|
||||
/// \brief Destructor of CSRTensorType.
|
||||
~CSRTensorType() override = default;
|
||||
MS_DECLARE_PARENT(CSRTensorType, SparseTensorType)
|
||||
|
||||
std::string GetSparseTensorTypeName() const override { return "CSRTensor"; }
|
||||
size_t GetElementIndex() override { return 2; }
|
||||
TypeId generic_type_id() const override { return kObjectTypeCSRTensorType; }
|
||||
|
||||
/// \brief Get the element of CSRTensorType object.
|
||||
///
|
||||
/// \return The element of CSRTensorType object.
|
||||
const TypePtr element() const { return element_type_; }
|
||||
|
||||
/// \brief Set the element of CSRTensorType object.
|
||||
///
|
||||
/// \param[in] element_type Define the element type to be set.
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
TypePtrList elements_;
|
||||
};
|
||||
using CSRTensorTypePtr = std::shared_ptr<CSRTensorType>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -233,37 +233,53 @@ TypePtr RowTensorStrToType(const std::string &type_name) {
|
|||
}
|
||||
|
||||
TypePtr COOTensorStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "COOTensor") {
|
||||
return std::make_shared<COOTensorType>();
|
||||
type = std::make_shared<COOTensorType>();
|
||||
} else {
|
||||
size_t start = type_name.find_first_of('[');
|
||||
size_t end = type_name.find_last_of(']');
|
||||
// It's better to using regular expression, now just do simple check.
|
||||
if (start == std::string::npos || end == std::string::npos || end < start) {
|
||||
MS_EXCEPTION(NotSupportError) << "Expect format like 'COOTensor[type1, type2, ...]', but got '" << type_name
|
||||
<< "' that not provide pair of ('[', ']').";
|
||||
}
|
||||
start = start + 1;
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types;
|
||||
auto ret = StringToVectorOfType(element_strs, &element_types);
|
||||
if (!ret) {
|
||||
MS_EXCEPTION(NotSupportError) << "Expect format like 'COOTensor[type1, type2, ...]', but got '" << type_name
|
||||
<< "' that miss typename after ','.";
|
||||
}
|
||||
type = std::make_shared<COOTensorType>(element_types);
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<COOTensorType>(element_type);
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr CSRTensorStrToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name == "CSRTensor") {
|
||||
return std::make_shared<CSRTensorType>();
|
||||
type = std::make_shared<CSRTensorType>();
|
||||
} else {
|
||||
size_t start = type_name.find_first_of('[');
|
||||
size_t end = type_name.find_last_of(']');
|
||||
// It's better to using regular expression, now just do simple check.
|
||||
if (start == std::string::npos || end == std::string::npos || end < start) {
|
||||
MS_EXCEPTION(NotSupportError) << "Expect format like 'CSRTensor[type1, type2, ...]', but got '" << type_name
|
||||
<< "' that not provide pair of ('[', ']').";
|
||||
}
|
||||
start = start + 1;
|
||||
std::string element_strs = type_name.substr(start, end - start);
|
||||
std::vector<TypePtr> element_types;
|
||||
auto ret = StringToVectorOfType(element_strs, &element_types);
|
||||
if (!ret) {
|
||||
MS_EXCEPTION(NotSupportError) << "Expect format like 'CSRTensor[type1, type2, ...]', but got '" << type_name
|
||||
<< "' that miss typename after ','.";
|
||||
}
|
||||
type = std::make_shared<CSRTensorType>(element_types);
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<CSRTensorType>(element_type);
|
||||
return type;
|
||||
}
|
||||
|
||||
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||
|
|
|
@ -561,20 +561,18 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
|
|||
TypePtr CheckAndConvertUtils::CheckSparseTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &, const std::string &prim_name) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (!type->isa<CSRTensorType>() && !type->isa<COOTensorType>()) {
|
||||
if (!type->isa<SparseTensorType>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
|
||||
<< "] must be a CSRTensor or COOTensor, but got " << type->ToString() << ".";
|
||||
} else {
|
||||
auto sparse_type = type->cast<SparseTensorTypePtr>();
|
||||
if (sparse_type != nullptr) {
|
||||
return sparse_type->element_type();
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
|
||||
<< "] cast to SparseTensorTypePtr failed! Get type : " << type->ToString() << ".";
|
||||
}
|
||||
TypePtr element = nullptr;
|
||||
if (type->isa<CSRTensorType>()) {
|
||||
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
|
||||
element = csr_tensor_type->element();
|
||||
} else if (type->isa<COOTensorType>()) {
|
||||
auto coo_tensor_type = type->cast<COOTensorTypePtr>();
|
||||
element = coo_tensor_type->element();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
return element;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
||||
|
|
|
@ -455,9 +455,15 @@ TEST_F(TestTensor, SparseTensor) {
|
|||
ASSERT_TRUE(abs_sparse_tensor->elements()[1]->isa<abstract::AbstractTensor>());
|
||||
|
||||
// SparseTensorType
|
||||
TypePtr sparse_tensor_type = std::make_shared<SparseTensorType>(TypeIdToType(kNumberTypeFloat32));
|
||||
TypePtrList elements{TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeFloat32),
|
||||
TypeIdToType(kNumberTypeInt64), TypeIdToType(kNumberTypeInt64)};
|
||||
TypePtr sparse_tensor_type = std::make_shared<SparseTensorType>(elements);
|
||||
ASSERT_TRUE(sparse_tensor_type->isa<SparseTensorType>());
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[0]->type_id(), kNumberTypeInt32);
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[1]->type_id(), kNumberTypeInt32);
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[2]->type_id(), kNumberTypeFloat32);
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[3]->type_id(), kNumberTypeInt64);
|
||||
ASSERT_EQ(sparse_tensor_type->cast<SparseTensorTypePtr>()->elements()[4]->type_id(), kNumberTypeInt64);
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue