forked from OSchip/llvm-project
Add folding rule for spv.CompositeExtract
If the composite is a constant, we can fold it away. This only supports vector and array constants for now, given that struct constant is not supported in spv.constant yet. PiperOrigin-RevId: 268350340
This commit is contained in:
parent
cf0a782339
commit
ee8cbccacf
|
@ -160,6 +160,8 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
|
|||
let results = (outs
|
||||
SPV_Type:$component
|
||||
);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -311,6 +312,28 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
|
|||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
}
|
||||
|
||||
// Extracts an element from the given `composite` by following the given
|
||||
// `indices`. Returns a null Attribute if error happens.
|
||||
static Attribute extractCompositeElement(Attribute composite,
|
||||
ArrayRef<unsigned> indices) {
|
||||
// Return composite itself if we reach the end of the index chain.
|
||||
if (indices.empty())
|
||||
return composite;
|
||||
|
||||
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
|
||||
assert(indices.size() == 1 && "must have exactly one index for a vector");
|
||||
return vector.getValue({indices[0]});
|
||||
}
|
||||
|
||||
if (auto array = composite.dyn_cast<ArrayAttr>()) {
|
||||
assert(!indices.empty() && "must have at least one index for an array");
|
||||
return extractCompositeElement(array.getValue()[indices[0]],
|
||||
indices.drop_front());
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.AccessChainOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -700,6 +723,16 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
|
||||
auto indexVector = functional::map(
|
||||
[](Attribute attr) {
|
||||
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
||||
},
|
||||
indices());
|
||||
return extractCompositeElement(operands[0], indexVector);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -768,7 +801,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
|||
}
|
||||
|
||||
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
assert(operands.empty() && "spv.constant has no operands");
|
||||
return value();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,45 @@
|
|||
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CompositeExtract
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: extract_vector
|
||||
func @extract_vector() -> (i32, i32, i32) {
|
||||
// CHECK: spv.constant 42 : i32
|
||||
// CHECK: spv.constant -33 : i32
|
||||
// CHECK: spv.constant 6 : i32
|
||||
%0 = spv.constant dense<[42, -33, 6]> : vector<3xi32>
|
||||
%1 = spv.CompositeExtract %0[0 : i32] : vector<3xi32>
|
||||
%2 = spv.CompositeExtract %0[1 : i32] : vector<3xi32>
|
||||
%3 = spv.CompositeExtract %0[2 : i32] : vector<3xi32>
|
||||
return %1, %2, %3 : i32, i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_array_final
|
||||
func @extract_array_final() -> (i32, i32) {
|
||||
// CHECK: spv.constant 4 : i32
|
||||
// CHECK: spv.constant -5 : i32
|
||||
%0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
|
||||
%1 = spv.CompositeExtract %0[0 : i32, 0 : i32] : !spv.array<1 x vector<2 x i32>>
|
||||
%2 = spv.CompositeExtract %0[0 : i32, 1 : i32] : !spv.array<1 x vector<2 x i32>>
|
||||
return %1, %2 : i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_array_interm
|
||||
func @extract_array_interm() -> (vector<2xi32>) {
|
||||
// CHECK: spv.constant dense<[4, -5]> : vector<2xi32>
|
||||
%0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
|
||||
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x vector<2 x i32>>
|
||||
return %1 : vector<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -33,3 +73,4 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
|
|||
// CHECK-NEXT: return %[[CST]], %[[CST]]
|
||||
return %0, %1 : !spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue