Add support for array-typed constants.

PiperOrigin-RevId: 267121729
This commit is contained in:
MLIR Team 2019-09-04 03:45:38 -07:00 committed by A. Unique TensorFlower
parent 71d27dfc3b
commit 2f13df13b0
2 changed files with 29 additions and 11 deletions

View File

@ -85,23 +85,35 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
return functionMapping.lookup(funcAttr.getValue());
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
auto *child = getLLVMConstant(vectorType->getElementType(),
splatAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
auto *sequentialType = cast<llvm::SequentialType>(llvmType);
auto elementType = sequentialType->getElementType();
uint64_t numElements = sequentialType->getNumElements();
auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc);
if (llvmType->isVectorTy())
return llvm::ConstantVector::getSplat(numElements, child);
if (llvmType->isArrayTy()) {
auto arrayType = llvm::ArrayType::get(elementType, numElements);
SmallVector<llvm::Constant *, 8> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
}
}
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
auto *sequentialType = cast<llvm::SequentialType>(llvmType);
auto elementType = sequentialType->getElementType();
uint64_t numElements = sequentialType->getNumElements();
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);
for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back(
getLLVMConstant(vectorType->getElementType(), n, loc));
constants.push_back(getLLVMConstant(elementType, n, loc));
if (!constants.back())
return nullptr;
}
return llvm::ConstantVector::get(constants);
if (llvmType->isVectorTy())
return llvm::ConstantVector::get(constants);
if (llvmType->isArrayTy()) {
auto arrayType = llvm::ArrayType::get(elementType, numElements);
return llvm::ConstantArray::get(arrayType, constants);
}
}
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
return llvm::ConstantDataArray::get(

View File

@ -3,12 +3,18 @@
// CHECK: @i32_global = internal global i32 42
llvm.mlir.global @i32_global(42: i32) : !llvm.i32
// CHECK: @i32_global_const = internal constant i53 52
llvm.mlir.global constant @i32_global_const(52: i53) : !llvm.i53
// CHECK: @i32_const = internal constant i53 52
llvm.mlir.global constant @i32_const(52: i53) : !llvm.i53
// CHECK: @int_global_array = internal global [3 x i32] [i32 62, i32 62, i32 62]
llvm.mlir.global @int_global_array(dense<62> : vector<3xi32>) : !llvm<"[3 x i32]">
// CHECK: @float_global = internal global float 0.000000e+00
llvm.mlir.global @float_global(0.0: f32) : !llvm.float
// CHECK: @float_global_array = internal global [1 x float] [float -5.000000e+00]
llvm.mlir.global @float_global_array(dense<[-5.0]> : vector<1xf32>) : !llvm<"[1 x float]">
// CHECK: @string_const = internal constant [6 x i8] c"foobar"
llvm.mlir.global constant @string_const("foobar") : !llvm<"[6 x i8]">