forked from OSchip/llvm-project
Fix Linalg3 lowering to use the floating point element type matching the view
It used to be hardcoded to f32, but Toy tutorial is using f64. -- PiperOrigin-RevId: 242370172
This commit is contained in:
parent
364b7e624e
commit
fea0560816
|
@ -64,9 +64,14 @@ void linalg::DotOp::emitScalarImplementation(
|
|||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
IndexHandle r_i(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
ValueHandle cond = (r_i == zero);
|
||||
|
@ -129,11 +134,16 @@ void linalg::MatvecOp::emitScalarImplementation(
|
|||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
ValueHandle cond = (r_j == zero);
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i));
|
||||
C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
|
@ -198,11 +208,16 @@ void linalg::MatmulOp::emitScalarImplementation(
|
|||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
ValueHandle cond = r_k == zero;
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i, j));
|
||||
C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
|
||||
|
|
Loading…
Reference in New Issue