diff --git a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp index c6f402ffc404..a5b094c777e4 100644 --- a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp @@ -64,9 +64,14 @@ void linalg::DotOp::emitScalarImplementation( using edsc::intrinsics::select; ScopedContext scope( // account for affine.terminator in loop. FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); - auto f32 = ScopedContext::getBuilder()->getF32Type(); + FloatType fTy = getOperand(0) + ->getType() + .cast() + .getElementType() + .cast(); IndexHandle zero(constant_index(0)); - ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + ValueHandle zerof = + constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy); IndexHandle r_i(reductionIvs[0]); IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); ValueHandle cond = (r_i == zero); @@ -129,11 +134,16 @@ void linalg::MatvecOp::emitScalarImplementation( using edsc::intrinsics::select; ScopedContext scope( // account for affine.terminator in loop. FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); - auto f32 = ScopedContext::getBuilder()->getF32Type(); + FloatType fTy = getOperand(0) + ->getType() + .cast() + .getElementType() + .cast(); IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]); IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); IndexHandle zero(constant_index(0)); - ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + ValueHandle zerof = + constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy); ValueHandle cond = (r_j == zero); ValueHandle scalarC = select(cond, zerof, *C(i)); C(i) = scalarC + A(i, r_j) * B(r_j); @@ -198,11 +208,16 @@ void linalg::MatmulOp::emitScalarImplementation( using edsc::intrinsics::select; ScopedContext scope( // account for affine.terminator in loop. FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); - auto f32 = ScopedContext::getBuilder()->getF32Type(); + FloatType fTy = getOperand(0) + ->getType() + .cast() + .getElementType() + .cast(); IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]); IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); IndexHandle zero(constant_index(0)); - ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + ValueHandle zerof = + constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy); ValueHandle cond = r_k == zero; ValueHandle scalarC = select(cond, zerof, *C(i, j)); C(i, j) = scalarC + A(i, r_k) * B(r_k, j);