From ee8cbccacfc8755d1692ff64ad98876917a08b30 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 10 Sep 2019 17:47:37 -0700 Subject: [PATCH] Add folding rule for spv.CompositeExtract If the composite is a constant, we can fold it away. This only supports vector and array constants for now, given that struct constant is not supported in spv.constant yet. PiperOrigin-RevId: 268350340 --- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 2 + mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 35 +++++++++++++++++- mlir/test/Dialect/SPIRV/canonicalize.mlir | 41 +++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 6aad60009aff..8d1a19a6ef85 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -160,6 +160,8 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let results = (outs SPV_Type:$component ); + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 0873eb0c9a01..817618d13aba 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/Functional.h" #include "mlir/Support/StringExtras.h" using namespace mlir; @@ -311,6 +312,28 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer, printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); } +// Extracts an element from the given `composite` by following the given +// `indices`. Returns a null Attribute if error happens. +static Attribute extractCompositeElement(Attribute composite, + ArrayRef indices) { + // Return composite itself if we reach the end of the index chain. + if (indices.empty()) + return composite; + + if (auto vector = composite.dyn_cast()) { + assert(indices.size() == 1 && "must have exactly one index for a vector"); + return vector.getValue({indices[0]}); + } + + if (auto array = composite.dyn_cast()) { + assert(!indices.empty() && "must have at least one index for an array"); + return extractCompositeElement(array.getValue()[indices[0]], + indices.drop_front()); + } + + return {}; +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// @@ -700,6 +723,16 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) { return success(); } +OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "spv.CompositeExtract expects one operand"); + auto indexVector = functional::map( + [](Attribute attr) { + return static_cast(attr.cast().getInt()); + }, + indices()); + return extractCompositeElement(operands[0], indexVector); +} + //===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// @@ -768,7 +801,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) { } OpFoldResult spirv::ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); + assert(operands.empty() && "spv.constant has no operands"); return value(); } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 24e4e9967c7e..a91286053b01 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -1,5 +1,45 @@ // RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.CompositeExtract +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: extract_vector +func @extract_vector() -> (i32, i32, i32) { + // CHECK: spv.constant 42 : i32 + // CHECK: spv.constant -33 : i32 + // CHECK: spv.constant 6 : i32 + %0 = spv.constant dense<[42, -33, 6]> : vector<3xi32> + %1 = spv.CompositeExtract %0[0 : i32] : vector<3xi32> + %2 = spv.CompositeExtract %0[1 : i32] : vector<3xi32> + %3 = spv.CompositeExtract %0[2 : i32] : vector<3xi32> + return %1, %2, %3 : i32, i32, i32 +} + +// ----- + +// CHECK-LABEL: extract_array_final +func @extract_array_final() -> (i32, i32) { + // CHECK: spv.constant 4 : i32 + // CHECK: spv.constant -5 : i32 + %0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>> + %1 = spv.CompositeExtract %0[0 : i32, 0 : i32] : !spv.array<1 x vector<2 x i32>> + %2 = spv.CompositeExtract %0[0 : i32, 1 : i32] : !spv.array<1 x vector<2 x i32>> + return %1, %2 : i32, i32 +} + +// ----- + +// CHECK-LABEL: extract_array_interm +func @extract_array_interm() -> (vector<2xi32>) { + // CHECK: spv.constant dense<[4, -5]> : vector<2xi32> + %0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x vector<2 x i32>> + return %1 : vector<2xi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// @@ -33,3 +73,4 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a // CHECK-NEXT: return %[[CST]], %[[CST]] return %0, %1 : !spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>> } +