forked from OSchip/llvm-project
Linalg portion of the tutorial - part 1
The first part of the Linalg tutorial introduces: 1. the RangeType and ViewType; 2. operations on those, namely RangeOp, ViewOp and SliceOp; 3. programmatic examples to test MLIR construction involving these types, ops and affine.for loops (with a mock custom op called "some_consumer"). -- PiperOrigin-RevId: 241409949
This commit is contained in:
parent
9089911daa
commit
d7296a4ae3
|
@ -0,0 +1,110 @@
|
|||
//===- Example.cpp - Our running example ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/test | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
|
||||
#include "linalg/Common.h"
|
||||
#include "linalg/Ops.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "linalg/SliceOp.h"
|
||||
#include "linalg/Types.h"
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/ViewType.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
TEST_FUNC(view_op) {
|
||||
Function *f = makeFunction("view_op", {});
|
||||
|
||||
ScopedContext scope(f);
|
||||
|
||||
// Let's be lazy and define some custom ops that prevent DCE.
|
||||
CustomOperation<OperationHandle> some_consumer("some_consumer");
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)),
|
||||
A0 = alloc(floatMemRefType<0>()),
|
||||
A1 = alloc(floatMemRefType<1>(), ArrayRef<ValueHandle>{M}),
|
||||
A2 = alloc(floatMemRefType<2>(), ArrayRef<ValueHandle>{M, N}),
|
||||
r0 = range(constant_index(3), constant_index(17), constant_index(1)),
|
||||
v0 = view(A0),
|
||||
v1 = view(A1, ArrayRef<ValueHandle>{r0}),
|
||||
v2 = view(A2, ArrayRef<ValueHandle>{r0, r0});
|
||||
some_consumer(ArrayRef<ValueHandle>{v0, v1, v2});
|
||||
ret();
|
||||
// CHECK-LABEL: func @view_op(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
|
||||
// CHECK-NEXT: {{.*}} = linalg.view {{.*}}[] : !linalg<"view<0xf32>">
|
||||
// CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view<f32xf32>">
|
||||
// clang-format on
|
||||
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(slice_op) {
|
||||
Function *f = makeFunction("slice_op", {});
|
||||
|
||||
ScopedContext scope(f);
|
||||
|
||||
// Let's be lazy and define some custom op that prevents DCE.
|
||||
CustomOperation<OperationHandle> some_consumer("some_consumer");
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)),
|
||||
A = alloc(floatMemRefType<2>(), {M, N}),
|
||||
r1 = range(constant_index(3), constant_index(17), constant_index(1)),
|
||||
r2 = range(constant_index(0), N, constant_index(1));
|
||||
ViewOp vA = view(A, {r1, r2}).getValue()->getDefiningOp()->cast<ViewOp>();
|
||||
IndexHandle i, j;
|
||||
LoopNestRangeBuilder({&i, &j}, vA.getRanges())({
|
||||
some_consumer(slice(vA, i, 1)),
|
||||
some_consumer(slice(slice(vA, j, 0), i, 0)),
|
||||
});
|
||||
ret();
|
||||
// CHECK-LABEL: func @slice_op(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: %[[ALLOC:.*]] = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
// CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:{{.*}}:{{.*}} : !linalg<"range">
|
||||
// CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%arg1:{{.*}} : !linalg<"range">
|
||||
// CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view<f32xf32>">
|
||||
// CHECK-NEXT: for %i0 = 3 to 17 {
|
||||
// CHECK-NEXT: for %i1 = 0 to (d0) -> (d0)(%arg1) {
|
||||
// CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg<"view<f32>">) -> ()
|
||||
// CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] { dim : 0 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[S3:.*]] = linalg.slice %[[S2]][%i0] { dim : 0 } : !linalg<"view<0xf32>">
|
||||
// CHECK-NEXT: "some_consumer"(%[[S3]]) : (!linalg<"view<0xf32>">) -> ()
|
||||
// clang-format on
|
||||
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
//===- TestHarness.h - Minimal test harness for exercising the linalg API -===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_TEST_HARNESS_H
|
||||
#define LINALG_TEST_HARNESS_H
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
namespace test_detail {
|
||||
// Returns a mutable list of known test functions. Used internally by test
|
||||
// macros to add and run tests. This function is static to ensure it creates a
|
||||
// new list in each test file.
|
||||
static std::vector<std::function<void()>> &tests() {
|
||||
static std::vector<std::function<void()>> list;
|
||||
return list;
|
||||
}
|
||||
|
||||
// Test registration class. Used internally by test macros to register tests
|
||||
// during static allocation.
|
||||
struct TestRegistration {
|
||||
explicit TestRegistration(std::function<void()> func) {
|
||||
test_detail::tests().push_back(func);
|
||||
}
|
||||
};
|
||||
} // end namespace test_detail
|
||||
|
||||
/// Declares a test function with the given name and adds it to the list of
|
||||
/// known tests. The body of the function must follow immediately. Example:
|
||||
///
|
||||
/// TEST_FUNC(mytest) {
|
||||
/// // CHECK: expected-output-here
|
||||
/// emitSomethingToStdOut();
|
||||
/// }
|
||||
///
|
||||
#define TEST_FUNC(name) \
|
||||
void name(); \
|
||||
static test_detail::TestRegistration name##Registration(name); \
|
||||
void name()
|
||||
|
||||
/// Runs all registered tests. Example:
|
||||
///
|
||||
/// int main() {
|
||||
/// RUN_TESTS();
|
||||
/// return 0;
|
||||
/// }
|
||||
#define RUN_TESTS \
|
||||
[]() { \
|
||||
for (auto f : test_detail::tests()) \
|
||||
f(); \
|
||||
}
|
||||
|
||||
#endif // LINALG_TEST_HARNESS_H
|
|
@ -0,0 +1,151 @@
|
|||
//===- Common.h - Linalg dialect RangeOp operation -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_COMMON_H_
|
||||
#define LINALG_COMMON_H_
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace common {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Define a few boilerplate objects used across all linalg examples.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The unique MLIRContext, similar to an llvm::Context.
|
||||
inline mlir::MLIRContext &globalContext() {
|
||||
static mlir::MLIRContext context;
|
||||
return context;
|
||||
}
|
||||
|
||||
// The unique Module, similar to an llvm::Module.
|
||||
inline mlir::Module &globalModule() {
|
||||
static mlir::Module module(&globalContext());
|
||||
return module;
|
||||
}
|
||||
|
||||
/// Shortcut notation for types that we use globally.
|
||||
/// The index type is the type that must be used with affine operations:
|
||||
/// (`affine.apply`, `affine.for`, `affine.load`, `affine.store`).
|
||||
inline mlir::IndexType indexType() {
|
||||
return mlir::IndexType::get(&globalContext());
|
||||
}
|
||||
|
||||
/// Common f32 type.
|
||||
inline mlir::FloatType f32Type() {
|
||||
return mlir::FloatType::getF32(&globalContext());
|
||||
}
|
||||
|
||||
/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
|
||||
/// sizes.
|
||||
template <int N>
|
||||
inline mlir::MemRefType floatMemRefType(unsigned memorySpace = 0) {
|
||||
llvm::SmallVector<int64_t, 4> shape(N, -1);
|
||||
return mlir::MemRefType::get(shape, f32Type(), {}, memorySpace);
|
||||
}
|
||||
|
||||
/// The simple function, taking 4 parameters of type index, that we will use
|
||||
/// throughout this tutorial:
|
||||
///
|
||||
/// ```mlir
|
||||
/// func @name(%M: index, %N: index, %K: index, %P: index)
|
||||
/// ```
|
||||
inline mlir::Function *makeFunction(llvm::StringRef name,
|
||||
llvm::ArrayRef<mlir::Type> resultTypes) {
|
||||
auto &ctx = globalContext();
|
||||
auto *function =
|
||||
new mlir::Function(mlir::UnknownLoc::get(&ctx), name,
|
||||
mlir::FunctionType::get({indexType(), indexType(),
|
||||
indexType(), indexType()},
|
||||
resultTypes, &ctx));
|
||||
function->addEntryBlock();
|
||||
globalModule().getFunctions().push_back(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
/// A basic pass manager pre-populated with cleanup passes.
|
||||
inline mlir::PassManager &cleanupPassManager() {
|
||||
static bool inited = false;
|
||||
static mlir::PassManager pm;
|
||||
if (!inited) {
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createSimplifyAffineStructuresPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
inited = true;
|
||||
}
|
||||
return pm;
|
||||
}
|
||||
|
||||
/// A simple function to verify and cleanup the IR before printing it to
|
||||
/// llvm::outs() for FileCheck'ing.
|
||||
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
|
||||
/// which will make the associated FileCheck test fail.
|
||||
inline void cleanupAndPrintFunction(mlir::Function *f) {
|
||||
bool printToOuts = true;
|
||||
auto check = [f, &printToOuts](mlir::LogicalResult result) {
|
||||
if (failed(result)) {
|
||||
f->dump();
|
||||
llvm::errs() << "Failure!\n";
|
||||
printToOuts = false;
|
||||
}
|
||||
};
|
||||
check(mlir::failure(f->getModule()->verify()));
|
||||
check(cleanupPassManager().run(f->getModule()));
|
||||
if (printToOuts)
|
||||
f->print(llvm::outs());
|
||||
}
|
||||
|
||||
/// Helper class to sugar building loop nests from indexings that appear in
|
||||
/// ViewOp and SliceOp.
|
||||
class LoopNestRangeBuilder {
|
||||
public:
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::edsc::ValueHandle
|
||||
operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
|
||||
|
||||
private:
|
||||
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_COMMON_H_
|
|
@ -0,0 +1,42 @@
|
|||
//===- Dialect.h - Definition of the Linalg dialect -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_DIALECT_H_
|
||||
#define LINALG_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// The Linalg Dialect is not exposed to the outside world. It is registered by
|
||||
/// linking and accessed via generic MLIR accessors.
|
||||
class LinalgDialect : public mlir::Dialect {
|
||||
public:
|
||||
/// Create a new Dialect that is registered on construction and adds the
|
||||
/// relevant types and operations.
|
||||
explicit LinalgDialect(mlir::MLIRContext *context);
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_DIALECT_H_
|
|
@ -0,0 +1,59 @@
|
|||
//===- Ops.h - Linalg Ops forward declarations ------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_OPS_H_
|
||||
#define LINALG_OPS_H_
|
||||
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class MatmulOp;
|
||||
class RangeOp;
|
||||
class SliceOp;
|
||||
class ViewOp;
|
||||
class ViewType;
|
||||
|
||||
struct ViewOrSliceOp {
|
||||
public:
|
||||
ViewOrSliceOp(mlir::Value *v) : v(v) {}
|
||||
ViewOp view();
|
||||
SliceOp slice();
|
||||
operator bool();
|
||||
unsigned getRank();
|
||||
ViewType getViewType();
|
||||
/// Get the indexing at `dim` by recursing into the parent.
|
||||
/// Returns the indexing as well as its actual dimension, which may have
|
||||
/// shifted from the originally requested `dim`.
|
||||
std::pair<mlir::Value *, unsigned> getRootIndexing(unsigned dim);
|
||||
// Get all the indexings without recursing.
|
||||
mlir::Operation::operand_range getIndexings();
|
||||
mlir::Value *getSupportingMemRef();
|
||||
|
||||
private:
|
||||
mlir::Value *v;
|
||||
};
|
||||
|
||||
namespace intrinsics {
|
||||
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
|
||||
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
|
||||
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_OPS_H_
|
|
@ -0,0 +1,56 @@
|
|||
//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_RANGEOP_H_
|
||||
#define LINALG_RANGEOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A RangeOp is used to create a value of RangeType from 3 values of type index
|
||||
/// that represent the min, max and step values of the range.
|
||||
/// Note: step must be an mlir::ConstantIndexOp for now due to current
|
||||
/// `affine.for` limitations.
|
||||
class RangeOp : public mlir::Op<RangeOp, mlir::OpTrait::NOperands<3>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.range"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *min, mlir::Value *max, mlir::Value *step);
|
||||
bool verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
mlir::Value *getMin() { return getOperand(0); }
|
||||
mlir::Value *getMax() { return getOperand(1); }
|
||||
mlir::Value *getStep() { return getOperand(2); }
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_RANGEOP_H_
|
|
@ -0,0 +1,47 @@
|
|||
//===- RangeType.h - Linalg RangeType definition --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_RANGETYPE_H_
|
||||
#define LINALG_RANGETYPE_H_
|
||||
|
||||
#include "linalg/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
}
|
||||
|
||||
namespace linalg {
|
||||
/// A RangeType is the simplest possible form of a type in MLIR. It represents
|
||||
/// a minimal range abstraction (min, max, step). Since RangeType is constructed
|
||||
/// without any additional argument, this example illustrates the minimal
|
||||
/// amount of information required to implement a new custom MLIR type.
|
||||
class RangeType : public mlir::Type::TypeBase<RangeType, mlir::Type> {
|
||||
public:
|
||||
// Used to implement llvm-style cast.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static RangeType get(mlir::MLIRContext *context) {
|
||||
/// Custom, uniqu'ed construction in the mlir::MLIRContext.
|
||||
return Base::get(context, LinalgTypes::Range);
|
||||
}
|
||||
/// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_RANGETYPE_H_
|
|
@ -0,0 +1,102 @@
|
|||
//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_SLICEOP_H_
|
||||
#define LINALG_SLICEOP_H_
|
||||
|
||||
#include "linalg/Types.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a
|
||||
/// new ViewType which is contained within its parent ViewType.
|
||||
class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.slice"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *view, mlir::Value *indexing, unsigned dim);
|
||||
bool verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
/// Returns the attribute name that describes which dimension of the input
|
||||
/// view that this SliceOp slices.
|
||||
static llvm::StringRef getSlicingDimAttrName() { return "dim"; }
|
||||
/// Returns the unique result of the parent SliceOp of ViewOp instruction that
|
||||
/// created the view on which this SliceOp operates.
|
||||
mlir::Value *getParentView() { return getOperand(0); }
|
||||
/// Returns the indexing operand of the current SliceOp.
|
||||
/// This operands may either be:
|
||||
/// 1. A range, in which case the operand comes from a RangeOp. This SliceOp
|
||||
/// does not reduce the dimension of the input ViewType.
|
||||
/// 2. An index, in which case the operand comes from any possible producer
|
||||
/// of an index. This SliceOp reduces the dimension of the input ViewType
|
||||
/// by 1.
|
||||
mlir::Value *getIndexing() { return getOperand(1); }
|
||||
/// Returns the dim of the parent ViewType that is sliced by this SliceOp.
|
||||
unsigned getSlicingDim() {
|
||||
return getAttrOfType<mlir::IntegerAttr>(getSlicingDimAttrName()).getInt();
|
||||
}
|
||||
/// Returns the ViewType resulting from this SliceOp.
|
||||
ViewType getViewType();
|
||||
/// Returns the rank of the current ViewType.
|
||||
unsigned getRank();
|
||||
/// Return the element type of the current ViewType.
|
||||
mlir::Type getElementType();
|
||||
|
||||
/// Returns the ViewType of `getParentView()`.
|
||||
ViewType getParentViewType();
|
||||
/// Returns the rank of the ViewType of `getParentView()`.
|
||||
unsigned getParentRank();
|
||||
/// Returns the element Type of the ViewType of `getParentView()`.
|
||||
mlir::Type getParentElementType();
|
||||
|
||||
/// Walks the SliceOp chain until it encounters the base ViewOp.
|
||||
/// Returns the single return value of the ViewOp.
|
||||
mlir::Value *getBaseView();
|
||||
|
||||
/// Returns the MemRef backing the base ViewOp.
|
||||
// May be another data type than a MemRef in the future.
|
||||
mlir::Value *getSupportingMemRef();
|
||||
|
||||
/// Extracts the indexing from the original ViewOp that this slice restricts
|
||||
/// along `dim`. Walks back the chain of SliceOp and determines the first
|
||||
/// slice that constrains `dim`.
|
||||
/// Returns the indexing as well as its actual dimension which may have
|
||||
/// shifted from the originally requested `dim`.
|
||||
std::pair<mlir::Value *, unsigned> getRootIndexing(unsigned dim);
|
||||
|
||||
// Get all the indexings in this slice.
|
||||
mlir::Operation::operand_range getIndexings();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_SLICEOP_H_
|
|
@ -0,0 +1,37 @@
|
|||
//===- Types.h - Linalg Types forward declarations ------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_TYPES_H_
|
||||
#define LINALG_TYPES_H_
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class RangeType;
|
||||
class ViewType;
|
||||
class ViewTypeStorage;
|
||||
|
||||
enum LinalgTypes {
|
||||
Range = mlir::Type::FIRST_LINALG_TYPE,
|
||||
View,
|
||||
LAST_USED_LINALG_TYPE = View,
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_TYPES_H_
|
|
@ -0,0 +1,71 @@
|
|||
//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_VIEWOP_H_
|
||||
#define LINALG_VIEWOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class ViewType;
|
||||
|
||||
/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range
|
||||
/// abstraction on top of an underlying data type. For now we use the existing
|
||||
/// mlir::MemRef for the underlying data type.
|
||||
class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.view"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *memRef,
|
||||
llvm::ArrayRef<mlir::Value *> indexings = {});
|
||||
bool verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
unsigned getRank();
|
||||
mlir::Type getElementType();
|
||||
ViewType getViewType();
|
||||
// May be something else than a MemRef in the future.
|
||||
mlir::Value *getSupportingMemRef();
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank);
|
||||
// A ViewOp is a root, its root indexing is trivial.
|
||||
std::pair<mlir::Value *, unsigned> getRootIndexing(unsigned rank) {
|
||||
return std::make_pair(getIndexing(rank), rank);
|
||||
}
|
||||
// Get all the indexings of type RangeOp.
|
||||
llvm::SmallVector<mlir::Value *, 8> getRanges();
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_VIEWOP_H_
|
|
@ -0,0 +1,54 @@
|
|||
//===- ViewType.h - Linalg ViewType definition --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG_VIEWTYPE_H_
|
||||
#define LINALG_VIEWTYPE_H_
|
||||
|
||||
#include "linalg/Types.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A ViewType represents a range abstraction on top of an underlying storage
|
||||
/// type. It is parameterizable by the underlying element type and the rank of
|
||||
/// the view.
|
||||
class ViewType
|
||||
: public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> {
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this type.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used to implement llvm-style cast.
|
||||
using Base::Base;
|
||||
// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
|
||||
/// Construction hook.
|
||||
static ViewType get(mlir::MLIRContext *context, mlir::Type elementType,
|
||||
unsigned rank);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Type-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Return the underlying elemental type.
|
||||
mlir::Type getElementType();
|
||||
/// Return the rank of the view.
|
||||
/// This is the number of indexings needed to reach an underlying element.
|
||||
unsigned getRank();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG_VIEWTYPE_H_
|
|
@ -0,0 +1,70 @@
|
|||
//===- Common.cpp - Implementation of common supporting functions ---------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/Common.h"
|
||||
#include "linalg/Ops.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/ViewType.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using mlir::ConstantIndexOp;
|
||||
using mlir::edsc::CapturableHandle;
|
||||
using mlir::edsc::ValueHandle;
|
||||
using mlir::edsc::intrinsics::alloc;
|
||||
using mlir::edsc::intrinsics::ret;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
llvm::ArrayRef<ValueHandle *> ivs, llvm::ArrayRef<ValueHandle> indexings) {
|
||||
assert(ivs.size() == indexings.size());
|
||||
for (unsigned i = 0, e = indexings.size(); i < e; ++i) {
|
||||
auto rangeOp =
|
||||
indexings[i].getValue()->getDefiningOp()->dyn_cast<RangeOp>();
|
||||
if (!rangeOp) {
|
||||
continue;
|
||||
}
|
||||
auto lb = rangeOp.getMin();
|
||||
auto ub = rangeOp.getMax();
|
||||
// This must be a constexpr index until we relax the affine.for constraint
|
||||
auto step =
|
||||
rangeOp.getStep()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
|
||||
loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
|
||||
}
|
||||
}
|
||||
|
||||
linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
llvm::ArrayRef<ValueHandle *> ivs, llvm::ArrayRef<mlir::Value *> indexings)
|
||||
: LoopNestRangeBuilder(ivs, llvm::SmallVector<ValueHandle, 4>(
|
||||
indexings.begin(), indexings.end())) {}
|
||||
|
||||
ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
|
||||
llvm::ArrayRef<CapturableHandle> stmts) {
|
||||
for (auto &lit : llvm::reverse(loops)) {
|
||||
lit({});
|
||||
}
|
||||
return ValueHandle::null();
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
//===- Dialect.cpp - Implementation of the linalg dialect -----------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple Linalg dialect to which we gradually add
|
||||
// complexity.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/Dialect.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "linalg/SliceOp.h"
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/ViewType.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using llvm::raw_ostream;
|
||||
using llvm::StringRef;
|
||||
using mlir::Location;
|
||||
using mlir::Type;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
Type LinalgDialect::parseType(StringRef spec, Location loc) const {
|
||||
llvm_unreachable("Unhandled linalg dialect parsing");
|
||||
return Type();
|
||||
}
|
||||
|
||||
/// RangeType prints as just "range".
|
||||
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
|
||||
|
||||
/// ViewType prints as:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<i8xf32xi1>
|
||||
/// ```
|
||||
///
|
||||
/// or
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<0xf32>
|
||||
/// ```
|
||||
///
|
||||
/// for 0-D views (a.k.a pointer to a scalar value).
|
||||
static void print(linalg::ViewType rt, raw_ostream &os) {
|
||||
os << "view<";
|
||||
if (rt.getRank() > 0) {
|
||||
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
|
||||
os << rt.getElementType() << ((i == e - 1) ? "" : "x");
|
||||
}
|
||||
} else {
|
||||
os << "0x" << rt.getElementType();
|
||||
}
|
||||
os << ">";
|
||||
}
|
||||
|
||||
void LinalgDialect::printType(Type type, raw_ostream &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("Unhandled linalg type");
|
||||
case LinalgTypes::Range:
|
||||
print(type.cast<RangeType>(), os);
|
||||
break;
|
||||
case linalg::LinalgTypes::View:
|
||||
print(type.cast<linalg::ViewType>(), os);
|
||||
break;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file registers the Linalg dialect and should live in a standalone
|
||||
// library. Linking with this library will create a static global object that
|
||||
// performs dialect registration.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/Dialect.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "linalg/SliceOp.h"
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/ViewType.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<RangeType, ViewType>();
|
||||
addOperations<RangeOp, SliceOp, ViewOp>();
|
||||
}
|
||||
|
||||
// Dialect registration triggers the creation of a `LinalgDialect` object which
|
||||
// adds the proper types and operations to the dialect.
|
||||
static mlir::DialectRegistration<LinalgDialect> LinalgOps;
|
|
@ -0,0 +1,68 @@
|
|||
//===- RangeOp.cpp - Implementation of the linalg RangeOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using mlir::Builder;
|
||||
using mlir::IndexType;
|
||||
using mlir::OpAsmParser;
|
||||
using mlir::OpAsmPrinter;
|
||||
using mlir::OperationState;
|
||||
using mlir::Value;
|
||||
|
||||
// Minimal example for a new RangeOp operating on RangeType.
|
||||
void linalg::RangeOp::build(Builder *b, OperationState *result, Value *min,
|
||||
Value *max, Value *step) {
|
||||
result->addOperands({min, max, step});
|
||||
result->addTypes({linalg::RangeType::get(b->getContext())});
|
||||
}
|
||||
|
||||
// Verification is simply that a RangeOp takes 3 index ssa-value.
|
||||
bool linalg::RangeOp::verify() {
|
||||
if (!getMin() || !getMin()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!getMax() || !getMax()->getType().isa<IndexType>())
|
||||
return emitOpError("second operand should be of type index");
|
||||
if (!getStep() || !getStep()->getType().isa<IndexType>())
|
||||
return emitOpError("third operand should be of type index");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
assert(false && "NYI");
|
||||
return false;
|
||||
}
|
||||
|
||||
// A RangeOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.range %arg0:%arg1:%c42 : !linalg<"range">
|
||||
// ```
|
||||
void linalg::RangeOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getMin() << ":" << *getMax() << ":"
|
||||
<< *getStep() << " : " << getType();
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
//===- SliceOp.cpp - Implementation of the linalg SliceOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements an IR operation to extract a "sub-View" from a ViewType
|
||||
// in the Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/SliceOp.h"
|
||||
#include "linalg/Ops.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/ViewType.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using mlir::Builder;
|
||||
using mlir::IndexType;
|
||||
using mlir::OpAsmParser;
|
||||
using mlir::OpAsmPrinter;
|
||||
using mlir::OperationState;
|
||||
using mlir::Type;
|
||||
using mlir::Value;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
ViewOp linalg::ViewOrSliceOp::view() {
|
||||
return v->getDefiningOp()->dyn_cast<ViewOp>();
|
||||
}
|
||||
SliceOp linalg::ViewOrSliceOp::slice() {
|
||||
return v->getDefiningOp()->dyn_cast<SliceOp>();
|
||||
}
|
||||
linalg::ViewOrSliceOp::operator bool() {
|
||||
return static_cast<bool>(view()) || static_cast<bool>(slice());
|
||||
}
|
||||
unsigned linalg::ViewOrSliceOp::getRank() {
|
||||
assert(*this && "Not a ViewOp or a SliceOp!");
|
||||
return view() ? view().getRank() : slice().getRank();
|
||||
}
|
||||
ViewType linalg::ViewOrSliceOp::getViewType() {
|
||||
assert(*this && "Not a ViewOp or a SliceOp!");
|
||||
return view() ? view().getViewType() : slice().getViewType();
|
||||
}
|
||||
std::pair<Value *, unsigned>
|
||||
linalg::ViewOrSliceOp::getRootIndexing(unsigned dim) {
|
||||
assert(*this && "Not a ViewOp or a SliceOp!");
|
||||
return view() ? view().getRootIndexing(dim) : slice().getRootIndexing(dim);
|
||||
}
|
||||
llvm::iterator_range<mlir::Operation::operand_iterator>
|
||||
linalg::ViewOrSliceOp::getIndexings() {
|
||||
assert(*this && "Not a ViewOp or a SliceOp!");
|
||||
return view() ? view().getIndexings() : slice().getIndexings();
|
||||
}
|
||||
Value *linalg::ViewOrSliceOp::getSupportingMemRef() {
|
||||
assert(*this && "Not a ViewOp or a SliceOp!");
|
||||
return view() ? view().getSupportingMemRef() : slice().getSupportingMemRef();
|
||||
}
|
||||
|
||||
// A view may itself be coming either from a ViewOp or from a SliceOp.
|
||||
// TODO assert statically or dynamically that indexing is within the bounds of
|
||||
// view.
|
||||
void linalg::SliceOp::build(Builder *b, OperationState *result, Value *view,
|
||||
Value *indexing, unsigned dim) {
|
||||
// Early sanity checks + extract rank.
|
||||
ViewOrSliceOp op(view);
|
||||
unsigned rank = op.getRank();
|
||||
ViewType viewType = op.getViewType();
|
||||
Type elementType = viewType.getElementType();
|
||||
|
||||
result->addOperands({view, indexing});
|
||||
result->addAttribute(getSlicingDimAttrName(),
|
||||
b->getIntegerAttr(b->getIndexType(), dim));
|
||||
if (indexing->getType().isa<RangeType>()) {
|
||||
// Taking a range slice does not decrease the rank, the view has the same
|
||||
// type.
|
||||
result->addTypes({viewType});
|
||||
} else {
|
||||
assert(indexing->getType().cast<IndexType>());
|
||||
result->addTypes(
|
||||
{linalg::ViewType::get(b->getContext(), elementType, rank - 1)});
|
||||
}
|
||||
}
|
||||
|
||||
bool linalg::SliceOp::verify() {
|
||||
unsigned dim = getSlicingDim();
|
||||
if (dim >= getParentRank())
|
||||
return emitOpError("slicing dim must be in the [0 .. parent_rank) range");
|
||||
ViewOrSliceOp op(getOperand(0));
|
||||
if (!op)
|
||||
return emitOpError(
|
||||
"first operand must be of ViewType (i.e. a ViewOp or a SliceOp)");
|
||||
auto type = getOperand(1)->getType().dyn_cast<IndexType>();
|
||||
auto *inst = getOperand(1)->getDefiningOp();
|
||||
auto range = inst ? inst->dyn_cast<RangeOp>() : RangeOp();
|
||||
if (!range && !type)
|
||||
return emitOpError(
|
||||
"second operand must be of RangeType (i.e. a RangeOp) or IndexType");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
assert(false && "NYI");
|
||||
return false;
|
||||
}
|
||||
|
||||
// A SliceOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.slice %0[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a `view<f32xf32>`, %i0 is an ssa-value
|
||||
// holding an index.
|
||||
void linalg::SliceOp::print(OpAsmPrinter *p) {
|
||||
unsigned dim = getSlicingDim();
|
||||
*p << getOperationName() << " " << *getParentView() << "[";
|
||||
for (unsigned idx = 0, rank = getParentRank(); idx < rank; ++idx) {
|
||||
if (idx != dim) {
|
||||
*p << "*";
|
||||
} else {
|
||||
auto *v = getIndexing();
|
||||
if (v->getDefiningOp() && v->getDefiningOp()->isa<RangeOp>()) {
|
||||
*p << *v << "..";
|
||||
} else {
|
||||
*p << *v;
|
||||
}
|
||||
}
|
||||
*p << ((idx == rank - 1) ? "" : ", ");
|
||||
}
|
||||
*p << "] { " << getSlicingDimAttrName() << " : " << dim << " }"
|
||||
<< " : " << getViewType();
|
||||
}
|
||||
|
||||
ViewType linalg::SliceOp::getViewType() { return getType().cast<ViewType>(); }
|
||||
|
||||
unsigned linalg::SliceOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
mlir::Type linalg::SliceOp::getElementType() {
|
||||
return getViewType().getElementType();
|
||||
}
|
||||
|
||||
ViewType linalg::SliceOp::getParentViewType() {
|
||||
ViewOrSliceOp op(getParentView());
|
||||
return op.getViewType();
|
||||
}
|
||||
|
||||
unsigned linalg::SliceOp::getParentRank() {
|
||||
return getParentViewType().getRank();
|
||||
}
|
||||
|
||||
mlir::Type linalg::SliceOp::getParentElementType() {
|
||||
return getParentViewType().getElementType();
|
||||
}
|
||||
|
||||
Value *linalg::SliceOp::getBaseView() {
|
||||
Value *parent = getParentView();
|
||||
while (!parent->getDefiningOp()->isa<ViewOp>()) {
|
||||
parent = parent->getDefiningOp()->cast<SliceOp>().getParentView();
|
||||
}
|
||||
assert(parent && "null parent");
|
||||
return parent;
|
||||
}
|
||||
|
||||
// We want to extract the range from the original ViewOp that this slice
|
||||
// captures along `dim`. To achieve this, we want to walk back the chain of
|
||||
// SliceOp and determine the first slice that constrains `dim`.
|
||||
std::pair<Value *, unsigned> linalg::SliceOp::getRootIndexing(unsigned dim) {
|
||||
assert(dim < getRank());
|
||||
auto *view = getParentView();
|
||||
unsigned sliceDim = getSlicingDim();
|
||||
auto *indexing = getIndexing();
|
||||
if (indexing->getDefiningOp()) {
|
||||
if (auto rangeOp = indexing->getDefiningOp()->cast<RangeOp>()) {
|
||||
// If I sliced with a range and I sliced at this dim, then I'm it.
|
||||
if (dim == sliceDim) {
|
||||
return make_pair(rangeOp.getResult(), dim);
|
||||
}
|
||||
// Otherwise, I did not change the rank, just go look for `dim` into my
|
||||
// parent.
|
||||
ViewOrSliceOp op(view);
|
||||
return op.getRootIndexing(dim);
|
||||
}
|
||||
}
|
||||
assert(indexing->getType().isa<IndexType>());
|
||||
// If I get here, I indexed and reduced along the dim `sliceDim` from my
|
||||
// parent. I need to query my parent for `dim` or `dim + 1` depending on
|
||||
// whether dim > sliceDim or not.
|
||||
unsigned parentDim = dim > sliceDim ? dim + 1 : dim;
|
||||
ViewOrSliceOp op(view);
|
||||
return op.getRootIndexing(parentDim);
|
||||
}
|
||||
|
||||
Value *linalg::SliceOp::getSupportingMemRef() {
|
||||
auto view = getBaseView()->getDefiningOp()->cast<ViewOp>();
|
||||
return view.getSupportingMemRef();
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::SliceOp::getIndexings() {
|
||||
return {this->getOperation()->operand_begin() + SliceOp::FirstIndexingOperand,
|
||||
this->getOperation()->operand_end()};
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
//===- ViewOp.cpp - Implementation of the linalg ViewOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new ViewType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/ViewOp.h"
|
||||
#include "linalg/Ops.h"
|
||||
#include "linalg/RangeOp.h"
|
||||
#include "linalg/RangeType.h"
|
||||
#include "linalg/ViewType.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::SmallVector;
|
||||
using llvm::Twine;
|
||||
using mlir::Builder;
|
||||
using mlir::IndexType;
|
||||
using mlir::MemRefType;
|
||||
using mlir::OpAsmParser;
|
||||
using mlir::OpAsmPrinter;
|
||||
using mlir::OperationState;
|
||||
using mlir::Type;
|
||||
using mlir::Value;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef,
|
||||
ArrayRef<Value *> indexings) {
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
result->addOperands({memRef});
|
||||
assert(indexings.size() == memRefType.getRank() &&
|
||||
"unexpected number of indexings (must match the memref rank)");
|
||||
|
||||
result->addOperands(indexings);
|
||||
unsigned rank = memRefType.getRank();
|
||||
for (auto *v : indexings) {
|
||||
if (!v->getType().isa<RangeType>()) {
|
||||
rank--;
|
||||
}
|
||||
}
|
||||
Type elementType = memRefType.getElementType();
|
||||
result->addTypes({linalg::ViewType::get(b->getContext(), elementType, rank)});
|
||||
}
|
||||
|
||||
bool linalg::ViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a memref operand followed by 'rank' indices");
|
||||
auto memrefType = getOperand(0)->getType().dyn_cast<MemRefType>();
|
||||
unsigned memrefRank = memrefType.getRank();
|
||||
if (!memrefType)
|
||||
return emitOpError("first operand must be of MemRefType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>() &&
|
||||
!indexing->getType().isa<IndexType>()) {
|
||||
return emitOpError(Twine(index) +
|
||||
"^th index must be of range or index type");
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (llvm::size(getIndexings()) != memrefRank) {
|
||||
return emitOpError("requires at least a memref operand followed by " +
|
||||
Twine(memrefRank) + " indices");
|
||||
}
|
||||
unsigned rank = memrefRank;
|
||||
for (auto *v : getIndexings()) {
|
||||
if (!v->getType().isa<RangeType>()) {
|
||||
rank--;
|
||||
}
|
||||
}
|
||||
if (getRank() != rank) {
|
||||
return emitOpError("the rank of the view must be the number of its range "
|
||||
"indices: " +
|
||||
Twine(rank));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
assert(false && "NYI");
|
||||
return false;
|
||||
}
|
||||
|
||||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] : !linalg<"view<f32xf32>">
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void linalg::ViewOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getSupportingMemRef() << "[";
|
||||
unsigned numRanges = llvm::size(getIndexings());
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
*p << *indexing << ((index++ == numRanges - 1) ? "" : ", ");
|
||||
}
|
||||
*p << "] : " << getType();
|
||||
}
|
||||
|
||||
Type linalg::ViewOp::getElementType() { return getViewType().getElementType(); }
|
||||
|
||||
ViewType linalg::ViewOp::getViewType() { return getType().cast<ViewType>(); }
|
||||
|
||||
unsigned linalg::ViewOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
// May be something else than a MemRef in the future.
|
||||
Value *linalg::ViewOp::getSupportingMemRef() {
|
||||
auto *res = getOperand(0);
|
||||
assert(res->getType().isa<MemRefType>());
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<mlir::Value *, 8> linalg::ViewOp::getRanges() {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *operand : getIndexings()) {
|
||||
if (!operand->getType().isa<mlir::IndexType>()) {
|
||||
res.push_back(operand);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Value *linalg::ViewOp::getIndexing(unsigned rank) {
|
||||
SmallVector<Value *, 1> ranges(getIndexings().begin(), getIndexings().end());
|
||||
return ranges[rank];
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::ViewOp::getIndexings() {
|
||||
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
//===- ViewType.h - Implementation of the ViewType custom type ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a custom ViewType in the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg/ViewType.h"
|
||||
|
||||
using mlir::MLIRContext;
|
||||
using mlir::Type;
|
||||
using mlir::TypeStorage;
|
||||
using mlir::TypeStorageAllocator;
|
||||
|
||||
namespace linalg {
|
||||
|
||||
struct ViewTypeStorage : public mlir::TypeStorage {
|
||||
/// Underlying Key type to transport the payload needed to construct a custom
|
||||
/// type in a generic way.
|
||||
struct Key {
|
||||
Key(Type elementType, unsigned rank)
|
||||
: elementType(elementType), rank(rank) {}
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
|
||||
using KeyTy = Key;
|
||||
|
||||
/// Construction in the llvm::BumpPtrAllocator given a key.
|
||||
static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const Key &key) {
|
||||
return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
|
||||
}
|
||||
|
||||
/// Equality operator for hashing.
|
||||
bool operator==(const Key &key) const {
|
||||
return elementType == key.elementType && rank == key.rank;
|
||||
}
|
||||
|
||||
/// Hashing for unique'ing.
|
||||
static unsigned hashKey(const Key &key) {
|
||||
return llvm::hash_combine(key.elementType, key.rank);
|
||||
}
|
||||
|
||||
unsigned getRank() { return rank; };
|
||||
Type getElementType() { return elementType; };
|
||||
|
||||
private:
|
||||
ViewTypeStorage(const Key &key)
|
||||
: elementType(key.elementType), rank(key.rank) {}
|
||||
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
|
||||
ViewType linalg::ViewType::get(MLIRContext *context, Type elementType,
|
||||
unsigned rank) {
|
||||
return Base::get(context, LinalgTypes::View, elementType, rank);
|
||||
}
|
||||
|
||||
Type linalg::ViewType::getElementType() { return getImpl()->getElementType(); }
|
||||
|
||||
unsigned linalg::ViewType::getRank() { return getImpl()->getRank(); }
|
||||
|
||||
} // namespace linalg
|
Loading…
Reference in New Issue