diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h index 6fc3f5d200d7..116765588d78 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -28,7 +28,9 @@ namespace mlir { namespace edsc { -class IndexedValue; +template class TemplatedIndexedValue; + +using IndexedValue = TemplatedIndexedValue; /// An IndexHandle is a simple wrapper around a ValueHandle. /// IndexHandles are ubiquitous enough to justify a new type to allow simple @@ -89,6 +91,7 @@ public: operator ArrayRef() { return values; } private: + ValueHandleArray() = default; llvm::SmallVector values; }; @@ -164,52 +167,47 @@ private: /// /// Assigning to an IndexedValue emits an actual store operation, while using /// converting an IndexedValue to a ValueHandle emits an actual load operation. -struct IndexedValue { - explicit IndexedValue(Type t) : base(t) {} - explicit IndexedValue(Value *v) : IndexedValue(ValueHandle(v)) {} - explicit IndexedValue(ValueHandle v) : base(v) {} +template struct TemplatedIndexedValue { + explicit TemplatedIndexedValue(Type t) : base(t) {} + explicit TemplatedIndexedValue(Value *v) + : TemplatedIndexedValue(ValueHandle(v)) {} + explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} - IndexedValue(const IndexedValue &rhs) = default; + TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; ValueHandle operator()() { return ValueHandle(*this); } - /// Returns a new `IndexedValue`. - IndexedValue operator()(ValueHandle index) { - IndexedValue res(base); + /// Returns a new `TemplatedIndexedValue`. + TemplatedIndexedValue operator()(ValueHandle index) { + TemplatedIndexedValue res(base); res.indices.push_back(index); return res; } template - IndexedValue operator()(ValueHandle index, Args... indices) { - return IndexedValue(base, index).append(indices...); + TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { + return TemplatedIndexedValue(base, index).append(indices...); } - IndexedValue operator()(llvm::ArrayRef indices) { - return IndexedValue(base, indices); + TemplatedIndexedValue operator()(llvm::ArrayRef indices) { + return TemplatedIndexedValue(base, indices); } - IndexedValue operator()(llvm::ArrayRef indices) { - return IndexedValue( + TemplatedIndexedValue operator()(llvm::ArrayRef indices) { + return TemplatedIndexedValue( base, llvm::ArrayRef(indices.begin(), indices.end())); } /// Emits a `store`. // NOLINTNEXTLINE: unconventional-assign-operator - InstructionHandle operator=(const IndexedValue &rhs) { + InstructionHandle operator=(const TemplatedIndexedValue &rhs) { ValueHandle rrhs(rhs); - assert(getBase().getType().cast().getRank() == indices.size() && - "Unexpected number of indices to store in MemRef"); - return intrinsics::store(rrhs, getBase(), ValueHandleArray(indices)); + return Store(rrhs, getBase(), ValueHandleArray(indices)); } // NOLINTNEXTLINE: unconventional-assign-operator InstructionHandle operator=(ValueHandle rhs) { - assert(getBase().getType().cast().getRank() == indices.size() && - "Unexpected number of indices to store in MemRef"); - return intrinsics::store(rhs, getBase(), ValueHandleArray(indices)); + return Store(rhs, getBase(), ValueHandleArray(indices)); } /// Emits a `load` when converting to a ValueHandle. operator ValueHandle() const { - assert(getBase().getType().cast().getRank() == indices.size() && - "Unexpected number of indices to store in MemRef"); - return intrinsics::load(getBase(), ValueHandleArray(indices)); + return Load(getBase(), ValueHandleArray(indices)); } ValueHandle getBase() const { return base; } @@ -223,39 +221,39 @@ struct IndexedValue { InstructionHandle operator-=(ValueHandle e); InstructionHandle operator*=(ValueHandle e); InstructionHandle operator/=(ValueHandle e); - ValueHandle operator+(IndexedValue e) { + ValueHandle operator+(TemplatedIndexedValue e) { return *this + static_cast(e); } - ValueHandle operator-(IndexedValue e) { + ValueHandle operator-(TemplatedIndexedValue e) { return *this - static_cast(e); } - ValueHandle operator*(IndexedValue e) { + ValueHandle operator*(TemplatedIndexedValue e) { return *this * static_cast(e); } - ValueHandle operator/(IndexedValue e) { + ValueHandle operator/(TemplatedIndexedValue e) { return *this / static_cast(e); } - InstructionHandle operator+=(IndexedValue e) { + InstructionHandle operator+=(TemplatedIndexedValue e) { return this->operator+=(static_cast(e)); } - InstructionHandle operator-=(IndexedValue e) { + InstructionHandle operator-=(TemplatedIndexedValue e) { return this->operator-=(static_cast(e)); } - InstructionHandle operator*=(IndexedValue e) { + InstructionHandle operator*=(TemplatedIndexedValue e) { return this->operator*=(static_cast(e)); } - InstructionHandle operator/=(IndexedValue e) { + InstructionHandle operator/=(TemplatedIndexedValue e) { return this->operator/=(static_cast(e)); } private: - IndexedValue(ValueHandle base, ArrayRef indices) + TemplatedIndexedValue(ValueHandle base, ArrayRef indices) : base(base), indices(indices.begin(), indices.end()) {} - IndexedValue &append() { return *this; } + TemplatedIndexedValue &append() { return *this; } template - IndexedValue &append(T index, Args... indices) { + TemplatedIndexedValue &append(T index, Args... indices) { this->indices.push_back(static_cast(index)); append(indices...); return *this; @@ -264,6 +262,53 @@ private: llvm::SmallVector indices; }; +/// Operator overloadings. +template +ValueHandle TemplatedIndexedValue::operator+(ValueHandle e) { + using op::operator+; + return static_cast(*this) + e; +} +template +ValueHandle TemplatedIndexedValue::operator-(ValueHandle e) { + using op::operator-; + return static_cast(*this) - e; +} +template +ValueHandle TemplatedIndexedValue::operator*(ValueHandle e) { + using op::operator*; + return static_cast(*this) * e; +} +template +ValueHandle TemplatedIndexedValue::operator/(ValueHandle e) { + using op::operator/; + return static_cast(*this) / e; +} + +template +InstructionHandle +TemplatedIndexedValue::operator+=(ValueHandle e) { + using op::operator+; + return Store(*this + e, getBase(), ValueHandleArray(indices)); +} +template +InstructionHandle +TemplatedIndexedValue::operator-=(ValueHandle e) { + using op::operator-; + return Store(*this - e, getBase(), ValueHandleArray(indices)); +} +template +InstructionHandle +TemplatedIndexedValue::operator*=(ValueHandle e) { + using op::operator*; + return Store(*this * e, getBase(), ValueHandleArray(indices)); +} +template +InstructionHandle +TemplatedIndexedValue::operator/=(ValueHandle e) { + using op::operator/; + return Store(*this / e, getBase(), ValueHandleArray(indices)); +} + } // namespace edsc } // namespace mlir diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index bac28bf93a26..c33086e151c6 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -62,38 +62,3 @@ mlir::edsc::VectorView::VectorView(Value *v) : base(v) { steps.push_back(1); } } - -/// Operator overloadings. -ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) { - using op::operator+; - return static_cast(*this) + e; -} -ValueHandle mlir::edsc::IndexedValue::operator-(ValueHandle e) { - using op::operator-; - return static_cast(*this) - e; -} -ValueHandle mlir::edsc::IndexedValue::operator*(ValueHandle e) { - using op::operator*; - return static_cast(*this) * e; -} -ValueHandle mlir::edsc::IndexedValue::operator/(ValueHandle e) { - using op::operator/; - return static_cast(*this) / e; -} - -InstructionHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) { - using op::operator+; - return intrinsics::store(*this + e, getBase(), ValueHandleArray(indices)); -} -InstructionHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) { - using op::operator-; - return intrinsics::store(*this - e, getBase(), ValueHandleArray(indices)); -} -InstructionHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) { - using op::operator*; - return intrinsics::store(*this * e, getBase(), ValueHandleArray(indices)); -} -InstructionHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) { - using op::operator/; - return intrinsics::store(*this / e, getBase(), ValueHandleArray(indices)); -}