forked from OSchip/llvm-project
[fir] Add fir.save_result op
Add the fir.save_result operation. It is use to save an array, box, or record function result SSA-value to a memory location Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D110407 Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: Valentin Clement <clementval@gmail.com>
This commit is contained in:
parent
764d9aa979
commit
5b5ef2e265
|
@ -352,6 +352,55 @@ def fir_LoadOp : fir_OneResultOp<"load"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def fir_SaveResultOp : fir_Op<"save_result", [AttrSizedOperandSegments]> {
|
||||
let summary = [{
|
||||
save an array, box, or record function result SSA-value to a memory location
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Save the result of a function returning an array, box, or record type value
|
||||
into a memory location given the shape and length parameters of the result.
|
||||
|
||||
Function results of type fir.box, fir.array, or fir.rec are abstract values
|
||||
that require a storage to be manipulated on the caller side. This operation
|
||||
allows associating such abstract result to a storage. In later lowering of
|
||||
the function interfaces, this storage might be used to pass the result in
|
||||
memory.
|
||||
|
||||
For arrays, result, it is required to provide the shape of the result. For
|
||||
character arrays and derived types with length parameters, the length
|
||||
parameter values must be provided.
|
||||
|
||||
The fir.save_result associated to a function call must immediately follow
|
||||
the call and be in the same block.
|
||||
|
||||
```mlir
|
||||
%buffer = fir.alloca fir.array<?xf32>, %c100
|
||||
%shape = fir.shape %c100
|
||||
%array_result = fir.call @foo() : () -> fir.array<?xf32>
|
||||
fir.save_result %array_result to %buffer(%shape)
|
||||
%coor = fir.array_coor %buffer%(%shape), %c5
|
||||
%fifth_element = fir.load %coor : f32
|
||||
```
|
||||
|
||||
The above fir.save_result allows saving a fir.array function result into
|
||||
a buffer to later access its 5th element.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins ArrayOrBoxOrRecord:$value,
|
||||
Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
|
||||
Optional<AnyShapeType>:$shape,
|
||||
Variadic<AnyIntegerType>:$typeparams);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$value `to` $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)?
|
||||
attr-dict `:` type(operands)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def fir_StoreOp : fir_Op<"store", []> {
|
||||
let summary = "store an SSA-value to a memory location";
|
||||
|
||||
|
|
|
@ -551,4 +551,9 @@ def AnyCoordinateType : Type<AnyCoordinateLike.predicate, "coordinate type">;
|
|||
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
|
||||
FunctionType.predicate]>, "any addressable">;
|
||||
|
||||
def ArrayOrBoxOrRecord : TypeConstraint<Or<[fir_SequenceType.predicate,
|
||||
fir_BoxType.predicate, fir_RecordType.predicate]>,
|
||||
"fir.box, fir.array or fir.type">;
|
||||
|
||||
|
||||
#endif // FIR_DIALECT_FIR_TYPES
|
||||
|
|
|
@ -1361,6 +1361,63 @@ static mlir::LogicalResult verify(fir::ResultOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SaveResultOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::LogicalResult verify(fir::SaveResultOp op) {
|
||||
auto resultType = op.value().getType();
|
||||
if (resultType != fir::dyn_cast_ptrEleTy(op.memref().getType()))
|
||||
return op.emitOpError("value type must match memory reference type");
|
||||
if (fir::isa_unknown_size_box(resultType))
|
||||
return op.emitOpError("cannot save !fir.box of unknown rank or type");
|
||||
|
||||
if (resultType.isa<fir::BoxType>()) {
|
||||
if (op.shape() || !op.typeparams().empty())
|
||||
return op.emitOpError(
|
||||
"must not have shape or length operands if the value is a fir.box");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// fir.record or fir.array case.
|
||||
unsigned shapeTyRank = 0;
|
||||
if (auto shapeOp = op.shape()) {
|
||||
auto shapeTy = shapeOp.getType();
|
||||
if (auto s = shapeTy.dyn_cast<fir::ShapeType>())
|
||||
shapeTyRank = s.getRank();
|
||||
else
|
||||
shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank();
|
||||
}
|
||||
|
||||
auto eleTy = resultType;
|
||||
if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) {
|
||||
if (seqTy.getDimension() != shapeTyRank)
|
||||
op.emitOpError("shape operand must be provided and have the value rank "
|
||||
"when the value is a fir.array");
|
||||
eleTy = seqTy.getEleTy();
|
||||
} else {
|
||||
if (shapeTyRank != 0)
|
||||
op.emitOpError(
|
||||
"shape operand should only be provided if the value is a fir.array");
|
||||
}
|
||||
|
||||
if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
|
||||
if (recTy.getNumLenParams() != op.typeparams().size())
|
||||
op.emitOpError("length parameters number must match with the value type "
|
||||
"length parameters");
|
||||
} else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) {
|
||||
if (op.typeparams().size() > 1)
|
||||
op.emitOpError("no more than one length parameter must be provided for "
|
||||
"character value");
|
||||
} else {
|
||||
if (!op.typeparams().empty())
|
||||
op.emitOpError(
|
||||
"length parameters must not be provided for this value type");
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -671,3 +671,14 @@ func @test_rebox(%arg0: !fir.box<!fir.array<?xf32>>) {
|
|||
fir.call @bar_rebox_test(%4) : (!fir.box<!fir.array<?x?xf32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_save_result(
|
||||
func @test_save_result(%buffer: !fir.ref<!fir.array<?x!fir.char<1,?>>>) {
|
||||
%c100 = constant 100 : index
|
||||
%c50 = constant 50 : index
|
||||
%shape = fir.shape %c100 : (index) -> !fir.shape<1>
|
||||
%res = fir.call @array_func() : () -> !fir.array<?x!fir.char<1,?>>
|
||||
// CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
|
||||
fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
|
||||
return
|
||||
}
|
||||
|
|
|
@ -417,3 +417,80 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
|
|||
%2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
|
||||
fir.has_value %2 : !fir.array<32x32xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf64>>, %n :index) {
|
||||
%res = fir.call @array_func() : () -> !fir.array<?xf32>
|
||||
%shape = fir.shape %n : (index) -> !fir.shape<1>
|
||||
// expected-error@+1 {{'fir.save_result' op value type must match memory reference type}}
|
||||
fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf64>>, !fir.shape<1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<*:f32>>>) {
|
||||
%res = fir.call @array_func() : () -> !fir.box<!fir.array<*:f32>>
|
||||
// expected-error@+1 {{'fir.save_result' op cannot save !fir.box of unknown rank or type}}
|
||||
fir.save_result %res to %buffer : !fir.box<!fir.array<*:f32>>, !fir.ref<!fir.box<!fir.array<*:f32>>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<f64>) {
|
||||
%res = fir.call @array_func() : () -> f64
|
||||
// expected-error@+1 {{'fir.save_result' op operand #0 must be fir.box, fir.array or fir.type, but got 'f64'}}
|
||||
fir.save_result %res to %buffer : f64, !fir.ref<f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<?xf32>>>, %n : index) {
|
||||
%res = fir.call @array_func() : () -> !fir.box<!fir.array<?xf32>>
|
||||
%shape = fir.shape %n : (index) -> !fir.shape<1>
|
||||
// expected-error@+1 {{'fir.save_result' op must not have shape or length operands if the value is a fir.box}}
|
||||
fir.save_result %res to %buffer(%shape) : !fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.shape<1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
|
||||
%res = fir.call @array_func() : () -> !fir.array<?xf32>
|
||||
%shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
|
||||
// expected-error@+1 {{'fir.save_result' op shape operand must be provided and have the value rank when the value is a fir.array}}
|
||||
fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<2>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
|
||||
%res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
|
||||
%shape = fir.shape %n : (index) -> !fir.shape<1>
|
||||
// expected-error@+1 {{'fir.save_result' op shape operand should only be provided if the value is a fir.array}}
|
||||
fir.save_result %res to %buffer(%shape) : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, !fir.shape<1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
|
||||
%res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
|
||||
// expected-error@+1 {{'fir.save_result' op length parameters number must match with the value type length parameters}}
|
||||
fir.save_result %res to %buffer typeparams %n : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
|
||||
%res = fir.call @array_func() : () -> !fir.array<?xf32>
|
||||
%shape = fir.shape %n : (index) -> !fir.shape<1>
|
||||
// expected-error@+1 {{'fir.save_result' op length parameters must not be provided for this value type}}
|
||||
fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue