forked from OSchip/llvm-project
[mlir][sparse] fix bug in reduction chain
Found with exhaustive testing, it is possible that a while loop appears in between chainable for loops. As long as we don't scalarize reductions in while loops, this means we need to terminate the chain at the while. This also refactors the reduction code into more readable helper methods. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97886
This commit is contained in:
parent
dbf41ddaa3
commit
553cb6d473
|
@ -775,6 +775,39 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
|
||||||
return rewriter.create<AddIOp>(loc, mul, i);
|
return rewriter.create<AddIOp>(loc, mul, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates start of a reduction.
|
||||||
|
static Value genReductionStart(Merger &merger, CodeGen &codegen,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
linalg::GenericOp op) {
|
||||||
|
if (codegen.redVal)
|
||||||
|
return codegen.redVal; // chained with previous for-loop
|
||||||
|
if (codegen.curVecLength > 1) {
|
||||||
|
// TODO: assumes + reductions for now
|
||||||
|
VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
|
||||||
|
return rewriter.create<ConstantOp>(op.getLoc(), vtp,
|
||||||
|
rewriter.getZeroAttr(vtp));
|
||||||
|
}
|
||||||
|
return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates end of a reduction.
|
||||||
|
static void genReductionEnd(Merger &merger, CodeGen &codegen,
|
||||||
|
PatternRewriter &rewriter, linalg::GenericOp op) {
|
||||||
|
Value red = codegen.redVal;
|
||||||
|
if (!red)
|
||||||
|
return;
|
||||||
|
codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
|
||||||
|
unsigned lhs = op.getNumShapedOperands() - 1;
|
||||||
|
if (codegen.curVecLength > 1) {
|
||||||
|
// TODO: assumes + reductions for now
|
||||||
|
codegen.curVecLength = 1;
|
||||||
|
Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
|
||||||
|
red = rewriter.create<vector::ReductionOp>(
|
||||||
|
op.getLoc(), ld.getType(), rewriter.getStringAttr("add"), red, ld);
|
||||||
|
}
|
||||||
|
genTensorStore(merger, codegen, rewriter, op, lhs, red);
|
||||||
|
}
|
||||||
|
|
||||||
/// Recursively generates tensor expression.
|
/// Recursively generates tensor expression.
|
||||||
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
|
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
|
||||||
linalg::GenericOp op, unsigned exp) {
|
linalg::GenericOp op, unsigned exp) {
|
||||||
|
@ -952,16 +985,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
|
||||||
bool scalarRed = isInner && codegen.redExp != -1u;
|
bool scalarRed = isInner && codegen.redExp != -1u;
|
||||||
SmallVector<Value, 4> operands;
|
SmallVector<Value, 4> operands;
|
||||||
if (scalarRed) {
|
if (scalarRed) {
|
||||||
Value load;
|
Value load = genReductionStart(merger, codegen, rewriter, op);
|
||||||
if (codegen.redVal) {
|
|
||||||
load = codegen.redVal; // chained with previous for-loop
|
|
||||||
} else if (isVector) {
|
|
||||||
// TODO: assumes + reductions for now
|
|
||||||
VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]);
|
|
||||||
load = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
|
|
||||||
} else {
|
|
||||||
load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
|
|
||||||
}
|
|
||||||
operands.push_back(load);
|
operands.push_back(load);
|
||||||
}
|
}
|
||||||
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
|
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
|
||||||
|
@ -1049,6 +1073,7 @@ static Operation *genLoop(Merger &merger, CodeGen &codegen,
|
||||||
return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
|
return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
|
||||||
indices);
|
indices);
|
||||||
}
|
}
|
||||||
|
genReductionEnd(merger, codegen, rewriter, op); // cannot chain
|
||||||
return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
|
return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1251,18 +1276,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap-up loop sequence.
|
// Wrap-up loop sequence.
|
||||||
Value red = codegen.redVal;
|
genReductionEnd(merger, codegen, rewriter, op);
|
||||||
if (red) {
|
|
||||||
codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
|
|
||||||
unsigned lhs = op.getNumShapedOperands() - 1;
|
|
||||||
if (codegen.curVecLength > 1) {
|
|
||||||
codegen.curVecLength = 1;
|
|
||||||
Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
|
|
||||||
red = rewriter.create<vector::ReductionOp>(
|
|
||||||
loc, ld.getType(), rewriter.getStringAttr("add"), red, ld);
|
|
||||||
}
|
|
||||||
genTensorStore(merger, codegen, rewriter, op, lhs, red);
|
|
||||||
}
|
|
||||||
genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
|
genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
|
||||||
codegen.loops[idx] = Value();
|
codegen.loops[idx] = Value();
|
||||||
codegen.curVecLength = 1;
|
codegen.curVecLength = 1;
|
||||||
|
|
|
@ -1346,3 +1346,326 @@ func @four_tensors_op(%arga: tensor<?xf64>,
|
||||||
} -> tensor<?xf64>
|
} -> tensor<?xf64>
|
||||||
return %r : tensor<?xf64>
|
return %r : tensor<?xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#trait_red3s = {
|
||||||
|
indexing_maps = [
|
||||||
|
affine_map<(i) -> (i)>,
|
||||||
|
affine_map<(i) -> (i)>,
|
||||||
|
affine_map<(i) -> (i)>,
|
||||||
|
affine_map<(i) -> ()>
|
||||||
|
],
|
||||||
|
sparse = [
|
||||||
|
["S"],
|
||||||
|
["S"],
|
||||||
|
["S"],
|
||||||
|
[]
|
||||||
|
],
|
||||||
|
iterator_types = ["reduction"],
|
||||||
|
doc = "x += a(i) + b(i) + c(i)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @red3s(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?xf64>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?xf64>,
|
||||||
|
// CHECK-SAME: %[[VAL_2:.*2]]: tensor<?xf64>,
|
||||||
|
// CHECK-SAME: %[[VAL_3:.*3]]: tensor<f64>) -> tensor<f64> {
|
||||||
|
// CHECK: %[[VAL_4:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_5:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_6:.*]] = linalg.sparse_pointers %[[VAL_0]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = linalg.sparse_indices %[[VAL_0]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = linalg.sparse_values %[[VAL_0]] : tensor<?xf64> to memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<?xf64> to memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = linalg.sparse_pointers %[[VAL_2]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = linalg.sparse_indices %[[VAL_2]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = linalg.sparse_values %[[VAL_2]] : tensor<?xf64> to memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tensor_to_memref %[[VAL_3]] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = alloc() : memref<f64>
|
||||||
|
// CHECK: linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<f64>, memref<f64>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_22:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_23:.*]]:3 = scf.while (%[[VAL_24:.*]] = %[[VAL_17]], %[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_21]]) : (index, index, index) -> (index, index, index) {
|
||||||
|
// CHECK: %[[VAL_27:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_18]] : index
|
||||||
|
// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[VAL_25]], %[[VAL_20]] : index
|
||||||
|
// CHECK: %[[VAL_29:.*]] = and %[[VAL_27]], %[[VAL_28]] : i1
|
||||||
|
// CHECK: %[[VAL_30:.*]] = cmpi ult, %[[VAL_26]], %[[VAL_22]] : index
|
||||||
|
// CHECK: %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1
|
||||||
|
// CHECK: scf.condition(%[[VAL_31]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, index
|
||||||
|
// CHECK: } do {
|
||||||
|
// CHECK: ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index):
|
||||||
|
// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_32]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_36:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_33]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_37:.*]] = cmpi ult, %[[VAL_36]], %[[VAL_35]] : index
|
||||||
|
// CHECK: %[[VAL_38:.*]] = select %[[VAL_37]], %[[VAL_36]], %[[VAL_35]] : index
|
||||||
|
// CHECK: %[[VAL_39:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_34]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_40:.*]] = cmpi ult, %[[VAL_39]], %[[VAL_38]] : index
|
||||||
|
// CHECK: %[[VAL_41:.*]] = select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index
|
||||||
|
// CHECK: %[[VAL_42:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_43:.*]] = cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_44:.*]] = and %[[VAL_42]], %[[VAL_43]] : i1
|
||||||
|
// CHECK: %[[VAL_45:.*]] = cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_46:.*]] = and %[[VAL_44]], %[[VAL_45]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_46]] {
|
||||||
|
// CHECK: %[[VAL_47:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_48:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_49:.*]] = addf %[[VAL_47]], %[[VAL_48]] : f64
|
||||||
|
// CHECK: %[[VAL_50:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_51:.*]] = addf %[[VAL_49]], %[[VAL_50]] : f64
|
||||||
|
// CHECK: %[[VAL_52:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_53:.*]] = addf %[[VAL_51]], %[[VAL_52]] : f64
|
||||||
|
// CHECK: store %[[VAL_53]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_54:.*]] = cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_55:.*]] = cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_56:.*]] = and %[[VAL_54]], %[[VAL_55]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_56]] {
|
||||||
|
// CHECK: %[[VAL_57:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_58:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_59:.*]] = addf %[[VAL_57]], %[[VAL_58]] : f64
|
||||||
|
// CHECK: %[[VAL_60:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_61:.*]] = addf %[[VAL_59]], %[[VAL_60]] : f64
|
||||||
|
// CHECK: store %[[VAL_61]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_62:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_63:.*]] = cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_64:.*]] = and %[[VAL_62]], %[[VAL_63]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_64]] {
|
||||||
|
// CHECK: %[[VAL_65:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_66:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_67:.*]] = addf %[[VAL_65]], %[[VAL_66]] : f64
|
||||||
|
// CHECK: %[[VAL_68:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_69:.*]] = addf %[[VAL_67]], %[[VAL_68]] : f64
|
||||||
|
// CHECK: store %[[VAL_69]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_70:.*]] = cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_70]] {
|
||||||
|
// CHECK: %[[VAL_71:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_72:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_73:.*]] = addf %[[VAL_71]], %[[VAL_72]] : f64
|
||||||
|
// CHECK: store %[[VAL_73]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_74:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_75:.*]] = cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_76:.*]] = and %[[VAL_74]], %[[VAL_75]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_76]] {
|
||||||
|
// CHECK: %[[VAL_77:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_78:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_79:.*]] = addf %[[VAL_77]], %[[VAL_78]] : f64
|
||||||
|
// CHECK: %[[VAL_80:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_81:.*]] = addf %[[VAL_79]], %[[VAL_80]] : f64
|
||||||
|
// CHECK: store %[[VAL_81]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_82:.*]] = cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_82]] {
|
||||||
|
// CHECK: %[[VAL_83:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_84:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_85:.*]] = addf %[[VAL_83]], %[[VAL_84]] : f64
|
||||||
|
// CHECK: store %[[VAL_85]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_86:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_86]] {
|
||||||
|
// CHECK: %[[VAL_87:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_88:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_32]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_89:.*]] = addf %[[VAL_87]], %[[VAL_88]] : f64
|
||||||
|
// CHECK: store %[[VAL_89]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_90:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_91:.*]] = addi %[[VAL_32]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_92:.*]] = select %[[VAL_90]], %[[VAL_91]], %[[VAL_32]] : index
|
||||||
|
// CHECK: %[[VAL_93:.*]] = cmpi eq, %[[VAL_36]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_94:.*]] = addi %[[VAL_33]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_95:.*]] = select %[[VAL_93]], %[[VAL_94]], %[[VAL_33]] : index
|
||||||
|
// CHECK: %[[VAL_96:.*]] = cmpi eq, %[[VAL_39]], %[[VAL_41]] : index
|
||||||
|
// CHECK: %[[VAL_97:.*]] = addi %[[VAL_34]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_98:.*]] = select %[[VAL_96]], %[[VAL_97]], %[[VAL_34]] : index
|
||||||
|
// CHECK: scf.yield %[[VAL_92]], %[[VAL_95]], %[[VAL_98]] : index, index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_99:.*]]:2 = scf.while (%[[VAL_100:.*]] = %[[VAL_101:.*]]#1, %[[VAL_102:.*]] = %[[VAL_101]]#2) : (index, index) -> (index, index) {
|
||||||
|
// CHECK: %[[VAL_103:.*]] = cmpi ult, %[[VAL_100]], %[[VAL_20]] : index
|
||||||
|
// CHECK: %[[VAL_104:.*]] = cmpi ult, %[[VAL_102]], %[[VAL_22]] : index
|
||||||
|
// CHECK: %[[VAL_105:.*]] = and %[[VAL_103]], %[[VAL_104]] : i1
|
||||||
|
// CHECK: scf.condition(%[[VAL_105]]) %[[VAL_100]], %[[VAL_102]] : index, index
|
||||||
|
// CHECK: } do {
|
||||||
|
// CHECK: ^bb0(%[[VAL_106:.*]]: index, %[[VAL_107:.*]]: index):
|
||||||
|
// CHECK: %[[VAL_108:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_106]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_109:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_107]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_110:.*]] = cmpi ult, %[[VAL_109]], %[[VAL_108]] : index
|
||||||
|
// CHECK: %[[VAL_111:.*]] = select %[[VAL_110]], %[[VAL_109]], %[[VAL_108]] : index
|
||||||
|
// CHECK: %[[VAL_112:.*]] = cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
|
||||||
|
// CHECK: %[[VAL_113:.*]] = cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
|
||||||
|
// CHECK: %[[VAL_114:.*]] = and %[[VAL_112]], %[[VAL_113]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_114]] {
|
||||||
|
// CHECK: %[[VAL_115:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_116:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_117:.*]] = addf %[[VAL_115]], %[[VAL_116]] : f64
|
||||||
|
// CHECK: %[[VAL_118:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_107]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_119:.*]] = addf %[[VAL_117]], %[[VAL_118]] : f64
|
||||||
|
// CHECK: store %[[VAL_119]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_120:.*]] = cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_120]] {
|
||||||
|
// CHECK: %[[VAL_121:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_122:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_107]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_123:.*]] = addf %[[VAL_121]], %[[VAL_122]] : f64
|
||||||
|
// CHECK: store %[[VAL_123]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_124:.*]] = cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_124]] {
|
||||||
|
// CHECK: %[[VAL_125:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_126:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_127:.*]] = addf %[[VAL_125]], %[[VAL_126]] : f64
|
||||||
|
// CHECK: store %[[VAL_127]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_128:.*]] = cmpi eq, %[[VAL_108]], %[[VAL_111]] : index
|
||||||
|
// CHECK: %[[VAL_129:.*]] = addi %[[VAL_106]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_130:.*]] = select %[[VAL_128]], %[[VAL_129]], %[[VAL_106]] : index
|
||||||
|
// CHECK: %[[VAL_131:.*]] = cmpi eq, %[[VAL_109]], %[[VAL_111]] : index
|
||||||
|
// CHECK: %[[VAL_132:.*]] = addi %[[VAL_107]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_133:.*]] = select %[[VAL_131]], %[[VAL_132]], %[[VAL_107]] : index
|
||||||
|
// CHECK: scf.yield %[[VAL_130]], %[[VAL_133]] : index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_134:.*]]:2 = scf.while (%[[VAL_135:.*]] = %[[VAL_136:.*]]#0, %[[VAL_137:.*]] = %[[VAL_138:.*]]#1) : (index, index) -> (index, index) {
|
||||||
|
// CHECK: %[[VAL_139:.*]] = cmpi ult, %[[VAL_135]], %[[VAL_18]] : index
|
||||||
|
// CHECK: %[[VAL_140:.*]] = cmpi ult, %[[VAL_137]], %[[VAL_22]] : index
|
||||||
|
// CHECK: %[[VAL_141:.*]] = and %[[VAL_139]], %[[VAL_140]] : i1
|
||||||
|
// CHECK: scf.condition(%[[VAL_141]]) %[[VAL_135]], %[[VAL_137]] : index, index
|
||||||
|
// CHECK: } do {
|
||||||
|
// CHECK: ^bb0(%[[VAL_142:.*]]: index, %[[VAL_143:.*]]: index):
|
||||||
|
// CHECK: %[[VAL_144:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_142]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_145:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_143]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_146:.*]] = cmpi ult, %[[VAL_145]], %[[VAL_144]] : index
|
||||||
|
// CHECK: %[[VAL_147:.*]] = select %[[VAL_146]], %[[VAL_145]], %[[VAL_144]] : index
|
||||||
|
// CHECK: %[[VAL_148:.*]] = cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
|
||||||
|
// CHECK: %[[VAL_149:.*]] = cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
|
||||||
|
// CHECK: %[[VAL_150:.*]] = and %[[VAL_148]], %[[VAL_149]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_150]] {
|
||||||
|
// CHECK: %[[VAL_151:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_152:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_142]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_153:.*]] = addf %[[VAL_151]], %[[VAL_152]] : f64
|
||||||
|
// CHECK: %[[VAL_154:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_143]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_155:.*]] = addf %[[VAL_153]], %[[VAL_154]] : f64
|
||||||
|
// CHECK: store %[[VAL_155]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_156:.*]] = cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_156]] {
|
||||||
|
// CHECK: %[[VAL_157:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_158:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_143]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_159:.*]] = addf %[[VAL_157]], %[[VAL_158]] : f64
|
||||||
|
// CHECK: store %[[VAL_159]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_160:.*]] = cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_160]] {
|
||||||
|
// CHECK: %[[VAL_161:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_162:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_142]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_163:.*]] = addf %[[VAL_161]], %[[VAL_162]] : f64
|
||||||
|
// CHECK: store %[[VAL_163]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_164:.*]] = cmpi eq, %[[VAL_144]], %[[VAL_147]] : index
|
||||||
|
// CHECK: %[[VAL_165:.*]] = addi %[[VAL_142]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_166:.*]] = select %[[VAL_164]], %[[VAL_165]], %[[VAL_142]] : index
|
||||||
|
// CHECK: %[[VAL_167:.*]] = cmpi eq, %[[VAL_145]], %[[VAL_147]] : index
|
||||||
|
// CHECK: %[[VAL_168:.*]] = addi %[[VAL_143]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_169:.*]] = select %[[VAL_167]], %[[VAL_168]], %[[VAL_143]] : index
|
||||||
|
// CHECK: scf.yield %[[VAL_166]], %[[VAL_169]] : index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_170:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_171:.*]] = scf.for %[[VAL_172:.*]] = %[[VAL_173:.*]]#1 to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_174:.*]] = %[[VAL_170]]) -> (f64) {
|
||||||
|
// CHECK: %[[VAL_175:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_172]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_176:.*]] = addf %[[VAL_174]], %[[VAL_175]] : f64
|
||||||
|
// CHECK: scf.yield %[[VAL_176]] : f64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: store %[[VAL_177:.*]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_178:.*]]:2 = scf.while (%[[VAL_179:.*]] = %[[VAL_180:.*]]#0, %[[VAL_181:.*]] = %[[VAL_182:.*]]#0) : (index, index) -> (index, index) {
|
||||||
|
// CHECK: %[[VAL_183:.*]] = cmpi ult, %[[VAL_179]], %[[VAL_18]] : index
|
||||||
|
// CHECK: %[[VAL_184:.*]] = cmpi ult, %[[VAL_181]], %[[VAL_20]] : index
|
||||||
|
// CHECK: %[[VAL_185:.*]] = and %[[VAL_183]], %[[VAL_184]] : i1
|
||||||
|
// CHECK: scf.condition(%[[VAL_185]]) %[[VAL_179]], %[[VAL_181]] : index, index
|
||||||
|
// CHECK: } do {
|
||||||
|
// CHECK: ^bb0(%[[VAL_186:.*]]: index, %[[VAL_187:.*]]: index):
|
||||||
|
// CHECK: %[[VAL_188:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_186]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_189:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_187]]] : memref<?xindex>
|
||||||
|
// CHECK: %[[VAL_190:.*]] = cmpi ult, %[[VAL_189]], %[[VAL_188]] : index
|
||||||
|
// CHECK: %[[VAL_191:.*]] = select %[[VAL_190]], %[[VAL_189]], %[[VAL_188]] : index
|
||||||
|
// CHECK: %[[VAL_192:.*]] = cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
|
||||||
|
// CHECK: %[[VAL_193:.*]] = cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
|
||||||
|
// CHECK: %[[VAL_194:.*]] = and %[[VAL_192]], %[[VAL_193]] : i1
|
||||||
|
// CHECK: scf.if %[[VAL_194]] {
|
||||||
|
// CHECK: %[[VAL_195:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_196:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_186]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_197:.*]] = addf %[[VAL_195]], %[[VAL_196]] : f64
|
||||||
|
// CHECK: %[[VAL_198:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_187]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_199:.*]] = addf %[[VAL_197]], %[[VAL_198]] : f64
|
||||||
|
// CHECK: store %[[VAL_199]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_200:.*]] = cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_200]] {
|
||||||
|
// CHECK: %[[VAL_201:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_202:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_187]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_203:.*]] = addf %[[VAL_201]], %[[VAL_202]] : f64
|
||||||
|
// CHECK: store %[[VAL_203]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_204:.*]] = cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
|
||||||
|
// CHECK: scf.if %[[VAL_204]] {
|
||||||
|
// CHECK: %[[VAL_205:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_206:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_186]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_207:.*]] = addf %[[VAL_205]], %[[VAL_206]] : f64
|
||||||
|
// CHECK: store %[[VAL_207]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_208:.*]] = cmpi eq, %[[VAL_188]], %[[VAL_191]] : index
|
||||||
|
// CHECK: %[[VAL_209:.*]] = addi %[[VAL_186]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_210:.*]] = select %[[VAL_208]], %[[VAL_209]], %[[VAL_186]] : index
|
||||||
|
// CHECK: %[[VAL_211:.*]] = cmpi eq, %[[VAL_189]], %[[VAL_191]] : index
|
||||||
|
// CHECK: %[[VAL_212:.*]] = addi %[[VAL_187]], %[[VAL_5]] : index
|
||||||
|
// CHECK: %[[VAL_213:.*]] = select %[[VAL_211]], %[[VAL_212]], %[[VAL_187]] : index
|
||||||
|
// CHECK: scf.yield %[[VAL_210]], %[[VAL_213]] : index, index
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_214:.*]] = load %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_215:.*]] = scf.for %[[VAL_216:.*]] = %[[VAL_217:.*]]#1 to %[[VAL_20]] step %[[VAL_5]] iter_args(%[[VAL_218:.*]] = %[[VAL_214]]) -> (f64) {
|
||||||
|
// CHECK: %[[VAL_219:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_216]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_220:.*]] = addf %[[VAL_218]], %[[VAL_219]] : f64
|
||||||
|
// CHECK: scf.yield %[[VAL_220]] : f64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_221:.*]] = scf.for %[[VAL_222:.*]] = %[[VAL_223:.*]]#0 to %[[VAL_18]] step %[[VAL_5]] iter_args(%[[VAL_224:.*]] = %[[VAL_225:.*]]) -> (f64) {
|
||||||
|
// CHECK: %[[VAL_226:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_222]]] : memref<?xf64>
|
||||||
|
// CHECK: %[[VAL_227:.*]] = addf %[[VAL_224]], %[[VAL_226]] : f64
|
||||||
|
// CHECK: scf.yield %[[VAL_227]] : f64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: store %[[VAL_228:.*]], %[[VAL_16]][] : memref<f64>
|
||||||
|
// CHECK: %[[VAL_229:.*]] = tensor_load %[[VAL_16]] : memref<f64>
|
||||||
|
// CHECK: return %[[VAL_229]] : tensor<f64>
|
||||||
|
// CHECK: }
|
||||||
|
func @red3s(%arga: tensor<?xf64>,
|
||||||
|
%argb: tensor<?xf64>,
|
||||||
|
%argc: tensor<?xf64>, %argx: tensor<f64>) ->tensor<f64>{
|
||||||
|
%0 = linalg.generic #trait_red3s
|
||||||
|
ins(%arga, %argb, %argc: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>)
|
||||||
|
outs(%argx: tensor<f64>) {
|
||||||
|
^bb(%a: f64,%b: f64,%c: f64,%x: f64):
|
||||||
|
%0 = addf %x, %a : f64
|
||||||
|
%1 = addf %0, %b : f64
|
||||||
|
%2 = addf %1, %c : f64
|
||||||
|
linalg.yield %2 : f64
|
||||||
|
} -> tensor<f64>
|
||||||
|
return %0 : tensor<f64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue