diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index d2efe2c30962..6feb8e0fa4a8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -465,6 +465,7 @@ def Vector_ShuffleOp : let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef")> ]; + let hasFolder = 1; let extraClassDeclaration = [{ static StringRef getMaskAttrName() { return "mask"; } VectorType getV1VectorType() { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c2515f706122..b576005c47e9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1803,6 +1803,33 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { return success(); } +OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { + Attribute lhs = operands.front(), rhs = operands.back(); + if (!lhs || !rhs) + return {}; + + auto lhsType = lhs.getType().cast(); + // Only support 1-D for now to avoid complicated n-D DenseElementsAttr + // manipulation. + if (lhsType.getRank() != 1) + return {}; + int64_t lhsSize = lhsType.getDimSize(0); + + SmallVector results; + auto lhsElements = lhs.cast().getValues(); + auto rhsElements = rhs.cast().getValues(); + for (const auto &index : this->mask().getAsValueRange()) { + int64_t i = index.getZExtValue(); + if (i >= lhsSize) { + results.push_back(rhsElements[i - lhsSize]); + } else { + results.push_back(lhsElements[i]); + } + } + + return DenseElementsAttr::get(getVectorType(), results); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a25b0687ca75..522f8dea8b47 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1254,3 +1254,15 @@ func @splat_fold() -> vector<4xf32> { // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> // CHECK-NEXT: return [[V]] : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle_1d +// CHECK: %[[V:.+]] = arith.constant dense<[3, 2, 5, 1]> : vector<4xi32> +// CHECK: return %[[V]] +func @shuffle_1d() -> vector<4xi32> { + %v0 = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %v1 = arith.constant dense<[3, 4, 5]> : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +}