Fix Linalg3 lowering to use the floating point element type matching the view

It used to be hardcoded to f32, but Toy tutorial is using f64.

--

PiperOrigin-RevId: 242370172
This commit is contained in:
Mehdi Amini 2019-04-07 13:12:48 -07:00 committed by Mehdi Amini
parent 364b7e624e
commit fea0560816
1 changed files with 21 additions and 6 deletions

View File

@ -64,9 +64,14 @@ void linalg::DotOp::emitScalarImplementation(
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
.getElementType()
.cast<FloatType>();
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
ValueHandle zerof =
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
IndexHandle r_i(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
ValueHandle cond = (r_i == zero);
@ -129,11 +134,16 @@ void linalg::MatvecOp::emitScalarImplementation(
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
.getElementType()
.cast<FloatType>();
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
ValueHandle zerof =
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
ValueHandle cond = (r_j == zero);
ValueHandle scalarC = select(cond, zerof, *C(i));
C(i) = scalarC + A(i, r_j) * B(r_j);
@ -198,11 +208,16 @@ void linalg::MatmulOp::emitScalarImplementation(
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
FloatType fTy = getOperand(0)
->getType()
.cast<ViewType>()
.getElementType()
.cast<FloatType>();
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
ValueHandle zerof =
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
ValueHandle cond = r_k == zero;
ValueHandle scalarC = select(cond, zerof, *C(i, j));
C(i, j) = scalarC + A(i, r_k) * B(r_k, j);