forked from OSchip/llvm-project
[mlir][sparse] Improve sparse_tensor::detail::readCOOValue template
This is a followup to the refactoring of D133462, D133830, D133831, and D133833. Depends On D133833 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133836
This commit is contained in:
parent
7cc39b45fe
commit
68609598e4
|
@ -154,50 +154,36 @@ private:
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
|
template <typename T>
|
||||||
// also adds the value to its symmetric location.
|
struct is_complex final : public std::false_type {};
|
||||||
template <typename T, typename V>
|
|
||||||
inline void addValue(T *coo, V value, const std::vector<uint64_t> indices,
|
template <typename T>
|
||||||
bool is_symmetric_value) {
|
struct is_complex<std::complex<T>> final : public std::true_type {};
|
||||||
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
|
|
||||||
coo->add(indices, value);
|
/// Reads an element of a non-complex type for the current indices in
|
||||||
// We currently chose to deal with symmetric matrices by fully constructing
|
/// coordinate scheme.
|
||||||
// them. In the future, we may want to make symmetry implicit for storage
|
template <typename V>
|
||||||
// reasons.
|
inline typename std::enable_if<!is_complex<V>::value, V>::type
|
||||||
if (is_symmetric_value)
|
readCOOValue(char **linePtr, bool is_pattern) {
|
||||||
coo->add({indices[1], indices[0]}, value);
|
// The external formats always store these numerical values with the type
|
||||||
|
// double, but we cast these values to the sparse tensor object type.
|
||||||
|
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
|
||||||
|
return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reads an element of a complex type for the current indices in
|
/// Reads an element of a complex type for the current indices in
|
||||||
/// coordinate scheme.
|
/// coordinate scheme.
|
||||||
template <typename V>
|
template <typename V>
|
||||||
inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
|
inline typename std::enable_if<is_complex<V>::value, V>::type
|
||||||
const std::vector<uint64_t> indices, char **linePtr,
|
readCOOValue(char **linePtr, bool is_pattern) {
|
||||||
bool is_pattern, bool add_symmetric_value) {
|
|
||||||
// Read two values to make a complex. The external formats always store
|
// Read two values to make a complex. The external formats always store
|
||||||
// numerical values with the type double, but we cast these values to the
|
// numerical values with the type double, but we cast these values to the
|
||||||
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
|
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
|
||||||
// value 1 for all entries.
|
// value 1 for all entries.
|
||||||
V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||||
V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
||||||
std::complex<V> value = {re, im};
|
// Avoiding brace-notation since that forbids narrowing to `float`.
|
||||||
addValue(coo, value, indices, add_symmetric_value);
|
return V(re, im);
|
||||||
}
|
|
||||||
|
|
||||||
// Reads an element of a non-complex type for the current indices in coordinate
|
|
||||||
// scheme.
|
|
||||||
template <typename V,
|
|
||||||
typename std::enable_if<
|
|
||||||
!std::is_same<std::complex<float>, V>::value &&
|
|
||||||
!std::is_same<std::complex<double>, V>::value>::type * = nullptr>
|
|
||||||
inline void readCOOValue(SparseTensorCOO<V> *coo,
|
|
||||||
const std::vector<uint64_t> indices, char **linePtr,
|
|
||||||
bool is_pattern, bool is_symmetric_value) {
|
|
||||||
// The external formats always store these numerical values with the type
|
|
||||||
// double, but we cast these values to the sparse tensor object type.
|
|
||||||
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
|
|
||||||
double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
|
|
||||||
addValue(coo, value, indices, is_symmetric_value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -232,8 +218,14 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
|
||||||
// Add the 0-based index.
|
// Add the 0-based index.
|
||||||
indices[perm[r]] = idx - 1;
|
indices[perm[r]] = idx - 1;
|
||||||
}
|
}
|
||||||
detail::readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
|
const V value = detail::readCOOValue<V>(&linePtr, stfile.isPattern());
|
||||||
stfile.isSymmetric() && indices[0] != indices[1]);
|
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
|
||||||
|
coo->add(indices, value);
|
||||||
|
// We currently chose to deal with symmetric matrices by fully
|
||||||
|
// constructing them. In the future, we may want to make symmetry
|
||||||
|
// implicit for storage reasons.
|
||||||
|
if (stfile.isSymmetric() && indices[0] != indices[1])
|
||||||
|
coo->add({indices[1], indices[0]}, value);
|
||||||
}
|
}
|
||||||
// Close the file and return tensor.
|
// Close the file and return tensor.
|
||||||
stfile.closeFile();
|
stfile.closeFile();
|
||||||
|
|
Loading…
Reference in New Issue