forked from OSchip/llvm-project
BEGIN_PUBLIC
[mlir] Add support for unranked case for `tensor_store` and `tensor_load` ops. END_PUBLIC Differential Revision: https://reviews.llvm.org/D85518
This commit is contained in:
parent
87a89e0f77
commit
9c94908320
|
@ -2934,24 +2934,26 @@ def TensorLoadOp : Std_Op<"tensor_load",
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
|
||||
[MemRead]>:$memref);
|
||||
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
|
||||
"the reference to load from", [MemRead]>:$memref);
|
||||
let results = (outs AnyTensor:$result);
|
||||
// TensorLoadOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value memref", [{
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto resultType = RankedTensorType::get(memrefType.getShape(),
|
||||
memrefType.getElementType());
|
||||
result.addOperands(memref);
|
||||
result.addTypes(resultType);
|
||||
result.addTypes(getTensorTypeFromMemRefType(memref.getType()));
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// The result of a tensor_load is always a tensor.
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
TensorType getType() {
|
||||
Type resultType = getResult().getType();
|
||||
if (resultType.isa<TensorType>())
|
||||
return resultType.cast<TensorType>();
|
||||
return {};
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||
|
@ -2981,9 +2983,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$tensor,
|
||||
Arg<AnyMemRef, "the reference to store to",
|
||||
[MemWrite]>:$memref);
|
||||
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
|
||||
"the reference to store to", [MemWrite]>:$memref);
|
||||
// TensorStoreOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
||||
|
|
|
@ -2985,6 +2985,17 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
|
|||
static Type getTensorTypeFromMemRefType(Type type) {
|
||||
if (auto memref = type.dyn_cast<MemRefType>())
|
||||
return RankedTensorType::get(memref.getShape(), memref.getElementType());
|
||||
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedTensorType::get(memref.getElementType());
|
||||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
static Type getMemRefTypeFromTensorType(Type type) {
|
||||
if (auto tensor = type.dyn_cast<MemRefType>())
|
||||
return MemRefType::get(tensor.getShape(), tensor.getElementType());
|
||||
if (auto tensor = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedMemRefType::get(tensor.getElementType(),
|
||||
tensor.getMemorySpace());
|
||||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
|
|
|
@ -813,6 +813,15 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unranked_tensor_load_store
|
||||
func @unranked_tensor_load_store(%0 : memref<*xi32>) {
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<*xi32>
|
||||
%1 = tensor_load %0 : memref<*xi32>
|
||||
// CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<*xi32>
|
||||
tensor_store %1, %0 : memref<*xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw
|
||||
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
|
||||
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
|
||||
|
|
Loading…
Reference in New Issue