diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index cfcda1f24214..7216e3d2ed5c 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1204,9 +1204,10 @@ def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> { let summary = "cast between index and integer types"; let description = [{ - Casts between integer scalars and 'index' scalars. Index is an integer of - platform-specific bit width. If casting to a wider integer, the value is - sign-extended. If casting to a narrower integer, the value is truncated. + Casts between scalar or vector integers and corresponding 'index' scalar or + vectors. Index is an integer of platform-specific bit width. If casting to + a wider integer, the value is sign-extended. If casting to a narrower + integer, the value is truncated. }]; let hasFolder = 1; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 61074382470e..8e808d75e205 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3104,11 +3104,15 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { IndexCastOpAdaptor transformed(operands); auto targetType = - typeConverter->convertType(indexCastOp.getResult().getType()) + typeConverter->convertType(indexCastOp.getResult().getType()); + auto targetElementType = + typeConverter + ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) .cast(); - auto sourceType = transformed.in().getType().cast(); - unsigned targetBits = targetType.getWidth(); - unsigned sourceBits = sourceType.getWidth(); + auto sourceElementType = + getElementTypeOrSelf(transformed.in()).cast(); + unsigned targetBits = targetElementType.getWidth(); + unsigned sourceBits = sourceElementType.getWidth(); if (targetBits == sourceBits) rewriter.replaceOp(indexCastOp, transformed.in()); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index b4942de204ab..78a03372904d 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -520,6 +520,15 @@ func @index_cast(%arg0: index, %arg1: i1) { return } +// CHECK-LABEL: @vector_index_cast +func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) { +// CHECK-NEXT: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1> + %0 = index_cast %arg0: vector<2xindex> to vector<2xi1> +// CHECK-NEXT: = llvm.sext %{{.*}} : vector<2xi1> to vector<2xi{{.*}}> + %1 = index_cast %arg1: vector<2xi1> to vector<2xindex> + return +} + // Checking conversion of signed integer types to floating point. // CHECK-LABEL: @sitofp func @sitofp(%arg0 : i32, %arg1 : i64) {