diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b616213ada5a..ca801fb3280a 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -1617,18 +1617,24 @@ struct MemorySanitizerVisitor : public InstVisitor { Type *EltTy = Ty->getSequentialElementType(); SmallVector Elements; for (unsigned Idx = 0; Idx < NumElements; ++Idx) { - ConstantInt *Elt = - dyn_cast(ConstArg->getAggregateElement(Idx)); - APInt V = Elt->getValue(); - APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); - Elements.push_back(ConstantInt::get(EltTy, V2)); + if (ConstantInt *Elt = + dyn_cast(ConstArg->getAggregateElement(Idx))) { + APInt V = Elt->getValue(); + APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); + Elements.push_back(ConstantInt::get(EltTy, V2)); + } else { + Elements.push_back(ConstantInt::get(EltTy, 1)); + } } ShadowMul = ConstantVector::get(Elements); } else { - ConstantInt *Elt = dyn_cast(ConstArg); - APInt V = Elt->getValue(); - APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); - ShadowMul = ConstantInt::get(Elt->getType(), V2); + if (ConstantInt *Elt = dyn_cast(ConstArg)) { + APInt V = Elt->getValue(); + APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); + ShadowMul = ConstantInt::get(Ty, V2); + } else { + ShadowMul = ConstantInt::get(Ty, 1); + } } IRBuilder<> IRB(&I); diff --git a/llvm/test/Instrumentation/MemorySanitizer/mul_by_constant.ll b/llvm/test/Instrumentation/MemorySanitizer/mul_by_constant.ll index e068f69ae4ba..7736d94717fe 100644 --- a/llvm/test/Instrumentation/MemorySanitizer/mul_by_constant.ll +++ b/llvm/test/Instrumentation/MemorySanitizer/mul_by_constant.ll @@ -92,3 +92,26 @@ entry: ; CHECK: [[A:%.*]] = load {{.*}} @__msan_param_tls ; CHECK: [[B:%.*]] = mul <4 x i32> [[A]], ; CHECK: store <4 x i32> [[B]], <4 x i32>* {{.*}} @__msan_retval_tls + + +; The constant in multiplication does not have to be a literal integer constant. +@X = linkonce_odr global i8* null +define i64 @MulNonIntegerConst(i64 %a) sanitize_memory { + %mul = mul i64 %a, ptrtoint (i8** @X to i64) + ret i64 %mul +} + +; CHECK-LABEL: @MulNonIntegerConst( +; CHECK: [[A:%.*]] = load {{.*}} @__msan_param_tls +; CHECK: [[B:%.*]] = mul i64 [[A]], 1 +; CHECK: store i64 [[B]], {{.*}}@__msan_retval_tls + +define <2 x i64> @MulNonIntegerVectorConst(<2 x i64> %a) sanitize_memory { + %mul = mul <2 x i64> %a, + ret <2 x i64> %mul +} + +; CHECK-LABEL: @MulNonIntegerVectorConst( +; CHECK: [[A:%.*]] = load {{.*}} @__msan_param_tls +; CHECK: [[B:%.*]] = mul <2 x i64> [[A]], +; CHECK: store <2 x i64> [[B]], {{.*}}@__msan_retval_tls