2018-07-09 11:51:38 +08:00
|
|
|
//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
|
|
|
|
//
|
2019-12-24 01:35:36 +08:00
|
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2018-07-09 11:51:38 +08:00
|
|
|
//
|
2019-12-24 01:35:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-07-09 11:51:38 +08:00
|
|
|
|
|
|
|
#include "mlir/IR/Builders.h"
|
2018-07-11 01:59:53 +08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
2019-12-14 04:21:42 +08:00
|
|
|
#include "mlir/IR/Dialect.h"
|
2018-08-08 05:24:38 +08:00
|
|
|
#include "mlir/IR/IntegerSet.h"
|
2019-12-14 04:21:42 +08:00
|
|
|
#include "mlir/IR/Matchers.h"
|
2018-07-09 11:51:38 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-01-04 06:29:52 +08:00
|
|
|
#include "mlir/IR/StandardTypes.h"
|
2019-04-06 03:19:22 +08:00
|
|
|
#include "mlir/Support/Functional.h"
|
2019-12-14 04:21:42 +08:00
|
|
|
#include "llvm/Support/raw_ostream.h"
|
2018-07-09 11:51:38 +08:00
|
|
|
using namespace mlir;
|
|
|
|
|
2019-07-11 01:07:49 +08:00
|
|
|
Builder::Builder(ModuleOp module) : context(module.getContext()) {}
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2018-07-11 01:59:53 +08:00
|
|
|
Identifier Builder::getIdentifier(StringRef str) {
|
|
|
|
return Identifier::get(str, context);
|
|
|
|
}
|
|
|
|
|
2018-08-28 12:05:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Locations.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-06-26 07:57:32 +08:00
|
|
|
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
|
2018-08-28 12:05:16 +08:00
|
|
|
|
2019-06-26 07:57:32 +08:00
|
|
|
Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
|
|
|
|
unsigned column) {
|
2018-08-28 12:05:16 +08:00
|
|
|
return FileLineColLoc::get(filename, line, column, context);
|
|
|
|
}
|
|
|
|
|
2018-11-10 03:27:28 +08:00
|
|
|
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
|
|
|
|
return FusedLoc::get(locs, metadata, context);
|
|
|
|
}
|
|
|
|
|
2018-07-11 01:59:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-07-09 11:51:38 +08:00
|
|
|
// Types.
|
2018-07-11 01:59:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
FloatType Builder::getF32Type() { return FloatType::getF32(context); }
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
IndexType Builder::getIndexType() { return IndexType::get(context); }
|
2018-07-09 11:51:38 +08:00
|
|
|
|
2019-02-13 03:08:04 +08:00
|
|
|
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
|
2018-11-29 03:49:26 +08:00
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
IntegerType Builder::getIntegerType(unsigned width) {
|
2019-02-13 03:08:04 +08:00
|
|
|
return IntegerType::get(width, context);
|
2018-07-09 11:51:38 +08:00
|
|
|
}
|
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
|
|
|
|
ArrayRef<Type> results) {
|
2018-07-09 11:51:38 +08:00
|
|
|
return FunctionType::get(inputs, results, context);
|
|
|
|
}
|
|
|
|
|
2019-03-20 01:59:02 +08:00
|
|
|
TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
|
|
|
|
return TupleType::get(elementTypes, context);
|
|
|
|
}
|
|
|
|
|
2019-04-28 09:35:04 +08:00
|
|
|
NoneType Builder::getNoneType() { return NoneType::get(context); }
|
|
|
|
|
2018-07-11 01:59:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Attributes.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-03-01 08:45:30 +08:00
|
|
|
NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
|
|
|
|
return NamedAttribute(getIdentifier(name), val);
|
|
|
|
}
|
|
|
|
|
2019-04-26 00:56:09 +08:00
|
|
|
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
BoolAttr Builder::getBoolAttr(bool value) {
|
2018-07-11 01:59:53 +08:00
|
|
|
return BoolAttr::get(value, context);
|
|
|
|
}
|
|
|
|
|
2019-06-01 00:24:48 +08:00
|
|
|
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
|
|
|
|
return DictionaryAttr::get(value, context);
|
|
|
|
}
|
|
|
|
|
2018-12-27 03:48:58 +08:00
|
|
|
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
|
2018-11-16 09:53:51 +08:00
|
|
|
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
|
|
|
|
}
|
|
|
|
|
2019-12-02 23:51:27 +08:00
|
|
|
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
|
|
|
|
return DenseElementsAttr::get(
|
|
|
|
VectorType::get(static_cast<int64_t>(values.size()),
|
|
|
|
getIntegerType(32)),
|
|
|
|
values)
|
|
|
|
.cast<DenseIntElementsAttr>();
|
|
|
|
}
|
|
|
|
|
2018-12-27 03:48:58 +08:00
|
|
|
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
|
|
|
|
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
|
|
|
|
}
|
|
|
|
|
2019-09-15 08:02:06 +08:00
|
|
|
IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
|
|
|
|
return IntegerAttr::get(getIntegerType(16), APInt(16, value));
|
|
|
|
}
|
|
|
|
|
|
|
|
IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
|
|
|
|
return IntegerAttr::get(getIntegerType(8), APInt(8, value));
|
|
|
|
}
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
|
|
|
|
if (type.isIndex())
|
|
|
|
return IntegerAttr::get(type, APInt(64, value));
|
2018-12-18 02:05:56 +08:00
|
|
|
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
|
2018-11-16 09:53:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
|
|
|
|
return IntegerAttr::get(type, value);
|
2018-07-11 01:59:53 +08:00
|
|
|
}
|
|
|
|
|
2018-12-27 03:48:58 +08:00
|
|
|
FloatAttr Builder::getF64FloatAttr(double value) {
|
2018-12-17 23:19:53 +08:00
|
|
|
return FloatAttr::get(getF64Type(), APFloat(value));
|
|
|
|
}
|
|
|
|
|
2018-12-27 03:48:58 +08:00
|
|
|
FloatAttr Builder::getF32FloatAttr(float value) {
|
2018-11-16 09:53:51 +08:00
|
|
|
return FloatAttr::get(getF32Type(), APFloat(value));
|
|
|
|
}
|
|
|
|
|
2019-06-06 13:02:22 +08:00
|
|
|
FloatAttr Builder::getF16FloatAttr(float value) {
|
|
|
|
return FloatAttr::get(getF16Type(), value);
|
|
|
|
}
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
FloatAttr Builder::getFloatAttr(Type type, double value) {
|
2018-12-17 23:19:53 +08:00
|
|
|
return FloatAttr::get(type, value);
|
2018-10-21 09:31:49 +08:00
|
|
|
}
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
|
|
|
|
return FloatAttr::get(type, value);
|
2018-07-11 01:59:53 +08:00
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
StringAttr Builder::getStringAttr(StringRef bytes) {
|
2018-07-11 01:59:53 +08:00
|
|
|
return StringAttr::get(bytes, context);
|
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
|
2018-07-11 01:59:53 +08:00
|
|
|
return ArrayAttr::get(value, context);
|
|
|
|
}
|
|
|
|
|
2019-11-12 10:18:02 +08:00
|
|
|
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
|
2019-07-12 02:41:04 +08:00
|
|
|
auto symName =
|
|
|
|
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
|
|
|
assert(symName && "value does not have a valid symbol name");
|
|
|
|
return getSymbolRefAttr(symName.getValue());
|
|
|
|
}
|
2019-11-12 10:18:02 +08:00
|
|
|
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
|
2019-07-12 02:41:04 +08:00
|
|
|
return SymbolRefAttr::get(value, getContext());
|
2019-05-23 04:41:23 +08:00
|
|
|
}
|
2019-11-12 10:18:02 +08:00
|
|
|
SymbolRefAttr
|
|
|
|
Builder::getSymbolRefAttr(StringRef value,
|
|
|
|
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
|
|
|
|
return SymbolRefAttr::get(value, nestedReferences, getContext());
|
|
|
|
}
|
2018-08-20 12:17:22 +08:00
|
|
|
|
2019-04-06 03:19:22 +08:00
|
|
|
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
2019-08-20 16:59:58 +08:00
|
|
|
ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](int64_t v) -> Attribute {
|
|
|
|
return getIntegerAttr(IndexType::get(getContext()), v);
|
|
|
|
},
|
|
|
|
values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
2019-04-06 03:19:22 +08:00
|
|
|
ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
|
|
|
|
auto attrs = functional::map(
|
|
|
|
[this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
|
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
Add a generic Linalg op
This CL introduces a linalg.generic op to represent generic tensor contraction operations on views.
A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion.
These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form.
In the future, most existing operations can be implemented using the generic form.
This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations.
This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the
linalg.generic methods are defensive as such:
```
unsigned getNumInputs() {
// This is redundant with the `n_views` attribute verifier but ordering of verifiers
// may exhibit cases where we crash instead of emitting an error message.
if (!getAttr("n_views") || n_views().getValue().size() != 2)
return 0;
```
In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary.
For instance, implementing linalg.matmul in terms of linalg.generic resembles:
```
func @mac(%a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32
}
#matmul_accesses = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
]
#matmul_trait = {
doc = "C(m, n) += A(m, k) * B(k, n)",
fun = @mac,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
n_loop_types = [2, 1, 0]
}
```
And can be used in multiple places as:
```
linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
```
In the future it would be great to have a mechanism to alias / register a new
linalg.op as a pair of linalg.generic, #trait.
Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it.
PiperOrigin-RevId: 261338740
2019-08-03 00:53:08 +08:00
|
|
|
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
|
|
|
|
auto attrs = functional::map(
|
2019-10-18 11:08:01 +08:00
|
|
|
[](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }, values);
|
Add a generic Linalg op
This CL introduces a linalg.generic op to represent generic tensor contraction operations on views.
A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion.
These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form.
In the future, most existing operations can be implemented using the generic form.
This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations.
This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the
linalg.generic methods are defensive as such:
```
unsigned getNumInputs() {
// This is redundant with the `n_views` attribute verifier but ordering of verifiers
// may exhibit cases where we crash instead of emitting an error message.
if (!getAttr("n_views") || n_views().getValue().size() != 2)
return 0;
```
In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary.
For instance, implementing linalg.matmul in terms of linalg.generic resembles:
```
func @mac(%a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32
}
#matmul_accesses = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
]
#matmul_trait = {
doc = "C(m, n) += A(m, k) * B(k, n)",
fun = @mac,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
n_loop_types = [2, 1, 0]
}
```
And can be used in multiple places as:
```
linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
```
In the future it would be great to have a mechanism to alias / register a new
linalg.op as a pair of linalg.generic, #trait.
Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it.
PiperOrigin-RevId: 261338740
2019-08-03 00:53:08 +08:00
|
|
|
return getArrayAttr(attrs);
|
|
|
|
}
|
|
|
|
|
2019-01-12 01:12:11 +08:00
|
|
|
Attribute Builder::getZeroAttr(Type type) {
|
|
|
|
switch (type.getKind()) {
|
2019-09-03 14:43:36 +08:00
|
|
|
case StandardTypes::BF16:
|
2019-06-06 13:02:22 +08:00
|
|
|
case StandardTypes::F16:
|
2019-01-12 01:12:11 +08:00
|
|
|
case StandardTypes::F32:
|
|
|
|
case StandardTypes::F64:
|
2019-09-03 14:43:36 +08:00
|
|
|
return getFloatAttr(type, 0.0);
|
2019-01-12 01:12:11 +08:00
|
|
|
case StandardTypes::Integer: {
|
|
|
|
auto width = type.cast<IntegerType>().getWidth();
|
|
|
|
if (width == 1)
|
|
|
|
return getBoolAttr(false);
|
|
|
|
return getIntegerAttr(type, APInt(width, 0));
|
|
|
|
}
|
|
|
|
case StandardTypes::Vector:
|
|
|
|
case StandardTypes::RankedTensor: {
|
2019-05-16 15:12:45 +08:00
|
|
|
auto vtType = type.cast<ShapedType>();
|
2019-01-12 01:12:11 +08:00
|
|
|
auto element = getZeroAttr(vtType.getElementType());
|
|
|
|
if (!element)
|
|
|
|
return {};
|
2019-10-18 11:08:01 +08:00
|
|
|
return DenseElementsAttr::get(vtType, element);
|
2019-01-12 01:12:11 +08:00
|
|
|
}
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2018-07-11 01:59:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-10-20 15:11:03 +08:00
|
|
|
// Affine Expressions, Affine Maps, and Integer Sets.
|
2018-07-11 01:59:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
AffineExpr Builder::getAffineDimExpr(unsigned position) {
|
2018-10-09 01:20:25 +08:00
|
|
|
return mlir::getAffineDimExpr(position, context);
|
2018-07-11 01:59:53 +08:00
|
|
|
}
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
|
2018-10-09 01:20:25 +08:00
|
|
|
return mlir::getAffineSymbolExpr(position, context);
|
2018-07-11 01:59:53 +08:00
|
|
|
}
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
|
2018-10-09 01:20:25 +08:00
|
|
|
return mlir::getAffineConstantExpr(constant, context);
|
2018-07-11 01:59:53 +08:00
|
|
|
}
|
|
|
|
|
2019-08-08 01:31:14 +08:00
|
|
|
AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getConstantAffineMap(int64_t val) {
|
2018-10-04 06:39:12 +08:00
|
|
|
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
|
2019-05-30 05:56:41 +08:00
|
|
|
{getAffineConstantExpr(val)});
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getDimIdentityMap() {
|
2018-10-09 01:20:25 +08:00
|
|
|
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
|
2019-05-30 05:56:41 +08:00
|
|
|
{getAffineDimExpr(0)});
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
|
2018-10-09 04:47:18 +08:00
|
|
|
SmallVector<AffineExpr, 4> dimExprs;
|
2018-10-09 02:10:11 +08:00
|
|
|
dimExprs.reserve(rank);
|
|
|
|
for (unsigned i = 0; i < rank; ++i)
|
|
|
|
dimExprs.push_back(getAffineDimExpr(i));
|
2019-05-30 05:56:41 +08:00
|
|
|
return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
|
2018-10-09 02:10:11 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getSymbolIdentityMap() {
|
2018-10-09 01:20:25 +08:00
|
|
|
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
|
2019-05-30 05:56:41 +08:00
|
|
|
{getAffineSymbolExpr(0)});
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
|
2018-10-04 06:34:57 +08:00
|
|
|
// expr = d0 + shift.
|
2018-10-09 01:20:25 +08:00
|
|
|
auto expr = getAffineDimExpr(0) + shift;
|
2019-05-30 05:56:41 +08:00
|
|
|
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
|
2018-09-29 03:17:26 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
|
2018-10-09 04:47:18 +08:00
|
|
|
SmallVector<AffineExpr, 4> shiftedResults;
|
2018-10-10 07:39:24 +08:00
|
|
|
shiftedResults.reserve(map.getNumResults());
|
2019-10-18 11:08:01 +08:00
|
|
|
for (auto resultExpr : map.getResults())
|
2018-10-10 01:59:27 +08:00
|
|
|
shiftedResults.push_back(resultExpr + shift);
|
2019-05-30 05:56:41 +08:00
|
|
|
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
|
2018-09-29 03:17:26 +08:00
|
|
|
}
|
|
|
|
|
2018-07-20 00:52:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-05 10:18:23 +08:00
|
|
|
// OpBuilder.
|
2018-07-25 01:15:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-06-05 10:18:23 +08:00
|
|
|
OpBuilder::~OpBuilder() {}
|
2019-05-18 06:57:49 +08:00
|
|
|
|
2019-12-12 08:26:08 +08:00
|
|
|
/// Insert the given operation at the current insertion point and return it.
|
|
|
|
Operation *OpBuilder::insert(Operation *op) {
|
|
|
|
if (block)
|
|
|
|
block->getOperations().insert(insertPoint, op);
|
|
|
|
return op;
|
|
|
|
}
|
|
|
|
|
2019-07-13 01:43:11 +08:00
|
|
|
/// Add new block and set the insertion point to the end of it. The block is
|
|
|
|
/// inserted at the provided insertion point of 'parent'.
|
|
|
|
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
|
|
|
|
assert(parent && "expected valid parent region");
|
|
|
|
if (insertPt == Region::iterator())
|
|
|
|
insertPt = parent->end();
|
2018-08-25 12:13:19 +08:00
|
|
|
|
2019-07-13 01:43:11 +08:00
|
|
|
Block *b = new Block();
|
|
|
|
parent->getBlocks().insert(insertPt, b);
|
2018-12-28 07:06:22 +08:00
|
|
|
setInsertionPointToEnd(b);
|
2018-07-25 01:15:13 +08:00
|
|
|
return b;
|
|
|
|
}
|
|
|
|
|
2019-07-13 01:43:11 +08:00
|
|
|
/// Add new block and set the insertion point to the end of it. The block is
|
|
|
|
/// placed before 'insertBefore'.
|
|
|
|
Block *OpBuilder::createBlock(Block *insertBefore) {
|
|
|
|
assert(insertBefore && "expected valid insertion block");
|
|
|
|
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
|
|
|
|
}
|
|
|
|
|
2018-08-08 00:12:35 +08:00
|
|
|
/// Create an operation given the fields represented as an OperationState.
|
2019-06-05 10:18:23 +08:00
|
|
|
Operation *OpBuilder::createOperation(const OperationState &state) {
|
2019-12-12 08:26:08 +08:00
|
|
|
return insert(Operation::create(state));
|
2018-08-08 00:12:35 +08:00
|
|
|
}
|
2019-06-06 01:50:10 +08:00
|
|
|
|
|
|
|
/// Attempts to fold the given operation and places new results within
|
2019-12-14 04:21:42 +08:00
|
|
|
/// 'results'. Returns success if the operation was folded, failure otherwise.
|
|
|
|
/// Note: This function does not erase the operation on a successful fold.
|
|
|
|
LogicalResult OpBuilder::tryFold(Operation *op,
|
2019-12-24 06:45:01 +08:00
|
|
|
SmallVectorImpl<Value> &results) {
|
2019-06-06 01:50:10 +08:00
|
|
|
results.reserve(op->getNumResults());
|
2019-12-14 04:21:42 +08:00
|
|
|
auto cleanupFailure = [&] {
|
|
|
|
results.assign(op->result_begin(), op->result_end());
|
|
|
|
return failure();
|
2019-06-06 01:50:10 +08:00
|
|
|
};
|
|
|
|
|
2019-12-14 04:21:42 +08:00
|
|
|
// If this operation is already a constant, there is nothing to do.
|
|
|
|
Attribute unused;
|
|
|
|
if (matchPattern(op, m_Constant(&unused)))
|
|
|
|
return cleanupFailure();
|
|
|
|
|
|
|
|
// Check to see if any operands to the operation is constant and whether
|
|
|
|
// the operation knows how to constant fold itself.
|
2019-06-06 01:50:10 +08:00
|
|
|
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
|
2019-12-14 04:21:42 +08:00
|
|
|
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
|
|
|
|
matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
|
|
|
|
|
|
|
|
// Try to fold the operation.
|
|
|
|
SmallVector<OpFoldResult, 4> foldResults;
|
|
|
|
if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
|
|
|
|
return cleanupFailure();
|
|
|
|
|
|
|
|
// A temporary builder used for creating constants during folding.
|
|
|
|
OpBuilder cstBuilder(context);
|
|
|
|
SmallVector<Operation *, 1> generatedConstants;
|
|
|
|
|
|
|
|
// Populate the results with the folded results.
|
|
|
|
Dialect *dialect = op->getDialect();
|
|
|
|
for (auto &it : llvm::enumerate(foldResults)) {
|
|
|
|
// Normal values get pushed back directly.
|
2019-12-24 06:45:01 +08:00
|
|
|
if (auto value = it.value().dyn_cast<Value>()) {
|
2019-12-14 04:21:42 +08:00
|
|
|
results.push_back(value);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, try to materialize a constant operation.
|
|
|
|
if (!dialect)
|
|
|
|
return cleanupFailure();
|
|
|
|
|
|
|
|
// Ask the dialect to materialize a constant operation for this value.
|
|
|
|
Attribute attr = it.value().get<Attribute>();
|
|
|
|
auto *constOp = dialect->materializeConstant(
|
2020-01-12 00:54:04 +08:00
|
|
|
cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
|
2019-12-14 04:21:42 +08:00
|
|
|
if (!constOp) {
|
|
|
|
// Erase any generated constants.
|
|
|
|
for (Operation *cst : generatedConstants)
|
|
|
|
cst->erase();
|
|
|
|
return cleanupFailure();
|
|
|
|
}
|
|
|
|
assert(matchPattern(constOp, m_Constant(&attr)));
|
|
|
|
|
|
|
|
generatedConstants.push_back(constOp);
|
|
|
|
results.push_back(constOp->getResult(0));
|
2019-06-06 01:50:10 +08:00
|
|
|
}
|
|
|
|
|
2019-12-14 04:21:42 +08:00
|
|
|
// If we were successful, insert any generated constants.
|
|
|
|
for (Operation *cst : generatedConstants)
|
|
|
|
insert(cst);
|
|
|
|
|
|
|
|
return success();
|
2019-06-06 01:50:10 +08:00
|
|
|
}
|