forked from OSchip/llvm-project
Make edsc::IndexedValue templated - NFC
This allows the indexing sugar to just work naturally with other type of load and store ops than the affine ones we currently have. This is needed for the EuroLLVM tutorial. PiperOrigin-RevId: 239602257
This commit is contained in:
parent
8a761881a0
commit
028530271e
|
@ -28,7 +28,9 @@
|
|||
namespace mlir {
|
||||
namespace edsc {
|
||||
|
||||
class IndexedValue;
|
||||
template <typename Load, typename Store> class TemplatedIndexedValue;
|
||||
|
||||
using IndexedValue = TemplatedIndexedValue<intrinsics::load, intrinsics::store>;
|
||||
|
||||
/// An IndexHandle is a simple wrapper around a ValueHandle.
|
||||
/// IndexHandles are ubiquitous enough to justify a new type to allow simple
|
||||
|
@ -89,6 +91,7 @@ public:
|
|||
operator ArrayRef<Value *>() { return values; }
|
||||
|
||||
private:
|
||||
ValueHandleArray() = default;
|
||||
llvm::SmallVector<Value *, 8> values;
|
||||
};
|
||||
|
||||
|
@ -164,52 +167,47 @@ private:
|
|||
///
|
||||
/// Assigning to an IndexedValue emits an actual store operation, while using
|
||||
/// converting an IndexedValue to a ValueHandle emits an actual load operation.
|
||||
struct IndexedValue {
|
||||
explicit IndexedValue(Type t) : base(t) {}
|
||||
explicit IndexedValue(Value *v) : IndexedValue(ValueHandle(v)) {}
|
||||
explicit IndexedValue(ValueHandle v) : base(v) {}
|
||||
template <typename Load, typename Store> struct TemplatedIndexedValue {
|
||||
explicit TemplatedIndexedValue(Type t) : base(t) {}
|
||||
explicit TemplatedIndexedValue(Value *v)
|
||||
: TemplatedIndexedValue(ValueHandle(v)) {}
|
||||
explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
|
||||
|
||||
IndexedValue(const IndexedValue &rhs) = default;
|
||||
TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
|
||||
|
||||
ValueHandle operator()() { return ValueHandle(*this); }
|
||||
/// Returns a new `IndexedValue`.
|
||||
IndexedValue operator()(ValueHandle index) {
|
||||
IndexedValue res(base);
|
||||
/// Returns a new `TemplatedIndexedValue`.
|
||||
TemplatedIndexedValue operator()(ValueHandle index) {
|
||||
TemplatedIndexedValue res(base);
|
||||
res.indices.push_back(index);
|
||||
return res;
|
||||
}
|
||||
template <typename... Args>
|
||||
IndexedValue operator()(ValueHandle index, Args... indices) {
|
||||
return IndexedValue(base, index).append(indices...);
|
||||
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
|
||||
return TemplatedIndexedValue(base, index).append(indices...);
|
||||
}
|
||||
IndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
|
||||
return IndexedValue(base, indices);
|
||||
TemplatedIndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
|
||||
return TemplatedIndexedValue(base, indices);
|
||||
}
|
||||
IndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
|
||||
return IndexedValue(
|
||||
TemplatedIndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
|
||||
return TemplatedIndexedValue(
|
||||
base, llvm::ArrayRef<ValueHandle>(indices.begin(), indices.end()));
|
||||
}
|
||||
|
||||
/// Emits a `store`.
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
InstructionHandle operator=(const IndexedValue &rhs) {
|
||||
InstructionHandle operator=(const TemplatedIndexedValue &rhs) {
|
||||
ValueHandle rrhs(rhs);
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::store(rrhs, getBase(), ValueHandleArray(indices));
|
||||
return Store(rrhs, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
InstructionHandle operator=(ValueHandle rhs) {
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::store(rhs, getBase(), ValueHandleArray(indices));
|
||||
return Store(rhs, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
|
||||
/// Emits a `load` when converting to a ValueHandle.
|
||||
operator ValueHandle() const {
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::load(getBase(), ValueHandleArray(indices));
|
||||
return Load(getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
|
||||
ValueHandle getBase() const { return base; }
|
||||
|
@ -223,39 +221,39 @@ struct IndexedValue {
|
|||
InstructionHandle operator-=(ValueHandle e);
|
||||
InstructionHandle operator*=(ValueHandle e);
|
||||
InstructionHandle operator/=(ValueHandle e);
|
||||
ValueHandle operator+(IndexedValue e) {
|
||||
ValueHandle operator+(TemplatedIndexedValue e) {
|
||||
return *this + static_cast<ValueHandle>(e);
|
||||
}
|
||||
ValueHandle operator-(IndexedValue e) {
|
||||
ValueHandle operator-(TemplatedIndexedValue e) {
|
||||
return *this - static_cast<ValueHandle>(e);
|
||||
}
|
||||
ValueHandle operator*(IndexedValue e) {
|
||||
ValueHandle operator*(TemplatedIndexedValue e) {
|
||||
return *this * static_cast<ValueHandle>(e);
|
||||
}
|
||||
ValueHandle operator/(IndexedValue e) {
|
||||
ValueHandle operator/(TemplatedIndexedValue e) {
|
||||
return *this / static_cast<ValueHandle>(e);
|
||||
}
|
||||
InstructionHandle operator+=(IndexedValue e) {
|
||||
InstructionHandle operator+=(TemplatedIndexedValue e) {
|
||||
return this->operator+=(static_cast<ValueHandle>(e));
|
||||
}
|
||||
InstructionHandle operator-=(IndexedValue e) {
|
||||
InstructionHandle operator-=(TemplatedIndexedValue e) {
|
||||
return this->operator-=(static_cast<ValueHandle>(e));
|
||||
}
|
||||
InstructionHandle operator*=(IndexedValue e) {
|
||||
InstructionHandle operator*=(TemplatedIndexedValue e) {
|
||||
return this->operator*=(static_cast<ValueHandle>(e));
|
||||
}
|
||||
InstructionHandle operator/=(IndexedValue e) {
|
||||
InstructionHandle operator/=(TemplatedIndexedValue e) {
|
||||
return this->operator/=(static_cast<ValueHandle>(e));
|
||||
}
|
||||
|
||||
private:
|
||||
IndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
|
||||
TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
|
||||
: base(base), indices(indices.begin(), indices.end()) {}
|
||||
|
||||
IndexedValue &append() { return *this; }
|
||||
TemplatedIndexedValue &append() { return *this; }
|
||||
|
||||
template <typename T, typename... Args>
|
||||
IndexedValue &append(T index, Args... indices) {
|
||||
TemplatedIndexedValue &append(T index, Args... indices) {
|
||||
this->indices.push_back(static_cast<ValueHandle>(index));
|
||||
append(indices...);
|
||||
return *this;
|
||||
|
@ -264,6 +262,53 @@ private:
|
|||
llvm::SmallVector<ValueHandle, 8> indices;
|
||||
};
|
||||
|
||||
/// Operator overloadings.
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
|
||||
using op::operator+;
|
||||
return static_cast<ValueHandle>(*this) + e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
|
||||
using op::operator-;
|
||||
return static_cast<ValueHandle>(*this) - e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
|
||||
using op::operator*;
|
||||
return static_cast<ValueHandle>(*this) * e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
|
||||
using op::operator/;
|
||||
return static_cast<ValueHandle>(*this) / e;
|
||||
}
|
||||
|
||||
template <typename Load, typename Store>
|
||||
InstructionHandle
|
||||
TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
|
||||
using op::operator+;
|
||||
return Store(*this + e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
InstructionHandle
|
||||
TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
|
||||
using op::operator-;
|
||||
return Store(*this - e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
InstructionHandle
|
||||
TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
|
||||
using op::operator*;
|
||||
return Store(*this * e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
InstructionHandle
|
||||
TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
|
||||
using op::operator/;
|
||||
return Store(*this / e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -62,38 +62,3 @@ mlir::edsc::VectorView::VectorView(Value *v) : base(v) {
|
|||
steps.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Operator overloadings.
|
||||
ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) {
|
||||
using op::operator+;
|
||||
return static_cast<ValueHandle>(*this) + e;
|
||||
}
|
||||
ValueHandle mlir::edsc::IndexedValue::operator-(ValueHandle e) {
|
||||
using op::operator-;
|
||||
return static_cast<ValueHandle>(*this) - e;
|
||||
}
|
||||
ValueHandle mlir::edsc::IndexedValue::operator*(ValueHandle e) {
|
||||
using op::operator*;
|
||||
return static_cast<ValueHandle>(*this) * e;
|
||||
}
|
||||
ValueHandle mlir::edsc::IndexedValue::operator/(ValueHandle e) {
|
||||
using op::operator/;
|
||||
return static_cast<ValueHandle>(*this) / e;
|
||||
}
|
||||
|
||||
InstructionHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) {
|
||||
using op::operator+;
|
||||
return intrinsics::store(*this + e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
InstructionHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) {
|
||||
using op::operator-;
|
||||
return intrinsics::store(*this - e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
InstructionHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) {
|
||||
using op::operator*;
|
||||
return intrinsics::store(*this * e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
InstructionHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) {
|
||||
using op::operator/;
|
||||
return intrinsics::store(*this / e, getBase(), ValueHandleArray(indices));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue