Add folding rule for spv.CompositeExtract

If the composite is a constant, we can fold it away. This only
supports vector and array constants for now, given that struct
constant is not supported in spv.constant yet.

PiperOrigin-RevId: 268350340
This commit is contained in:
Lei Zhang 2019-09-10 17:47:37 -07:00 committed by A. Unique TensorFlower
parent cf0a782339
commit ee8cbccacf
3 changed files with 77 additions and 1 deletions

View File

@ -160,6 +160,8 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
let results = (outs
SPV_Type:$component
);
let hasFolder = 1;
}
// -----

View File

@ -27,6 +27,7 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/StringExtras.h"
using namespace mlir;
@ -311,6 +312,28 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
// Extracts an element from the given `composite` by following the given
// `indices`. Returns a null Attribute if error happens.
static Attribute extractCompositeElement(Attribute composite,
ArrayRef<unsigned> indices) {
// Return composite itself if we reach the end of the index chain.
if (indices.empty())
return composite;
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
assert(indices.size() == 1 && "must have exactly one index for a vector");
return vector.getValue({indices[0]});
}
if (auto array = composite.dyn_cast<ArrayAttr>()) {
assert(!indices.empty() && "must have at least one index for an array");
return extractCompositeElement(array.getValue()[indices[0]],
indices.drop_front());
}
return {};
}
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
@ -700,6 +723,16 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
return success();
}
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
auto indexVector = functional::map(
[](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
},
indices());
return extractCompositeElement(operands[0], indexVector);
}
//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
@ -768,7 +801,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
}
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
assert(operands.empty() && "spv.constant has no operands");
return value();
}

View File

@ -1,5 +1,45 @@
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.CompositeExtract
//===----------------------------------------------------------------------===//
// CHECK-LABEL: extract_vector
func @extract_vector() -> (i32, i32, i32) {
// CHECK: spv.constant 42 : i32
// CHECK: spv.constant -33 : i32
// CHECK: spv.constant 6 : i32
%0 = spv.constant dense<[42, -33, 6]> : vector<3xi32>
%1 = spv.CompositeExtract %0[0 : i32] : vector<3xi32>
%2 = spv.CompositeExtract %0[1 : i32] : vector<3xi32>
%3 = spv.CompositeExtract %0[2 : i32] : vector<3xi32>
return %1, %2, %3 : i32, i32, i32
}
// -----
// CHECK-LABEL: extract_array_final
func @extract_array_final() -> (i32, i32) {
// CHECK: spv.constant 4 : i32
// CHECK: spv.constant -5 : i32
%0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
%1 = spv.CompositeExtract %0[0 : i32, 0 : i32] : !spv.array<1 x vector<2 x i32>>
%2 = spv.CompositeExtract %0[0 : i32, 1 : i32] : !spv.array<1 x vector<2 x i32>>
return %1, %2 : i32, i32
}
// -----
// CHECK-LABEL: extract_array_interm
func @extract_array_interm() -> (vector<2xi32>) {
// CHECK: spv.constant dense<[4, -5]> : vector<2xi32>
%0 = spv.constant [dense<[4, -5]> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x vector<2 x i32>>
return %1 : vector<2xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
@ -33,3 +73,4 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
// CHECK-NEXT: return %[[CST]], %[[CST]]
return %0, %1 : !spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>
}