[mlir][Linalg] NFC - Rename Linalg and Vector EDSCs to avoid collisions

A certain number of EDSCs have a named form (e.g. `linalg.matmul`) and a generic form (e.g. `linalg.generic` with matmul traits).
Despite living in different namespaces, using the same name is confusiong in clients.
Rename them as `linalg_matmul` and `linalg_generic_matmul` respectively.
This commit is contained in:
Nicolas Vasilache 2020-04-02 21:06:45 -04:00
parent 30f18ed387
commit aef0877b1b
7 changed files with 111 additions and 94 deletions

View File

@ -161,31 +161,34 @@ void macRegionBuilder(ArrayRef<BlockArgument> args);
/// Unary pointwise operation (with broadcast) entry point.
using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I, StructuredIndexed O);
Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I, StructuredIndexed O);
/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = tanh(I)`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
Operation *linalg_generic_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O);
/// Binary pointwise operation (with broadcast) entry point.
using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>;
Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = I1 + I2`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
Operation *linalg_generic_pointwise_add(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O);
/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = max(I1, I2)`. The client is responsible for specifying the
/// proper indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
Operation *linalg_generic_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O);
// TODO(ntv): Implement more useful pointwise operations on a per-need basis.
@ -198,8 +201,9 @@ using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;
/// |
/// | C(m, n) += A(m, k) * B(k, n)
/// ```
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@ -209,8 +213,9 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
/// | C(m, n) = sum_k(A(m, k) * B(k, n))
/// ```
/// and returns the tensor `C`.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
MatmulRegionBuilder regionBuilder = mulRegionBuilder);
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
MatmulRegionBuilder regionBuilder = mulRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
/// insert point, that computes:
@ -220,15 +225,17 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
/// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n))
/// ```
/// and returns the tensor `D`.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
RankedTensorType tD,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
RankedTensorType tD,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
template <typename Container>
Operation *linalg_matmul(Container values,
MatmulRegionBuilder regionBuilder = macRegionBuilder) {
Operation *
linalg_generic_matmul(Container values,
MatmulRegionBuilder regionBuilder = macRegionBuilder) {
assert(values.size() == 3 && "Expected exactly 3 values");
return linalg_matmul(values[0], values[1], values[2], regionBuilder);
return linalg_generic_matmul(values[0], values[1], values[2], regionBuilder);
}
/// Build a linalg.generic, under the current ScopedContext, at the current
@ -253,15 +260,17 @@ Operation *linalg_matmul(Container values,
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});
Operation *linalg_generic_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO, ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});
template <typename Container>
Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
ArrayRef<int> dilations = {}) {
Operation *linalg_generic_conv_nhwc(Container values,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {}) {
assert(values.size() == 3 && "Expected exactly 3 values");
return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
return linalg_generic_conv_nhwc(values[0], values[1], values[2], strides,
dilations);
}
/// Build a linalg.generic, under the current ScopedContext, at the current
@ -286,18 +295,20 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO, int depth_multiplier = 1,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});
Operation *linalg_generic_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO,
int depth_multiplier = 1,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});
template <typename Container>
Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {}) {
Operation *linalg_generic_dilated_conv_nhwc(Container values,
int depth_multiplier,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {}) {
assert(values.size() == 3 && "Expected exactly 3 values");
return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
depth_multiplier, strides, dilations);
return linalg_generic_dilated_conv_nhwc(values[0], values[1], values[2],
depth_multiplier, strides, dilations);
}
} // namespace ops

View File

@ -26,7 +26,10 @@ ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) {
namespace intrinsics {
using linalg_copy = OperationBuilder<linalg::CopyOp>;
using linalg_dot = OperationBuilder<linalg::DotOp>;
using linalg_fill = OperationBuilder<linalg::FillOp>;
using linalg_matmul = OperationBuilder<linalg::MatmulOp>;
using linalg_matvec = OperationBuilder<linalg::MatvecOp>;
using linalg_range = ValueBuilder<linalg::RangeOp>;
using linalg_reshape = ValueBuilder<linalg::ReshapeOp>;
using linalg_slice = ValueBuilder<linalg::SliceOp>;

View File

@ -44,7 +44,7 @@ Value vector_contraction(StructuredIndexed A, StructuredIndexed B,
/// Prerequisites:
/// A, B and C capture values of proper vector types. For instance
/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`.
Value vector_matmul(Value A, Value B, Value C);
Value vector_contraction_matmul(Value A, Value B, Value C);
} // namespace ops
} // namespace edsc

View File

@ -16,6 +16,7 @@ namespace intrinsics {
using vector_broadcast = ValueBuilder<vector::BroadcastOp>;
using vector_contract = ValueBuilder<vector::ContractionOp>;
using vector_matmul = ValueBuilder<vector::MatmulOp>;
using vector_print = OperationBuilder<vector::PrintOp>;
} // namespace intrinsics

View File

@ -221,9 +221,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
linalg_yield((c + a * b).getValue());
}
Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I,
StructuredIndexed O) {
Operation *mlir::edsc::ops::linalg_generic_pointwise(
UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
if (O.getType().isa<RankedTensorType>()) {
@ -242,18 +241,17 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
}
Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O) {
Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O) {
UnaryPointwiseOpBuilder unOp(
[](ValueHandle a) -> Value { return std_tanh(a); });
return linalg_pointwise(unOp, I, O);
return linalg_generic_pointwise(unOp, I, O);
}
/// Binary pointwise operation (with broadcast) entry point.
Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
Operation *mlir::edsc::ops::linalg_generic_pointwise(
BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1,
StructuredIndexed I2, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
if (O.getType().isa<RankedTensorType>()) {
@ -272,28 +270,29 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
}
Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
using edsc::op::operator+;
BinaryPointwiseOpBuilder binOp(
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
return linalg_pointwise(binOp, I1, I2, O);
return linalg_generic_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
using edsc::op::operator>;
return std_select(a > b, a, b).getValue();
});
return linalg_pointwise(binOp, I1, I2, O);
return linalg_generic_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC,
MatmulRegionBuilder regionBuilder) {
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@ -306,9 +305,10 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
// clang-format on
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
RankedTensorType tC,
MatmulRegionBuilder regionBuilder) {
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
RankedTensorType tC,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@ -321,9 +321,10 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
// clang-format on
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC, RankedTensorType tD,
MatmulRegionBuilder regionBuilder) {
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC, RankedTensorType tD,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
@ -336,10 +337,11 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
// clang-format on
}
Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO,
ArrayRef<int> strides,
ArrayRef<int> dilations) {
Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI,
ValueHandle vW,
ValueHandle vO,
ArrayRef<int> strides,
ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();
// TODO(ntv) some template magic to make everything rank-polymorphic.
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
@ -370,7 +372,7 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
// clang-format on
}
Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
ArrayRef<int> strides, ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();

View File

@ -30,7 +30,7 @@ Value mlir::edsc::ops::vector_contraction(
ArrayRef<StringRef>{functional::map(toString, iteratorTypes)});
}
Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) {
Value mlir::edsc::ops::vector_contraction_matmul(Value A, Value B, Value C) {
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
return vector_contraction(StructuredIndexed(A, {m, k}),

View File

@ -804,7 +804,7 @@ TEST_FUNC(affine_if_op) {
}
// clang-format off
// CHECK-LABEL: func @linalg_pointwise
// CHECK-LABEL: func @linalg_generic_pointwise
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
@ -822,14 +822,14 @@ TEST_FUNC(affine_if_op) {
// CHECK: tanh
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>
// clang-format on
TEST_FUNC(linalg_pointwise_test) {
TEST_FUNC(linalg_generic_pointwise_test) {
using namespace edsc;
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f = makeFunction("linalg_pointwise", {},
auto f = makeFunction("linalg_generic_pointwise", {},
{memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
@ -838,16 +838,16 @@ TEST_FUNC(linalg_pointwise_test) {
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(C);
linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_tanh(SA({i, j}), SC({i, j}));
linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
f.print(llvm::outs());
f.erase();
}
// clang-format off
// CHECK-LABEL: func @linalg_matmul
// CHECK-LABEL: func @linalg_generic_matmul
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
@ -857,7 +857,7 @@ TEST_FUNC(linalg_pointwise_test) {
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
// clang-format on
TEST_FUNC(linalg_matmul_test) {
TEST_FUNC(linalg_generic_matmul_test) {
using namespace edsc;
using namespace edsc::ops;
@ -865,18 +865,18 @@ TEST_FUNC(linalg_matmul_test) {
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f =
makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
makeFunction("linalg_generic_matmul", {}, {memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
linalg_generic_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
f.print(llvm::outs());
f.erase();
}
// clang-format off
// CHECK-LABEL: func @linalg_conv_nhwc
// CHECK-LABEL: func @linalg_generic_conv_nhwc
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
@ -888,7 +888,7 @@ TEST_FUNC(linalg_matmul_test) {
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
// clang-format on
TEST_FUNC(linalg_conv_nhwc) {
TEST_FUNC(linalg_generic_conv_nhwc) {
using namespace edsc;
using namespace edsc::ops;
@ -897,12 +897,12 @@ TEST_FUNC(linalg_conv_nhwc) {
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
ShapedType::kDynamicSize, ShapedType::kDynamicSize},
f32Type, {}, 0);
auto f = makeFunction("linalg_conv_nhwc", {},
auto f = makeFunction("linalg_generic_conv_nhwc", {},
{memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())),
linalg_generic_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())),
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
f.print(llvm::outs());
@ -910,7 +910,7 @@ TEST_FUNC(linalg_conv_nhwc) {
}
// clang-format off
// CHECK-LABEL: func @linalg_dilated_conv_nhwc
// CHECK-LABEL: func @linalg_generic_dilated_conv_nhwc
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
@ -922,7 +922,7 @@ TEST_FUNC(linalg_conv_nhwc) {
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
// clang-format on
TEST_FUNC(linalg_dilated_conv_nhwc) {
TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
using namespace edsc;
using namespace edsc::ops;
@ -931,12 +931,12 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
ShapedType::kDynamicSize, ShapedType::kDynamicSize},
f32Type, {}, 0);
auto f = makeFunction("linalg_dilated_conv_nhwc", {},
auto f = makeFunction("linalg_generic_dilated_conv_nhwc", {},
{memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_dilated_conv_nhwc(makeValueHandles(f.getArguments()),
linalg_generic_dilated_conv_nhwc(makeValueHandles(f.getArguments()),
/*depth_multiplier=*/7,
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
@ -1019,11 +1019,11 @@ TEST_FUNC(linalg_tensors_test) {
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(tensorType);
linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_tanh(SA({i, j}), SC({i, j}));
Value o1 = linalg_matmul(A, B, tensorType)->getResult(0);
linalg_matmul(A, B, ValueHandle(o1), tensorType);
linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0);
linalg_generic_matmul(A, B, ValueHandle(o1), tensorType);
f.print(llvm::outs());
f.erase();
@ -1067,9 +1067,9 @@ TEST_FUNC(memref_vector_matmul_test) {
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
assert(args.size() == 3 && "expected 3 block arguments");
(linalg_yield(vector_matmul(args[0], args[1], args[2])));
(linalg_yield(vector_contraction_matmul(args[0], args[1], args[2])));
};
linalg_matmul(A, B, C, contractionBuilder);
linalg_generic_matmul(A, B, C, contractionBuilder);
f.print(llvm::outs());
f.erase();