BEGIN_PUBLIC

[mlir] Add support for unranked case for `tensor_store` and `tensor_load` ops.
END_PUBLIC

Differential Revision: https://reviews.llvm.org/D85518
This commit is contained in:
Alexander Belyaev 2020-08-07 14:31:02 +02:00
parent 87a89e0f77
commit 9c94908320
3 changed files with 31 additions and 10 deletions

View File

@ -2934,24 +2934,26 @@ def TensorLoadOp : Std_Op<"tensor_load",
```
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref);
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
"the reference to load from", [MemRead]>:$memref);
let results = (outs AnyTensor:$result);
// TensorLoadOp is fully verified by traits.
let verifier = ?;
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value memref", [{
auto memrefType = memref.getType().cast<MemRefType>();
auto resultType = RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType());
result.addOperands(memref);
result.addTypes(resultType);
result.addTypes(getTensorTypeFromMemRefType(memref.getType()));
}]>];
let extraClassDeclaration = [{
/// The result of a tensor_load is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
TensorType getType() {
Type resultType = getResult().getType();
if (resultType.isa<TensorType>())
return resultType.cast<TensorType>();
return {};
}
}];
let assemblyFormat = "$memref attr-dict `:` type($memref)";
@ -2981,9 +2983,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
```
}];
let arguments = (ins AnyTensor:$tensor,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref);
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
"the reference to store to", [MemWrite]>:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;

View File

@ -2985,6 +2985,17 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
static Type getTensorTypeFromMemRefType(Type type) {
if (auto memref = type.dyn_cast<MemRefType>())
return RankedTensorType::get(memref.getShape(), memref.getElementType());
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
return UnrankedTensorType::get(memref.getElementType());
return NoneType::get(type.getContext());
}
static Type getMemRefTypeFromTensorType(Type type) {
if (auto tensor = type.dyn_cast<MemRefType>())
return MemRefType::get(tensor.getShape(), tensor.getElementType());
if (auto tensor = type.dyn_cast<UnrankedMemRefType>())
return UnrankedMemRefType::get(tensor.getElementType(),
tensor.getMemorySpace());
return NoneType::get(type.getContext());
}

View File

@ -813,6 +813,15 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
return
}
// CHECK-LABEL: func @unranked_tensor_load_store
func @unranked_tensor_load_store(%0 : memref<*xi32>) {
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<*xi32>
%1 = tensor_load %0 : memref<*xi32>
// CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<*xi32>
tensor_store %1, %0 : memref<*xi32>
return
}
// CHECK-LABEL: func @atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {