forked from OSchip/llvm-project
NFC: Delete the Linalg tutorial.
This part of the tutorial is now covered by a new flow in Toy. This also removes a point of confusion as there is also a proper Linalg dialect. PiperOrigin-RevId: 275338933
This commit is contained in:
parent
0372eb413f
commit
dae0ae6879
|
@ -1,3 +1 @@
|
|||
add_subdirectory(toy)
|
||||
add_subdirectory(Linalg)
|
||||
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
include_directories(Linalg1/)
|
||||
include_directories(Linalg1/include/)
|
||||
include_directories(Linalg2/include/)
|
||||
include_directories(Linalg3/include/)
|
||||
include_directories(Linalg4/include/)
|
||||
|
||||
add_custom_target(Linalg)
|
||||
set_target_properties(Linalg PROPERTIES FOLDER Examples)
|
||||
add_dependencies(Linalg
|
||||
linalg-conversion-3
|
||||
linalg-example-2
|
||||
linalg-example-3
|
||||
linalg-example-4
|
||||
linalg-execution-3
|
||||
)
|
||||
|
||||
add_subdirectory(Linalg1)
|
||||
add_subdirectory(Linalg2)
|
||||
add_subdirectory(Linalg3)
|
||||
add_subdirectory(Linalg4)
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(lib)
|
|
@ -1,67 +0,0 @@
|
|||
//===- TestHarness.h - Minimal test harness for exercising the linalg API -===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_TEST_HARNESS_H
|
||||
#define LINALG1_TEST_HARNESS_H
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
namespace test_detail {
|
||||
// Returns a mutable list of known test functions. Used internally by test
|
||||
// macros to add and run tests. This function is static to ensure it creates a
|
||||
// new list in each test file.
|
||||
static std::vector<std::function<void()>> &tests() {
|
||||
static std::vector<std::function<void()>> list;
|
||||
return list;
|
||||
}
|
||||
|
||||
// Test registration class. Used internally by test macros to register tests
|
||||
// during static allocation.
|
||||
struct TestRegistration {
|
||||
explicit TestRegistration(std::function<void()> func) {
|
||||
test_detail::tests().push_back(func);
|
||||
}
|
||||
};
|
||||
} // end namespace test_detail
|
||||
|
||||
/// Declares a test function with the given name and adds it to the list of
|
||||
/// known tests. The body of the function must follow immediately. Example:
|
||||
///
|
||||
/// TEST_FUNC(mytest) {
|
||||
/// // CHECK: expected-output-here
|
||||
/// emitSomethingToStdOut();
|
||||
/// }
|
||||
///
|
||||
#define TEST_FUNC(name) \
|
||||
void name(); \
|
||||
static test_detail::TestRegistration name##Registration(name); \
|
||||
void name()
|
||||
|
||||
/// Runs all registered tests. Example:
|
||||
///
|
||||
/// int main() {
|
||||
/// RUN_TESTS();
|
||||
/// return 0;
|
||||
/// }
|
||||
#define RUN_TESTS \
|
||||
[]() { \
|
||||
for (auto f : test_detail::tests()) \
|
||||
f(); \
|
||||
}
|
||||
|
||||
#endif // LINALG1_TEST_HARNESS_H
|
|
@ -1,49 +0,0 @@
|
|||
//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_ANALYSIS_H_
|
||||
#define LINALG1_ANALYSIS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class Value;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
class ViewOp;
|
||||
|
||||
/// Walks the chain of SliceOp until the unique base ViewOp.
|
||||
ViewOp getViewBaseViewOp(mlir::Value *view);
|
||||
|
||||
/// Walks the chain of SliceOp until the unique base ViewOp and returns the
|
||||
/// MemRef upon which the ViewOp is laid.
|
||||
mlir::Value *getViewSupportingMemRef(mlir::Value *view);
|
||||
|
||||
/// Extract the indexing from the root ViewOp that this slice constrins along
|
||||
/// `dim`. To achieve this, it walks back the chain of SliceOp and determine the
|
||||
/// first slice that constrains `dim`.
|
||||
/// Note that the dimension in the original ViewOp may shift due to
|
||||
/// rank-reducing operations.
|
||||
/// Returns a pair, with the indexing as the first element and the actual
|
||||
/// dimension, in the root ViewOp, as the second element.
|
||||
std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view,
|
||||
unsigned dim);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_ANALYSIS_H_
|
|
@ -1,120 +0,0 @@
|
|||
//===- Common.h - Linalg dialect RangeOp operation -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_COMMON_H_
|
||||
#define LINALG1_COMMON_H_
|
||||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace common {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Define a few boilerplate objects used across all linalg examples.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
|
||||
/// sizes.
|
||||
template <int N>
|
||||
inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
|
||||
unsigned memorySpace = 0) {
|
||||
llvm::SmallVector<int64_t, 4> shape(N, -1);
|
||||
auto f32 = mlir::FloatType::getF32(context);
|
||||
return mlir::MemRefType::get(shape, f32, {}, memorySpace);
|
||||
}
|
||||
|
||||
/// A basic function builder
|
||||
inline mlir::FuncOp makeFunction(mlir::ModuleOp module, llvm::StringRef name,
|
||||
llvm::ArrayRef<mlir::Type> types,
|
||||
llvm::ArrayRef<mlir::Type> resultTypes) {
|
||||
auto *context = module.getContext();
|
||||
auto function = mlir::FuncOp::create(
|
||||
mlir::UnknownLoc::get(context), name,
|
||||
mlir::FunctionType::get({types}, resultTypes, context));
|
||||
function.addEntryBlock();
|
||||
module.push_back(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
/// A basic pass manager pre-populated with cleanup passes.
|
||||
inline std::unique_ptr<mlir::PassManager>
|
||||
cleanupPassManager(mlir::MLIRContext *ctx) {
|
||||
std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager(ctx));
|
||||
pm->addPass(mlir::createCanonicalizerPass());
|
||||
pm->addPass(mlir::createSimplifyAffineStructuresPass());
|
||||
pm->addPass(mlir::createCSEPass());
|
||||
pm->addPass(mlir::createCanonicalizerPass());
|
||||
return pm;
|
||||
}
|
||||
|
||||
/// A simple function to verify and cleanup the IR before printing it to
|
||||
/// llvm::outs() for FileCheck'ing.
|
||||
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
|
||||
/// which will make the associated FileCheck test fail.
|
||||
inline void cleanupAndPrintFunction(mlir::FuncOp f) {
|
||||
bool printToOuts = true;
|
||||
auto check = [&f, &printToOuts](mlir::LogicalResult result) {
|
||||
if (failed(result)) {
|
||||
f.emitError("Verification and cleanup passes failed");
|
||||
printToOuts = false;
|
||||
}
|
||||
};
|
||||
auto pm = cleanupPassManager(f.getContext());
|
||||
check(mlir::verify(f.getParentOfType<mlir::ModuleOp>()));
|
||||
check(pm->run(f.getParentOfType<mlir::ModuleOp>()));
|
||||
if (printToOuts)
|
||||
f.print(llvm::outs());
|
||||
}
|
||||
|
||||
/// Helper class to sugar building loop nests from indexings that appear in
|
||||
/// ViewOp and SliceOp.
|
||||
class LoopNestRangeBuilder {
|
||||
public:
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_COMMON_H_
|
|
@ -1,58 +0,0 @@
|
|||
//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_CONVERTTOLLVMDIALECT_H_
|
||||
#define LINALG1_CONVERTTOLLVMDIALECT_H_
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPattern;
|
||||
class DialectConversion;
|
||||
struct LogicalResult;
|
||||
class MLIRContext;
|
||||
class ModuleOp;
|
||||
class RewritePattern;
|
||||
class Type;
|
||||
class OwningRewritePatternList;
|
||||
namespace LLVM {
|
||||
class LLVMType;
|
||||
} // end namespace LLVM
|
||||
} // end namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
/// Convert the given Linalg dialect type `t` into an LLVM IR dialect type.
|
||||
/// Keep all other types unmodified.
|
||||
mlir::Type convertLinalgType(mlir::Type t);
|
||||
|
||||
/// Get the conversion patterns for RangeOp, ViewOp and SliceOp from the Linalg
|
||||
/// dialect to the LLVM IR dialect. The LLVM IR dialect must be registered. This
|
||||
/// function can be used to apply multiple conversion patterns in the same pass.
|
||||
/// It does not have to be called explicitly before the conversion.
|
||||
void populateLinalg1ToLLVMConversionPatterns(
|
||||
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context);
|
||||
|
||||
/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
|
||||
/// to the LLVM IR dialect types and operations in the given `module`. This is
|
||||
/// the main entry point to the conversion.
|
||||
mlir::LogicalResult convertToLLVM(mlir::ModuleOp module);
|
||||
} // end namespace linalg
|
||||
|
||||
#endif // LINALG1_CONVERTTOLLVMDIALECT_H_
|
|
@ -1,43 +0,0 @@
|
|||
//===- Dialect.h - Definition of the Linalg dialect -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_DIALECT_H_
|
||||
#define LINALG1_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// The Linalg Dialect is not exposed to the outside world. It is registered by
|
||||
/// linking and accessed via generic MLIR accessors.
|
||||
class LinalgDialect : public mlir::Dialect {
|
||||
public:
|
||||
/// Create a new Dialect that is registered on construction and adds the
|
||||
/// relevant types and operations.
|
||||
explicit LinalgDialect(mlir::MLIRContext *context);
|
||||
static llvm::StringRef getDialectNamespace() { return "linalg"; }
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_DIALECT_H_
|
|
@ -1,32 +0,0 @@
|
|||
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_INTRINSICS_H_
|
||||
#define LINALG1_INTRINSICS_H_
|
||||
|
||||
#include "linalg1/Ops.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace intrinsics {
|
||||
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
|
||||
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
|
||||
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_INTRINSICS_H_
|
|
@ -1,41 +0,0 @@
|
|||
//===- LLVMIntrinsics.h - declarative builders for LLVM dialect -*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_LLVMINTRINSICS_H_
|
||||
#define LINALG1_LLVMINTRINSICS_H_
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
|
||||
// Expose some LLVM IR instructions to declarative builders.
|
||||
namespace intrinsics {
|
||||
using undef = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::UndefOp>;
|
||||
using insertvalue =
|
||||
mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::InsertValueOp>;
|
||||
using extractvalue =
|
||||
mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ExtractValueOp>;
|
||||
using constant = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::ConstantOp>;
|
||||
using add = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::AddOp>;
|
||||
using sub = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::SubOp>;
|
||||
using mul = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::MulOp>;
|
||||
using load = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::LoadOp>;
|
||||
using store = mlir::edsc::intrinsics::OperationBuilder<mlir::LLVM::StoreOp>;
|
||||
using gep = mlir::edsc::intrinsics::ValueBuilder<mlir::LLVM::GEPOp>;
|
||||
} // end namespace intrinsics
|
||||
|
||||
#endif // LINALG1_LLVMINTRINSICS_H_
|
|
@ -1,26 +0,0 @@
|
|||
//===- Ops.h - Linalg Ops single entry point ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_OPS_H_
|
||||
#define LINALG1_OPS_H_
|
||||
|
||||
#include "linalg1/RangeOp.h"
|
||||
#include "linalg1/SliceOp.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "linalg1/ViewOp.h"
|
||||
|
||||
#endif // LINALG1_OPS_H_
|
|
@ -1,40 +0,0 @@
|
|||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file defines prototypes that expose pass constructors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LINALG1_PASSES_H
|
||||
#define LINALG1_PASSES_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
template <typename T> class OpPassBase;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
mlir::OpPassBase<mlir::ModuleOp> *createLowerLinalgToLLVMPass();
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_PASSES_H
|
|
@ -1,57 +0,0 @@
|
|||
//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_RANGEOP_H_
|
||||
#define LINALG1_RANGEOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A RangeOp is used to create a value of RangeType from 3 values of type index
|
||||
/// that represent the min, max and step values of the range.
|
||||
/// Note: step must be an mlir::ConstantIndexOp for now due to current
|
||||
/// `affine.for` limitations.
|
||||
class RangeOp : public mlir::Op<RangeOp, mlir::OpTrait::NOperands<3>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.range"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *min, mlir::Value *max, mlir::Value *step);
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
mlir::Value *getMin() { return getOperand(0); }
|
||||
mlir::Value *getMax() { return getOperand(1); }
|
||||
mlir::Value *getStep() { return getOperand(2); }
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_RANGEOP_H_
|
|
@ -1,49 +0,0 @@
|
|||
//===- RangeType.h - Linalg RangeType definition --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_RANGETYPE_H_
|
||||
#define LINALG1_RANGETYPE_H_
|
||||
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
}
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A RangeType is the simplest possible form of a type in MLIR. It represents
|
||||
/// a minimal range abstraction (min, max, step). Since RangeType is constructed
|
||||
/// without any additional argument, this example illustrates the minimal
|
||||
/// amount of information required to implement a new custom MLIR type.
|
||||
class RangeType : public mlir::Type::TypeBase<RangeType, mlir::Type> {
|
||||
public:
|
||||
// Used to implement llvm-style cast.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static RangeType get(mlir::MLIRContext *context) {
|
||||
/// Custom, uniqu'ed construction in the mlir::MLIRContext.
|
||||
return Base::get(context, LinalgTypes::Range);
|
||||
}
|
||||
/// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_RANGETYPE_H_
|
|
@ -1,92 +0,0 @@
|
|||
//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_SLICEOP_H_
|
||||
#define LINALG1_SLICEOP_H_
|
||||
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a
|
||||
/// new ViewType which is contained within its parent ViewType.
|
||||
class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.slice"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *view, mlir::Value *indexing, unsigned dim);
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
/// Returns the attribute name that describes which dimension of the input
|
||||
/// view that this SliceOp slices.
|
||||
static llvm::StringRef getSlicingDimAttrName() { return "dim"; }
|
||||
/// Returns the unique result of the parent SliceOp of ViewOp instruction that
|
||||
/// created the view on which this SliceOp operates.
|
||||
mlir::Value *getParentView() { return getOperand(0); }
|
||||
/// Returns the indexing operand of the current SliceOp.
|
||||
/// This operands may either be:
|
||||
/// 1. A range, in which case the operand comes from a RangeOp. This SliceOp
|
||||
/// does not reduce the dimension of the input ViewType.
|
||||
/// 2. An index, in which case the operand comes from any possible producer
|
||||
/// of an index. This SliceOp reduces the dimension of the input ViewType
|
||||
/// by 1.
|
||||
mlir::Value *getIndexing() { return getOperand(1); }
|
||||
/// Returns the dim of the parent ViewType that is sliced by this SliceOp.
|
||||
unsigned getSlicingDim() {
|
||||
return getAttrOfType<mlir::IntegerAttr>(getSlicingDimAttrName()).getInt();
|
||||
}
|
||||
/// Returns the ViewType resulting from this SliceOp.
|
||||
ViewType getViewType();
|
||||
/// Returns the rank of the current ViewType.
|
||||
unsigned getRank();
|
||||
/// Return the element type of the current ViewType.
|
||||
mlir::Type getElementType();
|
||||
|
||||
/// Returns the ViewType of `getParentView()`.
|
||||
ViewType getParentViewType();
|
||||
/// Returns the rank of the ViewType of `getParentView()`.
|
||||
unsigned getParentRank();
|
||||
/// Returns the element Type of the ViewType of `getParentView()`.
|
||||
mlir::Type getParentElementType();
|
||||
|
||||
/// Returns true if the rank of the part view is greater than the rank of
|
||||
/// the child view.
|
||||
bool isRankDecreasing();
|
||||
|
||||
// Get all the indexings in this slice.
|
||||
mlir::Operation::operand_range getIndexings();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_SLICEOP_H_
|
|
@ -1,36 +0,0 @@
|
|||
//===- Types.h - Linalg Types forward declarations ------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_TYPES_H_
|
||||
#define LINALG1_TYPES_H_
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
enum LinalgTypes {
|
||||
Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
|
||||
View,
|
||||
FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#include "linalg1/RangeType.h"
|
||||
#include "linalg1/ViewType.h"
|
||||
|
||||
#endif // LINALG1_TYPES_H_
|
|
@ -1,37 +0,0 @@
|
|||
//===- Utils.h - Linalg dialect utility functions definitions -------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_UTILS_H_
|
||||
#define LINALG1_UTILS_H_
|
||||
|
||||
namespace mlir {
|
||||
class Value;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
class ViewOp;
|
||||
|
||||
/// Asserts `view` is of ViewType and returns its rank.
|
||||
unsigned getViewRank(mlir::Value *view);
|
||||
|
||||
/// Helper function to emit and return a new ViewOp from `memRef` that is
|
||||
/// assumed to be of MemRefType. This needs to be called under a ScopedContext.
|
||||
ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_UTILS_H_
|
|
@ -1,68 +0,0 @@
|
|||
//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_VIEWOP_H_
|
||||
#define LINALG1_VIEWOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class ViewType;
|
||||
|
||||
/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range
|
||||
/// abstraction on top of an underlying data type. For now we use the existing
|
||||
/// mlir::MemRef for the underlying data type.
|
||||
class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.view"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *memRef,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
unsigned getRank();
|
||||
mlir::Type getElementType();
|
||||
ViewType getViewType();
|
||||
// May be something else than a MemRef in the future.
|
||||
mlir::Value *getSupportingMemRef();
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank);
|
||||
// Get all the indexings of type RangeOp.
|
||||
llvm::SmallVector<mlir::Value *, 8> getRanges();
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_VIEWOP_H_
|
|
@ -1,57 +0,0 @@
|
|||
//===- ViewType.h - Linalg ViewType definition --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_VIEWTYPE_H_
|
||||
#define LINALG1_VIEWTYPE_H_
|
||||
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
struct ViewTypeStorage;
|
||||
|
||||
/// A ViewType represents a range abstraction on top of an underlying storage
|
||||
/// type. It is parameterizable by the underlying element type and the rank of
|
||||
/// the view.
|
||||
class ViewType
|
||||
: public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> {
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this type.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used to implement llvm-style cast.
|
||||
using Base::Base;
|
||||
// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
|
||||
/// Construction hook.
|
||||
static ViewType get(mlir::MLIRContext *context, mlir::Type elementType,
|
||||
unsigned rank);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Type-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Return the underlying elemental type.
|
||||
mlir::Type getElementType();
|
||||
/// Return the rank of the view.
|
||||
/// This is the number of indexings needed to reach an underlying element.
|
||||
unsigned getRank();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_VIEWTYPE_H_
|
|
@ -1,75 +0,0 @@
|
|||
//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
ViewOp linalg::getViewBaseViewOp(Value *view) {
|
||||
auto viewType = view->getType().dyn_cast<ViewType>();
|
||||
(void)viewType;
|
||||
assert(viewType.isa<ViewType>() && "expected a ViewType");
|
||||
while (auto slice = dyn_cast<SliceOp>(view->getDefiningOp())) {
|
||||
view = slice.getParentView();
|
||||
assert(viewType.isa<ViewType>() && "expected a ViewType");
|
||||
}
|
||||
return cast<ViewOp>(view->getDefiningOp());
|
||||
}
|
||||
|
||||
Value *linalg::getViewSupportingMemRef(Value *view) {
|
||||
return getViewBaseViewOp(view).getSupportingMemRef();
|
||||
}
|
||||
|
||||
std::pair<mlir::Value *, unsigned> linalg::getViewRootIndexing(Value *view,
|
||||
unsigned dim) {
|
||||
auto viewType = view->getType().dyn_cast<ViewType>();
|
||||
(void)viewType;
|
||||
assert(viewType.isa<ViewType>() && "expected a ViewType");
|
||||
assert(dim < viewType.getRank() && "dim exceeds rank");
|
||||
if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
|
||||
return std::make_pair(viewOp.getIndexing(dim), dim);
|
||||
|
||||
auto sliceOp = cast<SliceOp>(view->getDefiningOp());
|
||||
auto *parentView = sliceOp.getParentView();
|
||||
unsigned sliceDim = sliceOp.getSlicingDim();
|
||||
auto *indexing = sliceOp.getIndexing();
|
||||
if (indexing->getDefiningOp()) {
|
||||
if (auto rangeOp = dyn_cast<RangeOp>(indexing->getDefiningOp())) {
|
||||
// If I sliced with a range and I sliced at this dim, then I'm it.
|
||||
if (dim == sliceDim) {
|
||||
return std::make_pair(rangeOp.getResult(), dim);
|
||||
}
|
||||
// Otherwise, I did not change the rank, just go look for `dim` into my
|
||||
// parent.
|
||||
return getViewRootIndexing(parentView, dim);
|
||||
}
|
||||
}
|
||||
assert(indexing->getType().isa<IndexType>());
|
||||
// If I get here, I indexed and reduced along the dim `sliceDim` from my
|
||||
// parent. I need to query my parent for `dim` or `dim + 1` depending on
|
||||
// whether dim > sliceDim or not.
|
||||
unsigned parentDim = dim > sliceDim ? dim + 1 : dim;
|
||||
return getViewRootIndexing(parentView, parentDim);
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
Analysis.cpp
|
||||
ConvertToLLVMDialect.cpp
|
||||
SliceOp.cpp
|
||||
ViewOp.cpp
|
||||
Common.cpp
|
||||
Dialect.cpp
|
||||
RangeOp.cpp
|
||||
Utils.cpp
|
||||
ViewType.cpp
|
||||
DialectConstruction.cpp
|
||||
DialectRegistration.cpp
|
||||
)
|
||||
|
||||
set(LIBS
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIRLoopToStandard
|
||||
MLIREDSC
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg1LLVMConversion
|
||||
ConvertToLLVMDialect.cpp
|
||||
)
|
||||
target_link_libraries(Linalg1LLVMConversion PUBLIC MLIRLLVMIR
|
||||
MLIRLoopToStandard MLIRStandardOps)
|
||||
|
||||
add_llvm_library(Linalg1
|
||||
Analysis.cpp
|
||||
SliceOp.cpp
|
||||
ViewOp.cpp
|
||||
Common.cpp
|
||||
Dialect.cpp
|
||||
RangeOp.cpp
|
||||
Utils.cpp
|
||||
ViewType.cpp
|
||||
DEPENDS
|
||||
intrinsics_gen
|
||||
)
|
||||
target_link_libraries(Linalg1
|
||||
PUBLIC
|
||||
${LIBS}
|
||||
Linalg1LLVMConversion
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg1DialectConstruction
|
||||
DialectConstruction.cpp
|
||||
)
|
||||
target_link_libraries(Linalg1DialectConstruction PUBLIC Linalg1)
|
||||
|
||||
add_llvm_executable(linalg1-opt
|
||||
DialectRegistration.cpp
|
||||
)
|
||||
llvm_update_compile_flags(linalg1-opt)
|
||||
whole_archive_link(linalg1-opt
|
||||
Linalg1LLVMConversion
|
||||
Linalg1DialectConstruction
|
||||
${LIBS}
|
||||
)
|
||||
target_link_libraries(linalg1-opt
|
||||
PRIVATE
|
||||
Linalg1
|
||||
Linalg1LLVMConversion
|
||||
Linalg1DialectConstruction
|
||||
MLIRLLVMIR
|
||||
MLIRMlirOptLib
|
||||
MLIROptMain
|
||||
${LIBS}
|
||||
LLVMSupport)
|
|
@ -1,68 +0,0 @@
|
|||
//===- Common.cpp - Implementation of common supporting functions ---------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using mlir::ConstantIndexOp;
|
||||
using mlir::edsc::ValueHandle;
|
||||
using mlir::edsc::intrinsics::alloc;
|
||||
using mlir::edsc::intrinsics::ret;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
llvm::ArrayRef<ValueHandle *> ivs, llvm::ArrayRef<ValueHandle> indexings) {
|
||||
assert(ivs.size() == indexings.size());
|
||||
for (unsigned i = 0, e = indexings.size(); i < e; ++i) {
|
||||
auto rangeOp =
|
||||
llvm::dyn_cast<RangeOp>(indexings[i].getValue()->getDefiningOp());
|
||||
if (!rangeOp) {
|
||||
continue;
|
||||
}
|
||||
auto lb = rangeOp.getMin();
|
||||
auto ub = rangeOp.getMax();
|
||||
// This must be a constexpr index until we relax the affine.for constraint
|
||||
auto step = llvm::cast<ConstantIndexOp>(rangeOp.getStep()->getDefiningOp())
|
||||
.getValue();
|
||||
loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
|
||||
}
|
||||
}
|
||||
|
||||
linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
llvm::ArrayRef<ValueHandle *> ivs, llvm::ArrayRef<mlir::Value *> indexings)
|
||||
: LoopNestRangeBuilder(ivs, llvm::SmallVector<ValueHandle, 4>(
|
||||
indexings.begin(), indexings.end())) {}
|
||||
|
||||
ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
|
||||
std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
for (auto &lit : llvm::reverse(loops)) {
|
||||
lit({});
|
||||
}
|
||||
return ValueHandle::null();
|
||||
}
|
|
@ -1,433 +0,0 @@
|
|||
//===- ConvertToLLVMDialect.cpp - conversion from Linalg to LLVM dialect --===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/LowerAffine.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/ConvertToLLVMDialect.h"
|
||||
#include "linalg1/LLVMIntrinsics.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Convert the given type to the LLVM IR Dialect type. The following
|
||||
// conversions are supported:
|
||||
// - an Index type is converted into an LLVM integer type with pointer
|
||||
// bitwidth (analogous to intptr_t in C);
|
||||
// - an Integer type is converted into an LLVM integer type of the same width;
|
||||
// - an F32 type is converted into an LLVM float type
|
||||
// - a Range or View is converted into an LLVM structure type containing the
|
||||
// respective dynamic values.
|
||||
Type linalg::convertLinalgType(Type t) {
|
||||
auto *context = t.getContext();
|
||||
auto *dialect = context->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
||||
// Simple conversions.
|
||||
if (t.isa<IndexType>()) {
|
||||
int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
|
||||
return LLVM::LLVMType::getIntNTy(dialect, width);
|
||||
}
|
||||
if (auto intTy = t.dyn_cast<IntegerType>())
|
||||
return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
|
||||
if (t.isF32())
|
||||
return LLVM::LLVMType::getFloatTy(dialect);
|
||||
if (t.isF64())
|
||||
return LLVM::LLVMType::getDoubleTy(dialect);
|
||||
|
||||
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
||||
//
|
||||
// struct {
|
||||
// int64_t min;
|
||||
// int64_t max;
|
||||
// int64_t step;
|
||||
// };
|
||||
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
||||
return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
||||
}
|
||||
|
||||
// View descriptor contains the pointer to the data buffer, followed by a
|
||||
// 64-bit integer containing the distance between the beginning of the buffer
|
||||
// and the first element to be accessed through the view, followed by two
|
||||
// arrays, each containing as many 64-bit integers as the rank of the View.
|
||||
// The first array represents the size, in number of original elements, of the
|
||||
// view along the given dimension. When taking the view, the size is the
|
||||
// difference between the upper and the lower bound of the range. The second
|
||||
// array represents the "stride" (in tensor abstraction sense), i.e. the
|
||||
// number of consecutive elements of the underlying buffer that separate two
|
||||
// consecutive elements addressable through the view along the given
|
||||
// dimension. When taking the view, the strides are constructed as products
|
||||
// of the original sizes along the trailing dimensions, multiplied by the view
|
||||
// step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
|
||||
// i.e. the view of a complete memref, will have strides N and 1. A view with
|
||||
// ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
|
||||
//
|
||||
// template <typename Elem, size_t Rank>
|
||||
// struct {
|
||||
// Elem *ptr;
|
||||
// int64_t offset;
|
||||
// int64_t sizes[Rank];
|
||||
// int64_t strides[Rank];
|
||||
// };
|
||||
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
|
||||
auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
||||
auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
|
||||
return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
|
||||
}
|
||||
|
||||
// All other types are kept as is.
|
||||
return t;
|
||||
}
|
||||
|
||||
// RangeOp creates a new range descriptor.
|
||||
class RangeOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit RangeOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<linalg::RangeOp>(op);
|
||||
auto rangeDescriptorType =
|
||||
linalg::convertLinalgType(rangeOp.getResult()->getType());
|
||||
|
||||
using namespace intrinsics;
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
// Fill in an aggregate value of the descriptor.
|
||||
Value *rangeDescriptor = undef(rangeDescriptorType);
|
||||
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
||||
operands[0], rewriter.getI64ArrayAttr(0));
|
||||
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
||||
operands[1], rewriter.getI64ArrayAttr(1));
|
||||
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
||||
operands[2], rewriter.getI64ArrayAttr(2));
|
||||
rewriter.replaceOp(op, rangeDescriptor);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class ViewOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ViewOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto viewOp = cast<linalg::ViewOp>(op);
|
||||
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
|
||||
auto memrefType =
|
||||
viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
|
||||
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
||||
|
||||
// Helper function to create an integer array attribute out of a list of
|
||||
// values.
|
||||
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
||||
return rewriter.getI64ArrayAttr(values);
|
||||
};
|
||||
|
||||
// Helper function to emit an LLVMIR Dialect 64-bit integer constant given
|
||||
// its value.
|
||||
auto i64cst = [&rewriter, int64Ty](int64_t value) {
|
||||
return intrinsics::constant(
|
||||
int64Ty, IntegerAttr::get(rewriter.getIndexType(), value));
|
||||
};
|
||||
|
||||
// Helper function to obtain the size of the given `memref` along the
|
||||
// dimension `dim`. For static dimensions, emits a constant; for dynamic
|
||||
// dimensions, extracts the size from the memref descriptor.
|
||||
auto memrefSize = [&rewriter, int64Ty, i64cst](
|
||||
MemRefType type, Value *memref, int dim) -> Value * {
|
||||
assert(dim < type.getRank());
|
||||
if (type.getShape()[dim] != -1) {
|
||||
return i64cst(type.getShape()[dim]);
|
||||
}
|
||||
return intrinsics::extractvalue(int64Ty, memref,
|
||||
rewriter.getI64ArrayAttr({2, dim}));
|
||||
};
|
||||
|
||||
// Helper function to obtain the data pointer of the given `memref`.
|
||||
auto memrefPtr = [pos](MemRefType type, Value *memref) -> Value * {
|
||||
auto elementTy = linalg::convertLinalgType(type.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
return intrinsics::extractvalue(elementTy, memref, pos(0));
|
||||
};
|
||||
|
||||
using namespace intrinsics;
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
// Declare the view descriptor.
|
||||
Value *viewDescriptor = undef(viewDescriptorType);
|
||||
// Insert the data pointer.
|
||||
Value *bufferPtr = memrefPtr(memrefType, operands[0]);
|
||||
viewDescriptor =
|
||||
insertvalue(viewDescriptorType, viewDescriptor, bufferPtr, pos(0));
|
||||
|
||||
// Collect all memref sizes but the first, which are needed for further
|
||||
// computation.
|
||||
SmallVector<Value *, 4> trueSizes(memrefType.getRank());
|
||||
for (int i = 1, e = memrefType.getRank(); i < e; ++i) {
|
||||
trueSizes[i] = memrefSize(memrefType, operands[0], i);
|
||||
}
|
||||
|
||||
// Compute all strides of the memref.
|
||||
SmallVector<Value *, 4> trueStrides(memrefType.getRank());
|
||||
if (viewOp.getRank() != 0)
|
||||
trueStrides[memrefType.getRank() - 1] = i64cst(1);
|
||||
for (int i = memrefType.getRank() - 2; i >= 0; --i)
|
||||
trueStrides[i] = mul(trueStrides[i + 1], trueSizes[i + 1]);
|
||||
|
||||
// Compute and insert the base offset.
|
||||
Value *baseOffset = i64cst(0);
|
||||
for (int j = 0, e = memrefType.getRank(); j < e; ++j) {
|
||||
Value *indexing = operands[1 + j];
|
||||
Value *min = viewOp.getIndexing(j)->getType().isa<linalg::RangeType>()
|
||||
? (Value *)extractvalue(int64Ty, indexing, pos(0))
|
||||
: indexing;
|
||||
Value *product = mul(min, trueStrides[j]);
|
||||
baseOffset = add(baseOffset, product);
|
||||
}
|
||||
viewDescriptor =
|
||||
insertvalue(viewDescriptorType, viewDescriptor, baseOffset, pos(1));
|
||||
|
||||
// Compute and insert view sizes (max - min along the range). Skip the
|
||||
// non-range operands as they will be projected away from the view.
|
||||
int i = 0;
|
||||
for (Value *index : viewOp.getIndexings()) {
|
||||
if (!index->getType().isa<linalg::RangeType>())
|
||||
continue;
|
||||
|
||||
Value *rangeDescriptor = operands[1 + i];
|
||||
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
||||
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
||||
Value *size = sub(max, min);
|
||||
|
||||
viewDescriptor =
|
||||
insertvalue(viewDescriptorType, viewDescriptor, size, pos({2, i}));
|
||||
++i;
|
||||
}
|
||||
|
||||
// Compute and insert view strides. Step over the strides that correspond
|
||||
// to non-range operands as they are projected away from the view.
|
||||
i = 0;
|
||||
for (int j = 0, e = trueStrides.size(); j < e; ++j) {
|
||||
if (!viewOp.getIndexing(j)->getType().isa<linalg::RangeType>())
|
||||
continue;
|
||||
Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
|
||||
Value *stride = mul(trueStrides[j], step);
|
||||
viewDescriptor =
|
||||
insertvalue(viewDescriptorType, viewDescriptor, stride, pos({3, i}));
|
||||
++i;
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, viewDescriptor);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class SliceOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit SliceOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto sliceOp = cast<linalg::SliceOp>(op);
|
||||
auto newViewDescriptorType =
|
||||
linalg::convertLinalgType(sliceOp.getViewType());
|
||||
auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
||||
return rewriter.getI64ArrayAttr(values);
|
||||
};
|
||||
|
||||
// First operand to `slice` is the old view descriptor.
|
||||
Value *oldViewDescriptor = operands[0];
|
||||
|
||||
// Properties of the slice.
|
||||
bool isRankDecreasing = sliceOp.isRankDecreasing();
|
||||
int dim = sliceOp.getSlicingDim();
|
||||
assert(isRankDecreasing ^
|
||||
sliceOp.getIndexing()->getType().isa<linalg::RangeType>());
|
||||
|
||||
// Declare the descriptor of the new view.
|
||||
using namespace intrinsics;
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
Value *newViewDescriptor = undef(newViewDescriptorType);
|
||||
|
||||
// Copy the buffer pointer from the old descriptor to the new one.
|
||||
Value *buffer = extractvalue(elementType, oldViewDescriptor, pos(0));
|
||||
newViewDescriptor =
|
||||
insertvalue(newViewDescriptorType, newViewDescriptor, buffer, pos(0));
|
||||
|
||||
// Update the base offset:
|
||||
// base_offset' = base_offset + min_d * stride_d
|
||||
// where d is the dimension being sliced, min_d is the minimum value of the
|
||||
// range (in case of a single-value slice, that value), stride_d is the
|
||||
// stride along this dimension.
|
||||
Value *baseOffset = extractvalue(int64Ty, oldViewDescriptor, pos(1));
|
||||
Value *slicingValue = operands[1];
|
||||
// If `slice` is not rank-decreasing, we need to extract the "min" value
|
||||
// from the range descriptor. Otherwise, we take the value directly.
|
||||
Value *min = !isRankDecreasing
|
||||
? (Value *)extractvalue(int64Ty, slicingValue, pos(0))
|
||||
: slicingValue;
|
||||
Value *stride = extractvalue(int64Ty, oldViewDescriptor, pos({3, dim}));
|
||||
baseOffset = add(baseOffset, mul(min, stride));
|
||||
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
||||
baseOffset, pos(1));
|
||||
|
||||
// Copy the sizes and strides into the new descriptor, updating or dropping
|
||||
// the affected dimension. If the `slice` is rank-decreasing, the resulting
|
||||
// view will no longer one of the dimensions, its size and stride become
|
||||
// unnecessary and can be dropped. Otherwise, the size of the affected
|
||||
// updated to the size of the range and its stride is multiplied with the
|
||||
// step of the range.
|
||||
for (int i = 0, e = sliceOp.getRank(); i < e; ++i) {
|
||||
int originalPos = (isRankDecreasing && i >= dim) ? i + 1 : i;
|
||||
Value *size;
|
||||
Value *stride;
|
||||
if (!isRankDecreasing && i == dim) {
|
||||
Value *upper = extractvalue(int64Ty, slicingValue, pos(1));
|
||||
Value *lower = extractvalue(int64Ty, slicingValue, pos(0));
|
||||
size = sub(upper, lower);
|
||||
|
||||
Value *previousStride =
|
||||
extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
|
||||
Value *step = extractvalue(int64Ty, slicingValue, pos(2));
|
||||
stride = mul(previousStride, step);
|
||||
} else {
|
||||
size = extractvalue(int64Ty, oldViewDescriptor, pos({2, originalPos}));
|
||||
stride =
|
||||
extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
|
||||
}
|
||||
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
||||
size, pos({2, i}));
|
||||
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
||||
stride, pos({3, i}));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, newViewDescriptor);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// When converting the "some_consumer" operation, don't emit anything and
|
||||
// effectively drop it.
|
||||
class DropConsumer : public ConversionPattern {
|
||||
public:
|
||||
explicit DropConsumer(MLIRContext *context)
|
||||
: ConversionPattern("some_consumer", 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void linalg::populateLinalg1ToLLVMConversionPatterns(
|
||||
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
|
||||
patterns.insert<DropConsumer, RangeOpConversion, SliceOpConversion,
|
||||
ViewOpConversion>(context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// A type conversion class that converts Linalg and Std types to LLVM.
|
||||
struct LinalgTypeConverter : public LLVMTypeConverter {
|
||||
using LLVMTypeConverter::LLVMTypeConverter;
|
||||
|
||||
// This gets called for block and region arguments, and attributes.
|
||||
Type convertType(Type t) override {
|
||||
if (auto result = LLVMTypeConverter::convertType(t))
|
||||
return result;
|
||||
return linalg::convertLinalgType(t);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
LinalgTypeConverter converter(module.getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, module.getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, module.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
|
||||
ConversionTarget target(*module.getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
return applyFullConversion(module, target, patterns, &converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
|
||||
void runOnModule() {
|
||||
if (failed(linalg::convertToLLVM(getModule())))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OpPassBase<ModuleOp> *linalg::createLowerLinalgToLLVMPass() {
|
||||
return new LowerLinalgToLLVMPass();
|
||||
}
|
||||
|
||||
static PassRegistration<LowerLinalgToLLVMPass>
|
||||
pass("lower-linalg-to-llvm",
|
||||
"Lower the operations from the linalg dialect into the LLVM dialect");
|
|
@ -1,98 +0,0 @@
|
|||
//===- Dialect.cpp - Implementation of the linalg dialect -----------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple Linalg dialect to which we gradually add
|
||||
// complexity.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using llvm::raw_ostream;
|
||||
using llvm::StringRef;
|
||||
using namespace mlir;
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
Type LinalgDialect::parseType(StringRef spec, Location loc) const {
|
||||
MLIRContext *context = getContext();
|
||||
if (spec == "range")
|
||||
return RangeType::get(getContext());
|
||||
|
||||
StringRef str = spec;
|
||||
if (str.consume_front("view<")) {
|
||||
// Just count the number of ? to get the rank, the type must be f32 for now.
|
||||
unsigned rank = 0;
|
||||
while (str.consume_front("?x"))
|
||||
++rank;
|
||||
if (str.consume_front("bf16>"))
|
||||
return ViewType::get(context, FloatType::getBF16(context), rank);
|
||||
if (str.consume_front("f16>"))
|
||||
return ViewType::get(context, FloatType::getF16(context), rank);
|
||||
if (str.consume_front("f32>"))
|
||||
return ViewType::get(context, FloatType::getF32(context), rank);
|
||||
if (str.consume_front("f64>"))
|
||||
return ViewType::get(context, FloatType::getF64(context), rank);
|
||||
}
|
||||
return (emitError(loc, "unknown Linalg type: " + spec), nullptr);
|
||||
}
|
||||
|
||||
/// RangeType prints as just "range".
|
||||
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
|
||||
|
||||
/// ViewType prints as:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<?x?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// or
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// for 0-D views (a.k.a pointer to a scalar value).
|
||||
static void print(linalg::ViewType rt, raw_ostream &os) {
|
||||
os << "view<";
|
||||
if (rt.getRank() > 0) {
|
||||
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
|
||||
os << "?x";
|
||||
}
|
||||
}
|
||||
os << rt.getElementType();
|
||||
os << ">";
|
||||
}
|
||||
|
||||
void LinalgDialect::printType(Type type, raw_ostream &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("Unhandled linalg type");
|
||||
case LinalgTypes::Range:
|
||||
print(type.cast<RangeType>(), os);
|
||||
break;
|
||||
case linalg::LinalgTypes::View:
|
||||
print(type.cast<linalg::ViewType>(), os);
|
||||
break;
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the constructor for the Linalg Dialect. This is
|
||||
// explicitly separated from the core library to allow incremental buildup of
|
||||
// the codebase for the tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<RangeType, ViewType>();
|
||||
addOperations<RangeOp, SliceOp, ViewOp>();
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a registration for the Linalg Dialect globally.
|
||||
// This can just be linked in as a dependence into any binary to enable the
|
||||
// Linalg dialect. Note that the binary it is linked into must not already
|
||||
// register Linalg or double registration will occur.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
// Dialect registration triggers the creation of a `LinalgDialect` object which
|
||||
// adds the proper types and operations to the dialect.
|
||||
static mlir::DialectRegistration<LinalgDialect> LinalgOps;
|
|
@ -1,74 +0,0 @@
|
|||
//===- RangeOp.cpp - Implementation of the linalg RangeOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
// Minimal example for a new RangeOp operating on RangeType.
|
||||
void linalg::RangeOp::build(Builder *b, OperationState &result, Value *min,
|
||||
Value *max, Value *step) {
|
||||
result.addOperands({min, max, step});
|
||||
result.addTypes({RangeType::get(b->getContext())});
|
||||
}
|
||||
|
||||
// Verification is simply that a RangeOp takes 3 index ssa-value.
|
||||
mlir::LogicalResult linalg::RangeOp::verify() {
|
||||
if (!getMin() || !getMin()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!getMax() || !getMax()->getType().isa<IndexType>())
|
||||
return emitOpError("second operand should be of type index");
|
||||
if (!getStep() || !getStep()->getType().isa<IndexType>())
|
||||
return emitOpError("third operand should be of type index");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
ParseResult linalg::RangeOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
|
||||
RangeType type;
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(parser.parseOperand(rangeInfo[0]) || parser.parseColon() ||
|
||||
parser.parseOperand(rangeInfo[1]) || parser.parseColon() ||
|
||||
parser.parseOperand(rangeInfo[2]) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperands(rangeInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
// A RangeOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.range %arg0:%arg1:%c42 : !linalg.range
|
||||
// ```
|
||||
void linalg::RangeOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " " << *getMin() << ":" << *getMax() << ":"
|
||||
<< *getStep();
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getType();
|
||||
}
|
|
@ -1,160 +0,0 @@
|
|||
//===- SliceOp.cpp - Implementation of the linalg SliceOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements an IR operation to extract a "sub-View" from a ViewType
|
||||
// in the Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
// A view may itself be coming either from a ViewOp or from a SliceOp.
|
||||
// TODO assert statically or dynamically that indexing is within the bounds of
|
||||
// view.
|
||||
void linalg::SliceOp::build(Builder *b, OperationState &result, Value *view,
|
||||
Value *indexing, unsigned dim) {
|
||||
// Early sanity checks + extract rank.
|
||||
unsigned rank = getViewRank(view);
|
||||
ViewType viewType = view->getType().cast<ViewType>();
|
||||
Type elementType = viewType.getElementType();
|
||||
|
||||
result.addOperands({view, indexing});
|
||||
result.addAttribute(getSlicingDimAttrName(),
|
||||
b->getIntegerAttr(b->getIndexType(), dim));
|
||||
if (indexing->getType().isa<RangeType>()) {
|
||||
// Taking a range slice does not decrease the rank, the view has the same
|
||||
// type.
|
||||
result.addTypes({viewType});
|
||||
} else {
|
||||
assert(indexing->getType().cast<IndexType>());
|
||||
result.addTypes(
|
||||
{linalg::ViewType::get(b->getContext(), elementType, rank - 1)});
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult linalg::SliceOp::verify() {
|
||||
if (!getAttr(getSlicingDimAttrName()))
|
||||
return emitOpError("slice op expects a dim attribute");
|
||||
unsigned dim = getSlicingDim();
|
||||
if (dim >= getParentRank())
|
||||
return emitOpError("slicing dim must be in the [0 .. parent_rank) range");
|
||||
if (!getOperand(0)->getType().isa<ViewType>())
|
||||
return emitOpError(
|
||||
"first operand must be of ViewType (i.e. a ViewOp or a SliceOp)");
|
||||
auto type = getOperand(1)->getType().dyn_cast<IndexType>();
|
||||
auto range = getOperand(1)->getType().dyn_cast<RangeType>();
|
||||
if (!range && !type)
|
||||
return emitOpError(
|
||||
"second operand must be of RangeType (i.e. a RangeOp) or IndexType");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
ParseResult linalg::SliceOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType viewInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 1> indexingInfo;
|
||||
SmallVector<Type, 8> types;
|
||||
if (parser.parseOperand(viewInfo) ||
|
||||
parser.parseOperandList(indexingInfo, 1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
|
||||
if (indexingInfo.size() != 1)
|
||||
return parser.emitError(parser.getNameLoc(), "expected 1 indexing type");
|
||||
|
||||
ViewType viewType = types.front().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"view type expected as first type");
|
||||
|
||||
IndexType indexType = types.back().dyn_cast<IndexType>();
|
||||
RangeType rangeType = types.back().dyn_cast<RangeType>();
|
||||
if (!indexType && !rangeType) {
|
||||
llvm::errs() << types.back();
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"indexing must be of range or index type");
|
||||
}
|
||||
|
||||
unsigned rank = viewType.getRank();
|
||||
if (indexType)
|
||||
--rank;
|
||||
ViewType resultViewType =
|
||||
ViewType::get(viewType.getContext(), viewType.getElementType(), rank);
|
||||
|
||||
return failure(
|
||||
parser.resolveOperand(viewInfo, viewType, result.operands) ||
|
||||
parser.resolveOperands(indexingInfo[0], types.back(), result.operands) ||
|
||||
parser.addTypeToList(resultViewType, result.types));
|
||||
}
|
||||
|
||||
// A SliceOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.slice %0[%i0] {dim = 0} : !linalg.view<?xf32>, index
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a `view<?x?xf32>`, %i0 is an ssa-value
|
||||
// holding an index.
|
||||
void linalg::SliceOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " " << *getParentView() << "[" << *getIndexing()
|
||||
<< "]";
|
||||
p << " {dim = " << getAttrOfType<IntegerAttr>("dim").getInt() << "}";
|
||||
p.printOptionalAttrDict(getAttrs(), {"dim"});
|
||||
p << " : " << getParentViewType() << ", " << getIndexing()->getType();
|
||||
}
|
||||
|
||||
ViewType linalg::SliceOp::getViewType() { return getType().cast<ViewType>(); }
|
||||
|
||||
unsigned linalg::SliceOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
mlir::Type linalg::SliceOp::getElementType() {
|
||||
return getViewType().getElementType();
|
||||
}
|
||||
|
||||
ViewType linalg::SliceOp::getParentViewType() {
|
||||
return getParentView()->getType().cast<ViewType>();
|
||||
}
|
||||
|
||||
unsigned linalg::SliceOp::getParentRank() {
|
||||
return getParentViewType().getRank();
|
||||
}
|
||||
|
||||
mlir::Type linalg::SliceOp::getParentElementType() {
|
||||
return getParentViewType().getElementType();
|
||||
}
|
||||
|
||||
bool linalg::SliceOp::isRankDecreasing() {
|
||||
return getParentRank() != getRank();
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::SliceOp::getIndexings() {
|
||||
return {this->getOperation()->operand_begin() + SliceOp::FirstIndexingOperand,
|
||||
this->getOperation()->operand_end()};
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
//===- Utils.cpp - Implementation of utiliy functions for Linalg ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements utility functions for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg1/Intrinsics.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
unsigned linalg::getViewRank(Value *view) {
|
||||
assert(view->getType().isa<ViewType>() && "expected a ViewType");
|
||||
if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
|
||||
return viewOp.getRank();
|
||||
return cast<SliceOp>(view->getDefiningOp()).getRank();
|
||||
}
|
||||
|
||||
ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) {
|
||||
// Syntactic sugar helper to extract and emit view-like information from an
|
||||
// mlir::MemRef without boilerplate.
|
||||
mlir::edsc::MemRefView v(memRef);
|
||||
SmallVector<Value *, 8> indices(v.rank());
|
||||
for (unsigned i = 0; i < v.rank(); ++i) {
|
||||
indices[i] = range(v.lb(i), v.ub(i), constant_index(v.step(i)));
|
||||
}
|
||||
return ScopedContext::getBuilder().create<ViewOp>(
|
||||
ScopedContext::getLocation(), memRef, indices);
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
//===- ViewOp.cpp - Implementation of the linalg ViewOp operation -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new ViewType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::SmallVector;
|
||||
using llvm::Twine;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
void linalg::ViewOp::build(Builder *b, OperationState &result, Value *memRef,
|
||||
ArrayRef<Value *> indexings) {
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
result.addOperands({memRef});
|
||||
assert(static_cast<int64_t>(indexings.size()) == memRefType.getRank() &&
|
||||
"unexpected number of indexings (must match the memref rank)");
|
||||
|
||||
result.addOperands(indexings);
|
||||
unsigned rank = memRefType.getRank();
|
||||
for (auto *v : indexings) {
|
||||
if (!v->getType().isa<RangeType>()) {
|
||||
rank--;
|
||||
}
|
||||
}
|
||||
Type elementType = memRefType.getElementType();
|
||||
result.addTypes({linalg::ViewType::get(b->getContext(), elementType, rank)});
|
||||
}
|
||||
|
||||
LogicalResult linalg::ViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a memref operand followed by 'rank' indices");
|
||||
auto memRefType = getOperand(0)->getType().dyn_cast<MemRefType>();
|
||||
unsigned memrefRank = memRefType.getRank();
|
||||
if (!memRefType)
|
||||
return emitOpError("first operand must be of MemRefType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>() &&
|
||||
!indexing->getType().isa<IndexType>()) {
|
||||
return emitOpError(Twine(index) +
|
||||
"^th index must be of range or index type");
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (llvm::size(getIndexings()) != memrefRank) {
|
||||
return emitOpError("requires at least a memref operand followed by " +
|
||||
Twine(memrefRank) + " indices");
|
||||
}
|
||||
unsigned rank = memrefRank;
|
||||
for (auto *v : getIndexings()) {
|
||||
if (!v->getType().isa<RangeType>()) {
|
||||
rank--;
|
||||
}
|
||||
}
|
||||
if (getRank() != rank) {
|
||||
return emitOpError("the rank of the view must be the number of its range "
|
||||
"indices: " +
|
||||
Twine(rank));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult linalg::ViewOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType memRefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
SmallVector<Type, 8> types;
|
||||
if (parser.parseOperand(memRefInfo) ||
|
||||
parser.parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
|
||||
if (types.size() != 2 + indexingsInfo.size())
|
||||
return parser.emitError(parser.getNameLoc(), "unexpected number of types ");
|
||||
MemRefType memRefType = types[0].dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"memRef type expected for first type");
|
||||
if (static_cast<int64_t>(indexingsInfo.size()) != memRefType.getRank())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected " + Twine(memRefType.getRank()) +
|
||||
" indexings");
|
||||
ViewType viewType = types.back().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser.emitError(parser.getNameLoc(), "view type expected");
|
||||
|
||||
ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
|
||||
if (static_cast<int64_t>(indexingTypes.size()) != memRefType.getRank())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected " + Twine(memRefType.getRank()) +
|
||||
" indexing types");
|
||||
return failure(
|
||||
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser.resolveOperands(indexingsInfo, indexingTypes,
|
||||
indexingsInfo.front().location,
|
||||
result.operands)) ||
|
||||
parser.addTypeToList(viewType, result.types));
|
||||
}
|
||||
|
||||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] :
|
||||
// memref-type, [indexing-types], !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void linalg::ViewOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " " << *getSupportingMemRef() << "[";
|
||||
unsigned numRanges = llvm::size(getIndexings());
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
p << *indexing << ((index++ == numRanges - 1) ? "" : ", ");
|
||||
}
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << "] : " << getSupportingMemRef()->getType().cast<MemRefType>();
|
||||
for (auto indexing : getIndexings()) {
|
||||
p << ", " << indexing->getType();
|
||||
}
|
||||
p << ", " << getType();
|
||||
}
|
||||
|
||||
Type linalg::ViewOp::getElementType() { return getViewType().getElementType(); }
|
||||
|
||||
ViewType linalg::ViewOp::getViewType() { return getType().cast<ViewType>(); }
|
||||
|
||||
unsigned linalg::ViewOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
// May be something else than a MemRef in the future.
|
||||
Value *linalg::ViewOp::getSupportingMemRef() {
|
||||
auto *res = getOperand(0);
|
||||
assert(res->getType().isa<MemRefType>());
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<mlir::Value *, 8> linalg::ViewOp::getRanges() {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *operand : getIndexings()) {
|
||||
if (!operand->getType().isa<mlir::IndexType>()) {
|
||||
res.push_back(operand);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Value *linalg::ViewOp::getIndexing(unsigned rank) {
|
||||
SmallVector<Value *, 1> ranges(getIndexings().begin(), getIndexings().end());
|
||||
return ranges[rank];
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::ViewOp::getIndexings() {
|
||||
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
//===- ViewType.h - Implementation of the ViewType custom type ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a custom ViewType in the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/ViewType.h"
|
||||
|
||||
using mlir::MLIRContext;
|
||||
using mlir::Type;
|
||||
using mlir::TypeStorage;
|
||||
using mlir::TypeStorageAllocator;
|
||||
|
||||
namespace linalg {
|
||||
|
||||
struct ViewTypeStorage : public mlir::TypeStorage {
|
||||
/// Underlying Key type to transport the payload needed to construct a custom
|
||||
/// type in a generic way.
|
||||
struct Key {
|
||||
Key(Type elementType, unsigned rank)
|
||||
: elementType(elementType), rank(rank) {}
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
|
||||
using KeyTy = Key;
|
||||
|
||||
/// Construction in the llvm::BumpPtrAllocator given a key.
|
||||
static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const Key &key) {
|
||||
return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
|
||||
}
|
||||
|
||||
/// Equality operator for hashing.
|
||||
bool operator==(const Key &key) const {
|
||||
return elementType == key.elementType && rank == key.rank;
|
||||
}
|
||||
|
||||
/// Hashing for unique'ing.
|
||||
static unsigned hashKey(const Key &key) {
|
||||
return llvm::hash_combine(key.elementType, key.rank);
|
||||
}
|
||||
|
||||
unsigned getRank() { return rank; };
|
||||
Type getElementType() { return elementType; };
|
||||
|
||||
private:
|
||||
ViewTypeStorage(const Key &key)
|
||||
: elementType(key.elementType), rank(key.rank) {}
|
||||
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
|
||||
ViewType linalg::ViewType::get(MLIRContext *context, Type elementType,
|
||||
unsigned rank) {
|
||||
return Base::get(context, LinalgTypes::View, elementType, rank);
|
||||
}
|
||||
|
||||
Type linalg::ViewType::getElementType() { return getImpl()->getElementType(); }
|
||||
|
||||
unsigned linalg::ViewType::getRank() { return getImpl()->getRank(); }
|
||||
|
||||
} // namespace linalg
|
|
@ -1,23 +0,0 @@
|
|||
add_subdirectory(lib)
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
)
|
||||
|
||||
set(LLVM_OPTIONAL_SOURCES Example.cpp)
|
||||
|
||||
add_llvm_example(linalg-example-2
|
||||
Example.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(linalg-example-2
|
||||
PRIVATE
|
||||
Linalg2
|
||||
Linalg2DialectConstruction
|
||||
)
|
||||
|
||||
whole_archive_link(linalg-example-2
|
||||
MLIRAffineOps
|
||||
MLIRStandardOps
|
||||
)
|
|
@ -1,120 +0,0 @@
|
|||
//===- Example.cpp - Our running example ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/test | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg2/Ops.h"
|
||||
#include "linalg2/Transforms.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
TEST_FUNC(linalg_ops) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::FuncOp f = makeFunction(*module, "linalg_ops",
|
||||
{indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(alloc(floatMemRefType<2>(&context), {M, K}), {rM ,rK}),
|
||||
vB = view(alloc(floatMemRefType<2>(&context), {K, N}), {rK, rN}),
|
||||
vC = view(alloc(floatMemRefType<2>(&context), {M, N}), {rM, rN}),
|
||||
sB = slice(vB, constant_index(0), 1),
|
||||
sC = slice(vC, constant_index(0), 1),
|
||||
sA = slice(vA, constant_index(0), 0),
|
||||
ssC = slice(sC, constant_index(0), 0);
|
||||
matmul(vA, vB, vC);
|
||||
matvec(vA, sB, sC);
|
||||
dot(sA, sB, ssC);
|
||||
ret();
|
||||
// CHECK-LABEL: func @linalg_ops(%arg0: index, %arg1: index, %arg2: index) {
|
||||
// CHECK: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim = 1} : !linalg.view<?x?xf32>, index
|
||||
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim = 1} : !linalg.view<?x?xf32>, index
|
||||
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim = 0} : !linalg.view<?x?xf32>, index
|
||||
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim = 0} : !linalg.view<?xf32>, index
|
||||
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?xf32>
|
||||
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
|
||||
// clang-format on
|
||||
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(linalg_ops_folded_slices) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::FuncOp f = makeFunction(*module, "linalg_ops_folded_slices",
|
||||
{indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(alloc(floatMemRefType<2>(&context), {M, K}), {rM, rK}),
|
||||
vB = view(alloc(floatMemRefType<2>(&context), {K, N}), {rK, rN}),
|
||||
vC = view(alloc(floatMemRefType<2>(&context), {M, N}), {rM, rN}),
|
||||
sB = slice(vB, constant_index(0), 1),
|
||||
sC = slice(vC, constant_index(0), 1),
|
||||
sA = slice(vA, constant_index(0), 0),
|
||||
ssC = slice(sC, constant_index(0), 0);
|
||||
matmul(vA, vB, vC);
|
||||
matvec(vA, sB, sC);
|
||||
dot(sA, sB, ssC);
|
||||
ret();
|
||||
// CHECK-LABEL: func @linalg_ops_folded_slices(%arg0: index, %arg1: index, %arg2: index) {
|
||||
// CHECK-NOT: linalg.slice
|
||||
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?xf32>
|
||||
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
|
||||
// clang-format on
|
||||
|
||||
f.walk([](SliceOp slice) {
|
||||
auto *sliceResult = slice.getResult();
|
||||
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
|
||||
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
||||
slice.erase();
|
||||
});
|
||||
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
int main() {
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_ANALYSIS_H_
|
||||
#define LINALG2_ANALYSIS_H_
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
|
||||
#endif // LINALG2_ANALYSIS_H_
|
|
@ -1,32 +0,0 @@
|
|||
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_INTRINSICS_H_
|
||||
#define LINALG2_INTRINSICS_H_
|
||||
|
||||
#include "linalg1/Intrinsics.h"
|
||||
#include "linalg2/Ops.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace intrinsics {
|
||||
using dot = mlir::edsc::intrinsics::OperationBuilder<DotOp>;
|
||||
using matmul = mlir::edsc::intrinsics::OperationBuilder<MatmulOp>;
|
||||
using matvec = mlir::edsc::intrinsics::OperationBuilder<MatvecOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_INTRINSICS_H_
|
|
@ -1,24 +0,0 @@
|
|||
//===- Ops.h - Linalg Ops single entry point ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_OPS_H_
|
||||
#define LINALG2_OPS_H_
|
||||
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
#endif // LINALG2_OPS_H_
|
|
@ -1,121 +0,0 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG2_TENSOROPS_INL_H_
|
||||
#define LINALG2_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin(), op->operand_begin() + getNumInputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin() + getNumInputs(),
|
||||
op->operand_begin() + getNumInputs() + getNumOutputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
|
||||
return {getInputs().begin(), getOutputs().end()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
|
||||
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
if (getNumInputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one input");
|
||||
if (getNumOutputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one output");
|
||||
if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
|
||||
concreteOp->emitOpError("expected " +
|
||||
llvm::Twine(getNumInputs() + getNumOutputs()) +
|
||||
" operands");
|
||||
}
|
||||
for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
|
||||
if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
}
|
||||
for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
|
||||
++i) {
|
||||
auto viewType =
|
||||
concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
if (viewType.getRank() != getNumParallelDims())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" must be of rank " +
|
||||
llvm::Twine(getNumParallelDims()));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::ParseResult
|
||||
linalg::TensorContractionBase<ConcreteOp>::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
}
|
||||
|
||||
// A TensorContraction prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
|
||||
// ```
|
||||
//
|
||||
// for example:
|
||||
//
|
||||
// ```
|
||||
// linalg.matmul(%0, %1, %2) : view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type ViewType.
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter &p) {
|
||||
p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
|
||||
auto *last = *std::prev(getInputsAndOutputs().end());
|
||||
for (auto *i : getInputsAndOutputs()) {
|
||||
p << *i << ((i == last) ? "" : ", ");
|
||||
}
|
||||
p << ") : ";
|
||||
auto *lastOutput = *std::prev(getOutputs().end());
|
||||
for (auto *o : getOutputs()) {
|
||||
p << o->getType() << ((o == lastOutput) ? "" : ",");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_TENSOROPS_INL_H_
|
|
@ -1,291 +0,0 @@
|
|||
//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_TENSOROPS_H_
|
||||
#define LINALG2_TENSOROPS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineForOp;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A generic TensorContraction base class which captures the generic behavior
|
||||
/// of tensor contraction operations (with broadcast).
|
||||
template <class ConcreteOp> class TensorContractionBase {
|
||||
protected:
|
||||
using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Generic implementation of hooks that should be called from `ConcreteType`s
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TensorContractionBase() = default;
|
||||
mlir::Operation::operand_range getInputs();
|
||||
mlir::Operation::operand_range getOutputs();
|
||||
mlir::Operation::operand_range getInputsAndOutputs();
|
||||
|
||||
/// These are better as methods calling into the ConcreteOp instead of
|
||||
/// template parameters because methods allow more generic behavior and avoid
|
||||
/// specializing for number of arguments. All derived classes have
|
||||
/// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
|
||||
/// and the proper number of mlir::Value*.
|
||||
unsigned getNumInputs() {
|
||||
return static_cast<ConcreteOp *>(this)->numInputs;
|
||||
};
|
||||
unsigned getNumOutputs() {
|
||||
return static_cast<ConcreteOp *>(this)->numOutputs;
|
||||
};
|
||||
unsigned getNumParallelDims() {
|
||||
return static_cast<ConcreteOp *>(this)->numParallelDims;
|
||||
};
|
||||
unsigned getNumReductionDims() {
|
||||
return static_cast<ConcreteOp *>(this)->numReductionDims;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
mlir::Value *getInputView(unsigned viewIndex);
|
||||
mlir::Value *getOutputView(unsigned viewIndex);
|
||||
mlir::Value *getView(unsigned viewIndex) {
|
||||
return viewIndex < getNumInputs()
|
||||
? getInputView(viewIndex)
|
||||
: getOutputView(viewIndex - getNumInputs());
|
||||
}
|
||||
|
||||
/// Each op is responsible for declaring how it lowers itself to scalar form,
|
||||
/// given the enclosing parallel and reduction induction variables.
|
||||
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
|
||||
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
|
||||
/// `parallel.back()`).
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
|
||||
/// Represents a mapping from the loops to all the ranges of the operands.
|
||||
/// The operands and their ranges are in the order defined by the particular
|
||||
/// ConcreteOp implementation, the resulting map must match those.
|
||||
/// In favorable cases, this can be calculated by an analysis but specifying
|
||||
/// it explicitly is not expensive and generalizes to cases where an analysis
|
||||
/// is not available. For details, see the description of
|
||||
/// loopsToOperandRangeMaps in each ConcreteOp.
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
};
|
||||
|
||||
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
||||
class DotOp : public TensorContractionBase<DotOp>,
|
||||
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.dot"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
llvm::ArrayRef<mlir::Value *> operands);
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *A, mlir::Value *B, mlir::Value *C) {
|
||||
return build(b, result, {A, B, C});
|
||||
}
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 0;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(K), B(K), C() is:
|
||||
/// (d0) -> (d0, d0)(%k)
|
||||
/// And the operands ranges are:
|
||||
/// (%k, %k)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
|
||||
/// to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[] from memory (0-D tensor)
|
||||
/// 2. multiply A[r_i] by B[r_i] and add to scalarC
|
||||
/// 3. store back scalarC at C[]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_i == zero)
|
||||
/// scalarC = select(cond, zerof, C[]);
|
||||
/// C[] = scalarC + A[r_i] * B[r_i];
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
|
||||
class MatvecOp : public TensorContractionBase<MatvecOp>,
|
||||
public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.matvec"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
llvm::ArrayRef<mlir::Value *> operands);
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *A, mlir::Value *B, mlir::Value *C) {
|
||||
return build(b, result, {A, B, C});
|
||||
}
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 1;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
|
||||
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %m)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
|
||||
/// loop with iv `r_j`, emits MLIR corresponding to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[i]
|
||||
/// 2. multiply A[i, r_j] by B[r_j] and add to scalarC
|
||||
/// 3. store back scalarC at C[i]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_j == zero)
|
||||
/// scalarC = select(cond, zerof, C(i));
|
||||
/// C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B on 2-D matrices.
|
||||
class MatmulOp : public TensorContractionBase<MatmulOp>,
|
||||
public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.matmul"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
llvm::ArrayRef<mlir::Value *> operands);
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *A, mlir::Value *B, mlir::Value *C) {
|
||||
return build(b, result, {A, B, C});
|
||||
}
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 2;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
|
||||
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %n, %m, %n)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
|
||||
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[i, j]
|
||||
/// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
|
||||
/// 3. store back scalarC at C[i, j]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_k == zero)
|
||||
/// scalarC = select(cond, zerof, C[i, j]);
|
||||
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg2/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG2_TENSOROPS_H_
|
|
@ -1,36 +0,0 @@
|
|||
//===- Transforms.h - Linalg dialect Transformations definition -----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_TRANSFORMS_H_
|
||||
#define LINALG2_TRANSFORMS_H_
|
||||
|
||||
namespace mlir {
|
||||
class Value;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class ViewOp;
|
||||
|
||||
/// Takes a `view` of type ViewType (i.e. either a ViewOp or a SliceOp) and
|
||||
/// composes away all the SliceOp to return a single ViewOp.
|
||||
/// Inserts the required operations after `view`.
|
||||
ViewOp emitAndReturnFullyComposedView(mlir::Value *v);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_TRANSFORMS_H_
|
|
@ -1,31 +0,0 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
DialectConstruction.cpp
|
||||
TensorOps.cpp
|
||||
Transforms.cpp
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg2
|
||||
TensorOps.cpp
|
||||
Transforms.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(Linalg2
|
||||
PUBLIC
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
Linalg1
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg2DialectConstruction
|
||||
DialectConstruction.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(Linalg2DialectConstruction
|
||||
PUBLIC Linalg2)
|
|
@ -1,33 +0,0 @@
|
|||
//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the constructor for the Linalg Dialect. This is
|
||||
// explicitly separated from the core library to allow incremental buildup of
|
||||
// the codebase for the tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Ops.h"
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<RangeType, ViewType>();
|
||||
addOperations<DotOp, MatvecOp, MatmulOp, RangeOp, SliceOp, ViewOp>();
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
//===- TensorOps.cpp - Implementation of the linalg TensorOps operation ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create new tensor computation
|
||||
// operations in the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Twine;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Dot.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::DotOp::build(Builder *b, OperationState &result,
|
||||
ArrayRef<Value *> operands) {
|
||||
result.addOperands(operands);
|
||||
}
|
||||
|
||||
LogicalResult linalg::DotOp::verify() {
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
unsigned index = 0;
|
||||
for (auto *v : {A, B}) {
|
||||
if (getViewRank(v) != 1)
|
||||
return emitOpError("operand " + Twine(index++) + " must be of rank 1");
|
||||
}
|
||||
if (getViewRank(C) != 0)
|
||||
return emitOpError("operand 2 must be of rank 0");
|
||||
// TODO(ntv): check ranges match.
|
||||
return success();
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
ParseResult linalg::DotOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::DotOp::print(mlir::OpAsmPrinter &p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Matvec.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::MatvecOp::build(Builder *b, OperationState &result,
|
||||
ArrayRef<Value *> operands) {
|
||||
result.addOperands(operands);
|
||||
}
|
||||
|
||||
LogicalResult linalg::MatvecOp::verify() {
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
if (getViewRank(A) != 2)
|
||||
return emitOpError("operand 0 must be of rank 2");
|
||||
unsigned index = 0;
|
||||
for (auto *v : {B, C}) {
|
||||
if (getViewRank(v) != 1)
|
||||
return emitOpError("operand " + Twine(1 + index++) +
|
||||
" must be of rank 1");
|
||||
}
|
||||
// TODO(ntv): check ranges match.
|
||||
return success();
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
ParseResult linalg::MatvecOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::MatvecOp::print(mlir::OpAsmPrinter &p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Matmul.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::MatmulOp::build(Builder *b, OperationState &result,
|
||||
ArrayRef<Value *> operands) {
|
||||
result.addOperands(operands);
|
||||
}
|
||||
|
||||
LogicalResult linalg::MatmulOp::verify() {
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
unsigned index = 0;
|
||||
for (auto *v : {A, B, C}) {
|
||||
if (getViewRank(v) != 2)
|
||||
return emitOpError("operand " + Twine(index++) + " must be of rank 2");
|
||||
}
|
||||
// TODO(ntv): check ranges match.
|
||||
return success();
|
||||
}
|
||||
|
||||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
ParseResult linalg::MatmulOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::MatmulOp::print(mlir::OpAsmPrinter &p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
|
@ -1,116 +0,0 @@
|
|||
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements analyses and transformations for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg2/Transforms.h"
|
||||
#include "linalg2/Analysis.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::cast;
|
||||
using llvm::isa;
|
||||
using llvm::SmallVector;
|
||||
using mlir::MemRefType;
|
||||
using mlir::OpBuilder;
|
||||
using mlir::Value;
|
||||
using mlir::edsc::ScopedContext;
|
||||
using mlir::edsc::ValueHandle;
|
||||
using mlir::edsc::intrinsics::constant_index;
|
||||
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
// We need to traverse the slice chain from the original ViewOp for various
|
||||
// analyses. This builds the chain.
|
||||
static SmallVector<Value *, 8> getViewChain(mlir::Value *v) {
|
||||
assert(v->getType().isa<ViewType>() && "ViewType expected");
|
||||
if (isa<ViewOp>(v->getDefiningOp())) {
|
||||
return SmallVector<mlir::Value *, 8>{v};
|
||||
}
|
||||
|
||||
SmallVector<mlir::Value *, 8> tmp;
|
||||
do {
|
||||
auto sliceOp = cast<SliceOp>(v->getDefiningOp()); // must be a slice op
|
||||
tmp.push_back(v);
|
||||
v = sliceOp.getParentView();
|
||||
} while (!v->getType().isa<ViewType>());
|
||||
assert(isa<ViewOp>(v->getDefiningOp()) && "must be a ViewOp");
|
||||
tmp.push_back(v);
|
||||
return SmallVector<mlir::Value *, 8>(tmp.rbegin(), tmp.rend());
|
||||
}
|
||||
|
||||
static mlir::Value *createFullyComposedIndexing(unsigned dim,
|
||||
ArrayRef<Value *> chain) {
|
||||
using namespace mlir::edsc::op;
|
||||
assert(chain.front()->getType().isa<ViewType>() && "must be a ViewType");
|
||||
auto viewOp = cast<ViewOp>(chain.front()->getDefiningOp());
|
||||
auto *indexing = viewOp.getIndexing(dim);
|
||||
if (!indexing->getType().isa<RangeType>())
|
||||
return indexing;
|
||||
auto rangeOp = cast<RangeOp>(indexing->getDefiningOp());
|
||||
Value *min = rangeOp.getMin(), *max = rangeOp.getMax(),
|
||||
*step = rangeOp.getStep();
|
||||
for (auto *v : chain.drop_front(1)) {
|
||||
auto slice = cast<SliceOp>(v->getDefiningOp());
|
||||
if (slice.getRank() != slice.getParentRank()) {
|
||||
// Rank-reducing slice.
|
||||
if (slice.getSlicingDim() == dim) {
|
||||
// Slice a single element across dim -> done.
|
||||
return ValueHandle(min) +
|
||||
ValueHandle(slice.getIndexing()) * ValueHandle(step);
|
||||
}
|
||||
// Adjust the dim to account for the slice.
|
||||
dim = (slice.getSlicingDim() < dim) ? dim - 1 : dim;
|
||||
} else { // not a rank-reducing slice.
|
||||
if (slice.getSlicingDim() == dim) {
|
||||
auto range = cast<RangeOp>(slice.getIndexing()->getDefiningOp());
|
||||
auto oldMin = min;
|
||||
min = ValueHandle(min) + ValueHandle(range.getMin());
|
||||
// ideally: max = min(oldMin + ValueHandle(range.getMax()), oldMax);
|
||||
// but we cannot represent min/max with index and have it compose with
|
||||
// affine.map atm.
|
||||
max = ValueHandle(oldMin) + ValueHandle(range.getMax());
|
||||
// ideally: parametric steps.
|
||||
// but we cannot represent parametric steps with index atm.
|
||||
step = ValueHandle(step) * ValueHandle(range.getStep());
|
||||
}
|
||||
}
|
||||
}
|
||||
return linalg::intrinsics::range(min, max, step).getValue();
|
||||
}
|
||||
|
||||
ViewOp linalg::emitAndReturnFullyComposedView(Value *v) {
|
||||
OpBuilder builder(v->getDefiningOp());
|
||||
ScopedContext scope(builder, v->getDefiningOp()->getLoc());
|
||||
assert(v->getType().isa<ViewType>() && "must be a ViewType");
|
||||
auto *memRef = getViewSupportingMemRef(v);
|
||||
auto chain = getViewChain(v);
|
||||
unsigned rank = memRef->getType().cast<MemRefType>().getRank();
|
||||
SmallVector<Value *, 8> ranges;
|
||||
ranges.reserve(rank);
|
||||
for (unsigned idx = 0; idx < rank; ++idx) {
|
||||
ranges.push_back(createFullyComposedIndexing(idx, chain));
|
||||
}
|
||||
return cast<ViewOp>(view(memRef, ranges).getOperation());
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
add_definitions(-DLINALG_STEP=3)
|
||||
|
||||
add_subdirectory(lib)
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
OrcJIT
|
||||
Support
|
||||
native
|
||||
)
|
||||
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
Conversion.cpp
|
||||
Example.cpp
|
||||
Execution.cpp
|
||||
)
|
||||
|
||||
add_llvm_example(linalg-conversion-3
|
||||
Conversion.cpp
|
||||
)
|
||||
|
||||
add_llvm_example(linalg-example-3
|
||||
Example.cpp
|
||||
)
|
||||
|
||||
add_llvm_example(linalg-execution-3
|
||||
Execution.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(linalg-example-3
|
||||
PRIVATE
|
||||
Linalg3
|
||||
Linalg3DialectConstruction
|
||||
)
|
||||
|
||||
whole_archive_link(linalg-example-3
|
||||
MLIRAffineOps
|
||||
MLIRStandardOps
|
||||
)
|
||||
|
||||
target_link_libraries(linalg-conversion-3
|
||||
PRIVATE
|
||||
Linalg3
|
||||
Linalg3DialectConstruction
|
||||
)
|
||||
|
||||
whole_archive_link(linalg-conversion-3
|
||||
MLIRStandardOps
|
||||
)
|
||||
|
||||
target_link_libraries(linalg-execution-3
|
||||
PRIVATE
|
||||
MLIRExecutionEngine
|
||||
Linalg3
|
||||
Linalg3DialectConstruction
|
||||
)
|
||||
|
||||
whole_archive_link(linalg-execution-3
|
||||
MLIRStandardOps
|
||||
)
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
//===- Conversion.cpp - Linalg to LLVM conversion driver ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/conversion | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
|
||||
#include "linalg3/ConvertToLLVMDialect.h"
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using llvm::StringRef;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
FuncOp makeFunctionWithAMatmulOp(ModuleOp module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::FuncOp f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
TEST_FUNC(foo) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
|
||||
convertLinalg3ToLLVM(*module);
|
||||
|
||||
// clang-format off
|
||||
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: {{.*}} = llvm.load {{.*}} : !llvm<"float*">
|
||||
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: {{.*}} = llvm.load {{.*}} : !llvm<"float*">
|
||||
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.mul {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*">
|
||||
// clang-format on
|
||||
module->print(llvm::outs());
|
||||
}
|
||||
|
||||
int main() {
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -1,205 +0,0 @@
|
|||
//===- Example.cpp - Our running example ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/test | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using llvm::StringRef;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
FuncOp makeFunctionWithAMatmulOp(ModuleOp module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::FuncOp f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
mlir::OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_matvec) {
|
||||
MLIRContext context;
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
|
||||
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK: linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_dot) {
|
||||
MLIRContext context;
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_dot(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
|
||||
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[M]] {
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, index, index, !linalg.view<f32>
|
||||
// CHECK-NEXT: linalg.dot(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<f32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_loops) {
|
||||
MLIRContext context;
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_loops(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[rM:.*]] = linalg.range %{{.*}}:%[[M]]:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[rN:.*]] = linalg.range %{{.*}}:%[[N]]:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%[[K]]:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%[[rM]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[%[[rK]], %[[rN]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[%[[rM]], %[[rN]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
|
||||
// CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index
|
||||
// CHECK: %{{.*}} = linalg.load %[[vC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %{{.*}} = select {{.*}} : f32
|
||||
// CHECK: %{{.*}} = linalg.load %[[vB]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %{{.*}} = linalg.load %[[vA]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %{{.*}} = mulf {{.*}} : f32
|
||||
// CHECK: %{{.*}} = addf {{.*}} : f32
|
||||
// CHECK: linalg.store {{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_matvec_as_loops) {
|
||||
MLIRContext context;
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec_as_loops(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
|
||||
// CHECK: %[[vB:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK: %[[vC:.*]] = linalg.view %{{.*}}[{{.*}}, {{.*}}] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
|
||||
// CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
|
||||
// CHECK: %[[C:.*]] = linalg.load %[[vC]][%{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
|
||||
// CHECK: %[[B:.*]] = linalg.load %[[vB]][%{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[A:.*]] = linalg.load %[[vA]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32
|
||||
// CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32
|
||||
// CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}] : !linalg.view<?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_matvec_as_affine) {
|
||||
MLIRContext context;
|
||||
ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
lowerToLoops(f);
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createLowerLinalgLoadStorePass());
|
||||
if (succeeded(pm.run(f.getParentOfType<mlir::ModuleOp>())))
|
||||
cleanupAndPrintFunction(f);
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec_as_affine(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] {
|
||||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
|
||||
// CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
|
||||
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NOT: {{.*}} = linalg.
|
||||
// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
int main() {
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -1,166 +0,0 @@
|
|||
//===- Conversion.cpp - Linalg to LLVM execution driver -------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "TestHarness.h"
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/ConvertToLLVMDialect.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "linalg3/Transforms.h"
|
||||
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
|
||||
// RUN: %p/execution | FileCheck %s
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
FuncOp makeFunctionWithAMatmulOp(ModuleOp module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::FuncOp f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
mlir::OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
// Representation of a Memref descriptor for a 2D dynamically-sized Memref in C.
|
||||
// This is equivalent to the structure that the conversion produces.
|
||||
struct MemRefDescriptor2D {
|
||||
float *ptr;
|
||||
int64_t offset;
|
||||
int64_t sizes[2];
|
||||
int64_t strides[2];
|
||||
};
|
||||
|
||||
// Alocate a 2D memref of the given size, store the sizes in the descriptor and
|
||||
// initialize all values with 1.0f.
|
||||
static MemRefDescriptor2D allocateInit2DMemref(int64_t sz1, int64_t sz2) {
|
||||
MemRefDescriptor2D descriptor;
|
||||
descriptor.ptr = static_cast<float *>(malloc(sizeof(float) * sz1 * sz2));
|
||||
descriptor.offset = 0;
|
||||
descriptor.sizes[0] = sz1;
|
||||
descriptor.sizes[1] = sz2;
|
||||
descriptor.strides[0] = sz2;
|
||||
descriptor.strides[1] = 1;
|
||||
for (int64_t i = 0, e = sz1 * sz2; i < e; ++i)
|
||||
descriptor.ptr[i] = 1.0f;
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
// Print the contents of the memref given its descriptor.
|
||||
static void print2DMemref(const MemRefDescriptor2D &descriptor) {
|
||||
for (int64_t i = 0; i < descriptor.sizes[0]; ++i) {
|
||||
llvm::outs() << '[';
|
||||
for (int64_t j = 0; j < descriptor.sizes[1]; ++j) {
|
||||
if (j != 0)
|
||||
llvm::outs() << ", ";
|
||||
llvm::outs() << descriptor.ptr[i * descriptor.strides[0] +
|
||||
j * descriptor.strides[1]];
|
||||
}
|
||||
llvm::outs() << "]\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Free a 2D memref given its descriptor. Resets the pointer in the descriptor
|
||||
// to nullptr.
|
||||
static void free2DMemref(MemRefDescriptor2D &descriptor) {
|
||||
free(descriptor.ptr);
|
||||
descriptor.ptr = nullptr;
|
||||
}
|
||||
|
||||
TEST_FUNC(execution) {
|
||||
// Create an MLIR module, create a function "matmul_as_loops" containing a
|
||||
// linalg.matmul operation and lower it all the way down to the LLVM IR
|
||||
// dialect through partial conversions.
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
convertLinalg3ToLLVM(*module);
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
||||
// the module.
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(*module);
|
||||
assert(maybeEngine && "failed to construct an execution engine");
|
||||
auto &engine = maybeEngine.get();
|
||||
|
||||
// Prepare arguments for the function invocation: allocate input and output
|
||||
// buffers.
|
||||
auto A = allocateInit2DMemref(5, 3);
|
||||
auto B = allocateInit2DMemref(3, 2);
|
||||
auto C = allocateInit2DMemref(5, 2);
|
||||
auto *pA = &A, *pB = &B, *pC = &C;
|
||||
llvm::SmallVector<void *, 3> args({&pA, &pB, &pC});
|
||||
|
||||
// Invoke the JIT-compiled function with the arguments. Note that, for API
|
||||
// uniformity reasons, it takes a list of type-erased pointers to arguments.
|
||||
auto invocationResult =
|
||||
engine->invoke("matmul_as_loops", MutableArrayRef<void *>(args));
|
||||
assert(!invocationResult && "call failed");
|
||||
|
||||
// clang-format off
|
||||
// CHECK: [3.000000e+00, 3.000000e+00]
|
||||
// CHECK-NEXT: [3.000000e+00, 3.000000e+00]
|
||||
// CHECK-NEXT: [3.000000e+00, 3.000000e+00]
|
||||
// CHECK-NEXT: [3.000000e+00, 3.000000e+00]
|
||||
// CHECK-NEXT: [3.000000e+00, 3.000000e+00]
|
||||
// clang-format on
|
||||
print2DMemref(C);
|
||||
|
||||
// Cleanup.
|
||||
free2DMemref(A);
|
||||
free2DMemref(B);
|
||||
free2DMemref(C);
|
||||
}
|
||||
|
||||
int main() {
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_ANALYSIS_H_
|
||||
#define LINALG3_ANALYSIS_H_
|
||||
|
||||
#include "linalg2/Analysis.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Given a `map` specification and a subset of its results
|
||||
/// `[beginResult, endResult)`, returns the inverse map that maps result
|
||||
/// positions to dim positions.
|
||||
mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
|
||||
unsigned endResult = 0);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_ANALYSIS_H_
|
|
@ -1,30 +0,0 @@
|
|||
//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_CONVERTTOLLVMDIALECT_H_
|
||||
#define LINALG3_CONVERTTOLLVMDIALECT_H_
|
||||
|
||||
namespace mlir {
|
||||
struct LogicalResult;
|
||||
class ModuleOp;
|
||||
} // end namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
mlir::LogicalResult convertLinalg3ToLLVM(mlir::ModuleOp module);
|
||||
} // end namespace linalg
|
||||
|
||||
#endif // LINALG3_CONVERTTOLLVMDIALECT_H_
|
|
@ -1,31 +0,0 @@
|
|||
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_INTRINSICS_H_
|
||||
#define LINALG3_INTRINSICS_H_
|
||||
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace intrinsics {
|
||||
using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
|
||||
using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_INTRINSICS_H_
|
|
@ -1,91 +0,0 @@
|
|||
//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_LOADSTOREOP_H_
|
||||
#define LINALG3_LOADSTOREOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class ViewType;
|
||||
|
||||
/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
|
||||
/// instead of MemRefType.
|
||||
class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.load"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *view,
|
||||
mlir::ArrayRef<mlir::Value *> indices = {});
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
unsigned getRank();
|
||||
ViewType getViewType();
|
||||
mlir::Value *getView() { return getOperand(0); }
|
||||
mlir::Operation::operand_range getIndices() {
|
||||
return {operand_begin() + 1, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
/// A linalg.StoreOp is the counterpart of affine.store but operating on
|
||||
/// ViewType instead of MemRefType.
|
||||
class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.store"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState &result,
|
||||
mlir::Value *valueToStore, mlir::Value *view,
|
||||
mlir::ArrayRef<mlir::Value *> indices = {});
|
||||
mlir::LogicalResult verify();
|
||||
static mlir::ParseResult parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result);
|
||||
void print(mlir::OpAsmPrinter &p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
unsigned getRank();
|
||||
ViewType getViewType();
|
||||
mlir::Value *getValueToStore() { return getOperand(0); }
|
||||
mlir::Value *getView() { return getOperand(1); }
|
||||
mlir::Operation::operand_range getIndices() {
|
||||
return {operand_begin() + 2, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_LOADSTOREOP_H_
|
|
@ -1,25 +0,0 @@
|
|||
//===- Ops.h - Linalg Ops single entry point ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_OPS_H_
|
||||
#define LINALG3_OPS_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "linalg3/LoadStoreOps.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
|
||||
#endif // LINALG3_OPS_H_
|
|
@ -1,145 +0,0 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG3_TENSOROPS_INL_H_
|
||||
#define LINALG3_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg2/TensorOps.h"
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
|
||||
return *(getInputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
|
||||
return *(getOutputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::AffineMap, 8>
|
||||
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
|
||||
return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
|
||||
llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) {
|
||||
static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
|
||||
reductionIvs);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
mlir::AffineMap current;
|
||||
// Individual submaps may not be invertible but their union must be invertible
|
||||
// by construction.
|
||||
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
|
||||
if (!m)
|
||||
continue;
|
||||
if (!current) {
|
||||
current = m;
|
||||
continue;
|
||||
}
|
||||
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
|
||||
current.getResults().end());
|
||||
results.append(m.getResults().begin(), m.getResults().end());
|
||||
current = mlir::AffineMap::get(
|
||||
std::max(current.getNumDims(), m.getNumDims()),
|
||||
current.getNumSymbols() + m.getNumSymbols(), results);
|
||||
}
|
||||
return inverseSubMap(current);
|
||||
}
|
||||
|
||||
// Extract the ranges from a given ViewOp or SliceOp.
|
||||
//
|
||||
// In the case of a ViewOp, things are simple: just traverse the indexings and
|
||||
// get all the ranges (i.e. drop the indices).
|
||||
//
|
||||
// In the case of a SliceOp, things are trickier because we need to handle a
|
||||
// potential rank-reduction:
|
||||
// 1. Examine the indexing to determine if it is rank-reducing.
|
||||
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
|
||||
// that `d >= slicingDim`. This is to account for the rank reduction.
|
||||
// `getRootIndex` is then called on the **parent** view
|
||||
inline llvm::SmallVector<mlir::Value *, 8>
|
||||
extractRangesFromViewOrSliceOp(mlir::Value *view) {
|
||||
// This expects a viewType which must come from either ViewOp or SliceOp.
|
||||
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
|
||||
if (auto viewOp = llvm::dyn_cast<linalg::ViewOp>(view->getDefiningOp()))
|
||||
return viewOp.getRanges();
|
||||
|
||||
auto sliceOp = llvm::cast<linalg::SliceOp>(view->getDefiningOp());
|
||||
unsigned slicingDim = sliceOp.getSlicingDim();
|
||||
auto *indexing = *(sliceOp.getIndexings().begin());
|
||||
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
|
||||
unsigned offset = 0;
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
res.reserve(sliceOp.getRank());
|
||||
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
|
||||
if (d == slicingDim && isRankReducing)
|
||||
offset = 1;
|
||||
auto *parentView = sliceOp.getParentView();
|
||||
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
|
||||
res.push_back(indexingPosPair.first);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *in : tensorContraction.getInputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(in);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *out : tensorContraction.getOutputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(out);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
|
||||
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
|
||||
res.append(tmp.begin(), tmp.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif // LINALG3_TENSOROPS_INL_H_
|
|
@ -1,54 +0,0 @@
|
|||
//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_TENSOROPS_H_
|
||||
#define LINALG3_TENSOROPS_H_
|
||||
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
///
|
||||
/// Ideally all these functions would go in an Analysis but as long as
|
||||
/// TensorContractionBase is templated, they need to remain close enough.
|
||||
///
|
||||
|
||||
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
|
||||
/// map ranges to enclosing loops for all the operands' ranges.
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
/// Takes a `tensorContraction` and returns the ranges of all its operands.
|
||||
/// When an operand comes from a ViewOp, things are simple:
|
||||
/// just traverse the indexings and get all the ranges
|
||||
/// (i.e. drop the rank-reducing indices).
|
||||
/// In the case of a SliceOp, things are more involved because we need to handle
|
||||
/// potential rank-reductions.
|
||||
/// This function abstracts this complexity away and returns all the ranges.
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8>
|
||||
getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg3/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG3_TENSOROPS_H_
|
|
@ -1,82 +0,0 @@
|
|||
//===- Transforms.h - Linalg dialect Transformations definition -----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_TRANSFORMS_H_
|
||||
#define LINALG3_TRANSFORMS_H_
|
||||
|
||||
#include "linalg2/Transforms.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineForOp;
|
||||
class AffineMap;
|
||||
class FuncOp;
|
||||
class Operation;
|
||||
class Value;
|
||||
|
||||
template <typename T> class OpPassBase;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
struct RangeParts {
|
||||
explicit RangeParts(unsigned reserved);
|
||||
RangeParts(llvm::ArrayRef<mlir::Value *> ranges);
|
||||
llvm::SmallVector<mlir::Value *, 4> makeRanges();
|
||||
|
||||
llvm::SmallVector<mlir::Value *, 4> mins;
|
||||
llvm::SmallVector<mlir::Value *, 4> maxes;
|
||||
llvm::SmallVector<mlir::Value *, 4> steps;
|
||||
};
|
||||
|
||||
mlir::Value *
|
||||
makeFoldedComposedAffineApply(mlir::AffineMap map,
|
||||
llvm::ArrayRef<mlir::Value *> operandsRef);
|
||||
|
||||
llvm::SmallVector<mlir::Value *, 4>
|
||||
makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
|
||||
llvm::ArrayRef<mlir::Value *> ranges,
|
||||
llvm::ArrayRef<mlir::Value *> tileSizes = {});
|
||||
|
||||
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
|
||||
/// to only use linalg.view operations.
|
||||
void composeSliceOps(mlir::FuncOp f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
|
||||
/// as linalg.matvec(resp. linalg.dot).
|
||||
void lowerToFinerGrainedTensorContraction(mlir::FuncOp f);
|
||||
|
||||
/// Operation-wise writing of linalg operations to loop form.
|
||||
/// It is the caller's responsibility to erase the `op` if necessary.
|
||||
/// This returns the enclosing loops around the body of `op` for further
|
||||
/// composition of transformations.
|
||||
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
|
||||
writeAsLoops(mlir::Operation *op);
|
||||
|
||||
/// Traverses `f` and rewrites linalg operations in loop form.
|
||||
void lowerToLoops(mlir::FuncOp f);
|
||||
|
||||
/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and
|
||||
/// affine.store operations.
|
||||
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
|
||||
createLowerLinalgLoadStorePass();
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_TRANSFORMS_H_
|
|
@ -1,62 +0,0 @@
|
|||
//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
using namespace mlir;
|
||||
|
||||
// Compute an inverse map (only works with permutations for now).
|
||||
// Note that the mapping is generally non-full rank, so this returns the first
|
||||
// seen entry for each dim.
|
||||
static AffineMap inversePermutationMap(AffineMap map) {
|
||||
SmallVector<AffineExpr, 4> exprs(map.getNumDims());
|
||||
for (auto en : llvm::enumerate(map.getResults())) {
|
||||
auto expr = en.value();
|
||||
auto d = expr.dyn_cast<AffineDimExpr>();
|
||||
assert(d && "permutation map expected");
|
||||
if (exprs[d.getPosition()])
|
||||
continue;
|
||||
exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
|
||||
}
|
||||
SmallVector<AffineExpr, 4> seenExprs;
|
||||
seenExprs.reserve(map.getNumDims());
|
||||
for (auto expr : exprs)
|
||||
if (expr)
|
||||
seenExprs.push_back(expr);
|
||||
assert(map.getNumSymbols() == 0 && "expected map without symbols");
|
||||
assert(seenExprs.size() == map.getNumInputs() && "map is not invertible");
|
||||
return AffineMap::get(map.getNumResults(), 0, seenExprs);
|
||||
}
|
||||
|
||||
mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult,
|
||||
unsigned endResult) {
|
||||
if (beginResult == 0 && endResult == 0)
|
||||
endResult = map.getNumResults();
|
||||
auto subMap = AffineMap::get(
|
||||
map.getNumDims(), map.getNumSymbols(),
|
||||
map.getResults().slice(beginResult, endResult - beginResult));
|
||||
return inversePermutationMap(subMap);
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
Analysis.cpp
|
||||
ConvertToLLVMDialect.cpp
|
||||
LoadStoreOps.cpp
|
||||
Transforms.cpp
|
||||
DialectConstruction.cpp
|
||||
TensorOps.cpp
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg3
|
||||
Analysis.cpp
|
||||
ConvertToLLVMDialect.cpp
|
||||
LoadStoreOps.cpp
|
||||
Transforms.cpp
|
||||
TensorOps.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(Linalg3
|
||||
PUBLIC
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIRLoopToStandard
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardToLLVM
|
||||
MLIRTransforms
|
||||
Linalg2
|
||||
)
|
||||
|
||||
add_llvm_library(Linalg3DialectConstruction
|
||||
DialectConstruction.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(Linalg3DialectConstruction
|
||||
PUBLIC Linalg3)
|
|
@ -1,159 +0,0 @@
|
|||
//===- ConvertToLLVMDialect.cpp - conversion from Linalg to LLVM dialect --===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/LowerAffine.h"
|
||||
|
||||
#include "linalg1/ConvertToLLVMDialect.h"
|
||||
#include "linalg1/LLVMIntrinsics.h"
|
||||
|
||||
#include "linalg3/ConvertToLLVMDialect.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// Common functionality for Linalg LoadOp and StoreOp conversion to the
|
||||
// LLVM IR Dialect.
|
||||
template <typename Op> class LoadStoreOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit LoadStoreOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(Op::getOperationName(), 1, context) {}
|
||||
using Base = LoadStoreOpConversion<Op>;
|
||||
|
||||
// Compute the pointer to an element of the buffer underlying the view given
|
||||
// current view indices. Use the base offset and strides stored in the view
|
||||
// descriptor to emit IR iteratively computing the actual offset, followed by
|
||||
// a getelementptr.
|
||||
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
|
||||
ArrayRef<Value *> indices, Builder &rewriter) const {
|
||||
auto loadOp = cast<Op>(op);
|
||||
auto elementType =
|
||||
loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
|
||||
elementType = linalg::convertLinalgType(elementType)
|
||||
.template cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
||||
return rewriter.getI64ArrayAttr(values);
|
||||
};
|
||||
|
||||
using namespace intrinsics;
|
||||
|
||||
// Linearize subscripts as:
|
||||
// base_offset + SUM_i index_i * stride_i.
|
||||
Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
|
||||
for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
|
||||
Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
|
||||
Value *additionalOffset = mul(indices[i], stride);
|
||||
offset = add(offset, additionalOffset);
|
||||
}
|
||||
Value *base = extractvalue(elementType, viewDescriptor, pos(0));
|
||||
return gep(elementType, base, ArrayRef<Value *>{offset});
|
||||
}
|
||||
};
|
||||
|
||||
// A load is converted into the actual address computation, getelementptr and
|
||||
// an LLVM IR load.
|
||||
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
||||
using Base::Base;
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementType = linalg::convertLinalgType(*op->result_type_begin());
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
Value *element = intrinsics::load(elementType, ptr);
|
||||
rewriter.replaceOp(op, {element});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// A store is converted into the actual address computation, getelementptr and
|
||||
// an LLVM IR store.
|
||||
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
|
||||
using Base::Base;
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
Value *viewDescriptor = operands[1];
|
||||
Value *data = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front(2);
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
intrinsics::store(data, ptr);
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// A type conversion class that converts Linalg and Std types to LLVM.
|
||||
struct LinalgTypeConverter : public LLVMTypeConverter {
|
||||
using LLVMTypeConverter::LLVMTypeConverter;
|
||||
|
||||
// This gets called for block and region arguments, and attributes.
|
||||
Type convertType(Type t) override {
|
||||
if (auto result = LLVMTypeConverter::convertType(t))
|
||||
return result;
|
||||
return linalg::convertLinalgType(t);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Helper function that allocates the descriptor converters and adds load/store
|
||||
// converters to the list.
|
||||
static void populateLinalg3ToLLVMConversionPatterns(
|
||||
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
|
||||
patterns.insert<LoadOpConversion, StoreOpConversion>(context);
|
||||
}
|
||||
|
||||
LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
LinalgTypeConverter converter(module.getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, module.getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, module.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
|
||||
ConversionTarget target(*module.getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the constructor for the Linalg Dialect. This is
|
||||
// explicitly separated from the core library to allow incremental buildup of
|
||||
// the codebase for the tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<RangeType, ViewType>();
|
||||
addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
|
||||
ViewOp>();
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements linalg.load and linalg.store operations which allow
|
||||
// accessing memory through ViewType values.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/LoadStoreOps.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// LoadOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::LoadOp::build(Builder *b, OperationState &result, Value *view,
|
||||
ArrayRef<Value *> indices) {
|
||||
auto viewType = view->getType().cast<ViewType>();
|
||||
result.addOperands(view);
|
||||
result.addOperands(indices);
|
||||
result.addTypes(viewType.getElementType());
|
||||
}
|
||||
|
||||
void linalg::LoadOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " " << *getView() << '[';
|
||||
p.printOperands(getIndices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getViewType();
|
||||
}
|
||||
|
||||
ParseResult linalg::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult linalg::LoadOp::verify() {
|
||||
if (getNumOperands() == 0)
|
||||
return emitOpError("expected a view to load from");
|
||||
|
||||
auto viewType = getView()->getType().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return emitOpError("first operand must be a view");
|
||||
|
||||
if (getType() != viewType.getElementType())
|
||||
return emitOpError("result type must match element type of the view");
|
||||
|
||||
if (getRank() != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of indices for load");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
ViewType linalg::LoadOp::getViewType() {
|
||||
return getView()->getType().cast<ViewType>();
|
||||
}
|
||||
|
||||
unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// StoreOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::StoreOp::build(Builder *b, OperationState &result,
|
||||
Value *valueToStore, Value *view,
|
||||
ArrayRef<Value *> indices) {
|
||||
result.addOperands(valueToStore);
|
||||
result.addOperands(view);
|
||||
result.addOperands(indices);
|
||||
}
|
||||
|
||||
void linalg::StoreOp::print(OpAsmPrinter &p) {
|
||||
p << getOperationName() << " " << *getValueToStore();
|
||||
p << ", " << *getView() << '[';
|
||||
p.printOperands(getIndices());
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getViewType();
|
||||
}
|
||||
|
||||
ParseResult linalg::StoreOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
assert(false && "NYI");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult linalg::StoreOp::verify() {
|
||||
if (getNumOperands() < 2)
|
||||
return emitOpError("expected a value to store and a view");
|
||||
|
||||
// Second operand is a memref type.
|
||||
auto viewType = getView()->getType().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return emitOpError("second operand must be a view");
|
||||
|
||||
// First operand must have same type as memref element type.
|
||||
if (getValueToStore()->getType() != viewType.getElementType())
|
||||
return emitOpError("first operand must have same element type as the view");
|
||||
|
||||
if (getNumOperands() != 2 + viewType.getRank())
|
||||
return emitOpError("store index operand count not equal to view rank");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to store must have 'index' type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
ViewType linalg::StoreOp::getViewType() {
|
||||
return getView()->getType().cast<ViewType>();
|
||||
}
|
|
@ -1,223 +0,0 @@
|
|||
//===- TensorOps.cpp - Implementation of the linalg TensorOps operation ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create new tensor computation
|
||||
// operations in the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg3/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of DotOp.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
SmallVector<AffineMap, 8> linalg::DotOp::loopsToOperandRangeMaps() {
|
||||
// A(K), B(K), C()
|
||||
assert(getRanges(*this).size() == 2);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // K
|
||||
// A(K), B(K), C()
|
||||
// (d0) -> (d0, d0)(%k)
|
||||
return SmallVector<AffineMap, 8>{AffineMap::get(1, 0, {d0}), // A(K)
|
||||
AffineMap::get(1, 0, {d0}), // B(K)
|
||||
AffineMap()}; // C()
|
||||
}
|
||||
|
||||
void linalg::DotOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
|
||||
// Account for affine.terminator in loop.
|
||||
OpBuilder builder(body, std::prev(body->end(), 1));
|
||||
ScopedContext scope(builder, innermostLoop.getLoc());
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
IndexHandle r_i(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
ValueHandle cond = (r_i == zero);
|
||||
ValueHandle scalarC = select(cond, zerof, *C());
|
||||
C() = scalarC + A(r_i) * B(r_i);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of MatvecOp.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
SmallVector<AffineMap, 8> linalg::MatvecOp::loopsToOperandRangeMaps() {
|
||||
// A(M, K), B(K), C(M)
|
||||
assert(getRanges(*this).size() == 4);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // M
|
||||
auto d1 = getAffineDimExpr(1, context); // K
|
||||
// A(M, K), B(K), C(M)
|
||||
// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||
return SmallVector<AffineMap, 8>{AffineMap::get(2, 0, {d0, d1}), // A(M, K)
|
||||
AffineMap::get(2, 0, {d1}), // B(K)
|
||||
AffineMap::get(2, 0, {d0})}; // C(M)
|
||||
}
|
||||
|
||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||
// The body expression for dot is: C() = A(r_i) * B(r_i);
|
||||
// So we must drop the `i` loop from the matvec.
|
||||
void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vA, 0);
|
||||
assert(
|
||||
llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
|
||||
// clang-format off
|
||||
OpBuilder builder(op);
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
IndexHandle i;
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))(
|
||||
[&i, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceA = slice(vA, i, 0);
|
||||
ValueHandle sliceC = slice(vC, i, 0);
|
||||
dot(sliceA, vB, sliceC);
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void linalg::MatvecOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
// Account for affine.terminator in loop.
|
||||
OpBuilder builder(body, std::prev(body->end(), 1));
|
||||
ScopedContext scope(builder, innermostLoop.getLoc());
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
ValueHandle cond = (r_j == zero);
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i));
|
||||
C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of Matmul.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
SmallVector<AffineMap, 8> linalg::MatmulOp::loopsToOperandRangeMaps() {
|
||||
// A(M, K), B(K, N), C(M, N)
|
||||
assert(getRanges(*this).size() == 6);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // M
|
||||
auto d1 = getAffineDimExpr(1, context); // N
|
||||
auto d2 = getAffineDimExpr(2, context); // K
|
||||
// A(M, K), B(K, N), C(M, N):
|
||||
// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||
return SmallVector<AffineMap, 8>{
|
||||
AffineMap::get(3, 0, {d0, d2}), // A(M, K)
|
||||
AffineMap::get(3, 0, {d2, d1}), // B(K, N)
|
||||
AffineMap::get(3, 0, {d0, d1}) // C(M, N)
|
||||
};
|
||||
}
|
||||
|
||||
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
|
||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||
// So we must drop the `j` loop from the matmul.
|
||||
// This is fine because parallel dimensions permute: we can just do it
|
||||
// declaratively.
|
||||
void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vB, 1);
|
||||
assert(
|
||||
llvm::isa_and_nonnull<RangeOp>(indexingPosPair.first->getDefiningOp()));
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
// clang-format off
|
||||
OpBuilder builder(op);
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
IndexHandle j;
|
||||
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))(
|
||||
[&j, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceB = slice(vB, j, 1);
|
||||
ValueHandle sliceC = slice(vC, j, 1);
|
||||
matvec(vA, sliceB, sliceC);
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void linalg::MatmulOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
// Account for affine.terminator in loop.
|
||||
OpBuilder builder(body, std::prev(body->end(), 1));
|
||||
ScopedContext scope(builder, innermostLoop.getLoc());
|
||||
FloatType fTy = getOperand(0)
|
||||
->getType()
|
||||
.cast<ViewType>()
|
||||
.getElementType()
|
||||
.cast<FloatType>();
|
||||
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof =
|
||||
constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
|
||||
ValueHandle cond = r_k == zero;
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i, j));
|
||||
C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
|
||||
}
|
|
@ -1,305 +0,0 @@
|
|||
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements analyses and transformations for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
void linalg::composeSliceOps(mlir::FuncOp f) {
|
||||
f.walk([](SliceOp sliceOp) {
|
||||
auto *sliceResult = sliceOp.getResult();
|
||||
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
|
||||
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
||||
sliceOp.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void linalg::lowerToFinerGrainedTensorContraction(mlir::FuncOp f) {
|
||||
f.walk([](Operation *op) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
matmulOp.writeAsFinerGrainTensorContraction();
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
matvecOp.writeAsFinerGrainTensorContraction();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
op->erase();
|
||||
});
|
||||
}
|
||||
|
||||
// Folding eagerly is necessary to abide by affine.for static step requirement.
|
||||
// Returns nullptr if folding is not trivially feasible.
|
||||
static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
|
||||
assert(map.getNumResults() == 1 && "single result map expected");
|
||||
auto expr = map.getResult(0);
|
||||
if (auto dim = expr.dyn_cast<AffineDimExpr>())
|
||||
return operands[dim.getPosition()];
|
||||
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
|
||||
return operands[map.getNumDims() + sym.getPosition()];
|
||||
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
|
||||
return constant_index(cst.getValue());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value *linalg::makeFoldedComposedAffineApply(AffineMap map,
|
||||
ArrayRef<Value *> operandsRef) {
|
||||
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||
if (auto *v = tryFold(map, operands)) {
|
||||
return v;
|
||||
}
|
||||
auto &b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
return b.create<AffineApplyOp>(loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
linalg::RangeParts::RangeParts(unsigned reserved) {
|
||||
mins.reserve(reserved);
|
||||
maxes.reserve(reserved);
|
||||
steps.reserve(reserved);
|
||||
}
|
||||
|
||||
static SmallVector<Value *, 4>
|
||||
extractFromRanges(ArrayRef<Value *> ranges,
|
||||
std::function<Value *(RangeOp)> extract) {
|
||||
SmallVector<Value *, 4> res;
|
||||
res.reserve(ranges.size());
|
||||
for (auto *v : ranges) {
|
||||
auto r = cast<RangeOp>(v->getDefiningOp());
|
||||
res.push_back(extract(r));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
linalg::RangeParts::RangeParts(ArrayRef<Value *> ranges)
|
||||
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
|
||||
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
|
||||
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
|
||||
|
||||
SmallVector<Value *, 4> linalg::RangeParts::makeRanges() {
|
||||
SmallVector<Value *, 4> res;
|
||||
res.reserve(mins.size());
|
||||
for (auto z : llvm::zip(mins, maxes, steps)) {
|
||||
res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static RangeParts makeGenericRangeParts(AffineMap map,
|
||||
ArrayRef<Value *> ranges) {
|
||||
assert(map.getNumInputs() == ranges.size());
|
||||
unsigned numDims = map.getNumDims();
|
||||
assert(map.getNumSymbols() == 0);
|
||||
|
||||
RangeParts res(map.getNumResults());
|
||||
RangeParts rangeParts(ranges);
|
||||
for (auto expr : map.getResults()) {
|
||||
AffineMap map = AffineMap::get(numDims, 0, expr);
|
||||
res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
|
||||
res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
|
||||
res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
|
||||
ArrayRef<Value *> ranges) {
|
||||
return makeGenericRangeParts(map, ranges).makeRanges();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4>
|
||||
linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps,
|
||||
ArrayRef<Value *> ranges,
|
||||
ArrayRef<Value *> tileSizes) {
|
||||
RangeParts res = makeGenericRangeParts(operandRangesToLoopMaps, ranges);
|
||||
if (tileSizes.empty())
|
||||
return res.makeRanges();
|
||||
SmallVector<Value *, 4> tiledSteps;
|
||||
for (auto z : llvm::zip(res.steps, tileSizes)) {
|
||||
auto *step = std::get<0>(z);
|
||||
auto tileSize = std::get<1>(z);
|
||||
auto stepValue = cast<ConstantIndexOp>(step->getDefiningOp()).getValue();
|
||||
auto tileSizeValue =
|
||||
cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue();
|
||||
assert(stepValue > 0);
|
||||
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
|
||||
}
|
||||
res.steps = tiledSteps;
|
||||
return res.makeRanges();
|
||||
}
|
||||
|
||||
template <class ContractionOp>
|
||||
static SmallVector<mlir::AffineForOp, 4>
|
||||
writeContractionAsLoops(ContractionOp contraction) {
|
||||
OpBuilder builder(contraction.getOperation());
|
||||
ScopedContext scope(builder, contraction.getLoc());
|
||||
auto allRanges = getRanges(contraction);
|
||||
auto loopRanges =
|
||||
makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
|
||||
|
||||
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
|
||||
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
|
||||
auto pivs = makeIndexHandlePointers(parallelIvs);
|
||||
auto rivs = makeIndexHandlePointers(reductionIvs);
|
||||
assert(loopRanges.size() == pivs.size() + rivs.size());
|
||||
|
||||
// clang-format off
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
||||
[&contraction, ¶llelIvs, &reductionIvs] {
|
||||
SmallVector<mlir::Value *, 4> parallel(
|
||||
parallelIvs.begin(), parallelIvs.end());
|
||||
SmallVector<mlir::Value *, 4> reduction(
|
||||
reductionIvs.begin(), reductionIvs.end());
|
||||
contraction.emitScalarImplementation(parallel, reduction);
|
||||
});
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
// Return the AffineForOp for better compositionality (e.g. tiling).
|
||||
SmallVector<mlir::AffineForOp, 4> loops;
|
||||
loops.reserve(pivs.size() + rivs.size());
|
||||
for (auto iv : parallelIvs)
|
||||
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
||||
for (auto iv : reductionIvs)
|
||||
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
||||
|
||||
return loops;
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<mlir::AffineForOp, 4>>
|
||||
linalg::writeAsLoops(Operation *op) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
return writeContractionAsLoops(matmulOp);
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
return writeContractionAsLoops(matvecOp);
|
||||
} else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) {
|
||||
return writeContractionAsLoops(dotOp);
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToLoops(mlir::FuncOp f) {
|
||||
f.walk([](Operation *op) {
|
||||
if (writeAsLoops(op))
|
||||
op->erase();
|
||||
});
|
||||
}
|
||||
|
||||
/// Emits and returns the standard load and store ops from the view indexings.
|
||||
/// If the indexing is of index type, use it as an index to the load/store.
|
||||
/// If the indexing is a range, use range.min + indexing as an index to the
|
||||
/// load/store.
|
||||
template <typename LoadOrStoreOp>
|
||||
static SmallVector<Value *, 8>
|
||||
emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) {
|
||||
unsigned storeDim = 0;
|
||||
SmallVector<Value *, 8> operands;
|
||||
for (auto *indexing : viewOp.getIndexings()) {
|
||||
if (indexing->getType().isa<IndexType>()) {
|
||||
operands.push_back(indexing);
|
||||
continue;
|
||||
}
|
||||
RangeOp range = cast<RangeOp>(indexing->getDefiningOp());
|
||||
ValueHandle min(range.getMin());
|
||||
Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++);
|
||||
using edsc::op::operator+;
|
||||
operands.push_back(min + ValueHandle(storeIndex));
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Rewriting linalg::LoadOp and linalg::StoreOp to mlir::LoadOp and
|
||||
/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
|
||||
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
|
||||
/// fold away all the SliceOp.
|
||||
template <typename LoadOrStoreOpTy>
|
||||
struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
|
||||
using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
|
||||
|
||||
/// Performs the rewrite.
|
||||
PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct LowerLinalgLoadStorePass
|
||||
: public FunctionPass<LowerLinalgLoadStorePass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.insert<Rewriter<linalg::LoadOp>, Rewriter<linalg::StoreOp>>(
|
||||
context);
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
PatternMatchResult
|
||||
Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
|
||||
PatternRewriter &rewriter) const {
|
||||
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
|
||||
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
||||
: cast<ViewOp>(load.getView()->getDefiningOp());
|
||||
OpBuilder builder(load);
|
||||
ScopedContext scope(builder, load.getLoc());
|
||||
auto *memRef = view.getSupportingMemRef();
|
||||
auto operands = emitAndReturnLoadStoreOperands(load, view);
|
||||
rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
template <>
|
||||
PatternMatchResult
|
||||
Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
|
||||
PatternRewriter &rewriter) const {
|
||||
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
|
||||
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
||||
: cast<ViewOp>(store.getView()->getDefiningOp());
|
||||
OpBuilder builder(store);
|
||||
ScopedContext scope(builder, store.getLoc());
|
||||
auto *valueToStore = store.getValueToStore();
|
||||
auto *memRef = view.getSupportingMemRef();
|
||||
auto operands = emitAndReturnLoadStoreOperands(store, view);
|
||||
rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
|
||||
operands);
|
||||
return matchSuccess();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> linalg::createLowerLinalgLoadStorePass() {
|
||||
return std::make_unique<LowerLinalgLoadStorePass>();
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
add_subdirectory(lib)
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
)
|
||||
|
||||
set(LLVM_OPTIONAL_SOURCES Example.cpp)
|
||||
|
||||
add_llvm_example(linalg-example-4
|
||||
Example.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(linalg-example-4
|
||||
PRIVATE
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
Linalg4
|
||||
Linalg3DialectConstruction
|
||||
)
|
||||
|
||||
whole_archive_link(linalg-example-4
|
||||
MLIRAffineOps
|
||||
MLIRStandardOps
|
||||
)
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
//===- Example.cpp - Our running example ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/test | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "linalg4/Transforms.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using llvm::StringRef;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
FuncOp makeFunctionWithAMatmulOp(ModuleOp module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::FuncOp f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_tiled_loops) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops");
|
||||
lowerToTiledLoops(f, {8, 9});
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createLowerLinalgLoadStorePass());
|
||||
if (succeeded(pm.run(f.getParentOfType<mlir::ModuleOp>())))
|
||||
cleanupAndPrintFunction(f);
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_tiled_loops(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[N]] step 9 {
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[K]] {
|
||||
// CHECK: affine.for %{{.*}} = max (d0) -> (0, d0)(%{{.*}}) to min (d0)[s0] -> (s0, d0 + 8)(%{{.*}})[%[[M]]] {
|
||||
// CHECK: affine.for %{{.*}} = max (d0) -> (0, d0)(%{{.*}}) to min (d0)[s0] -> (s0, d0 + 9)(%{{.*}})[%[[N]]] {
|
||||
// CHECK-NEXT: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_tiled_views) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views");
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
composeSliceOps(f);
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_tiled_views(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[N]] step 9 {
|
||||
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%{{.*}})
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %{{.*}}:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%{{.*}})
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %{{.*}}:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?x?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_tiled_views_as_loops) {
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
mlir::FuncOp f =
|
||||
makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops");
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
composeSliceOps(f);
|
||||
lowerToLoops(f);
|
||||
// This cannot lower below linalg.load and linalg.store due to lost
|
||||
// information related to loop bounds and tiling. There are multiple ways to
|
||||
// attack the problem, the best one is an IR change.
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_tiled_views_as_loops(%{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to %[[M]] step 8 {
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[N]] step 9 {
|
||||
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%{{.*}})
|
||||
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %{{.*}}:%[[i0max]]:{{.*}} : !linalg.range
|
||||
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[vA:.*]] = linalg.view %{{.*}}[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%{{.*}})
|
||||
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %{{.*}}:%[[i1max]]:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: affine.for %{{.*}} = (d0) -> (d0)(%{{.*}}) to (d0) -> (d0)(%[[i0max]]) {
|
||||
// CHECK-NEXT: affine.for %{{.*}} = (d0) -> (d0)(%{{.*}}) to (d0) -> (d0)(%[[i1max]]) {
|
||||
// CHECK-NEXT: affine.for %{{.*}} = 0 to %[[K]] {
|
||||
// CHECK-NEXT: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
int main() {
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
//===- Transforms.h - Linalg dialect Transformations definition -----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG4_TRANSFORMS_H_
|
||||
#define LINALG4_TRANSFORMS_H_
|
||||
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Rewrites a linalg `op` in tiled loop form and erases `op`.
|
||||
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 8>>
|
||||
writeAsTiledLoops(mlir::Operation *op, llvm::ArrayRef<uint64_t> tileSizes);
|
||||
|
||||
/// Rewrites a linalg `op` in tiled view form and erases `op`.
|
||||
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 8>>
|
||||
writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes);
|
||||
|
||||
/// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function
|
||||
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||
/// in a function can generally not be specified.
|
||||
void lowerToTiledLoops(mlir::FuncOp f, llvm::ArrayRef<uint64_t> tileSizes);
|
||||
|
||||
/// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function
|
||||
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||
/// in a function can generally not be specified.
|
||||
void lowerToTiledViews(mlir::FuncOp f, llvm::ArrayRef<mlir::Value *> tileSizes);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG4_TRANSFORMS_H_
|
|
@ -1,16 +0,0 @@
|
|||
add_llvm_library(Linalg4
|
||||
Transforms.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(Linalg4
|
||||
PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
Linalg3
|
||||
)
|
|
@ -1,200 +0,0 @@
|
|||
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements analyses and transformations for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg4/Transforms.h"
|
||||
#include "linalg3/Intrinsics.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::SmallVector;
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
llvm::Optional<SmallVector<mlir::AffineForOp, 8>>
|
||||
linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) {
|
||||
auto loops = writeAsLoops(op);
|
||||
if (loops.hasValue())
|
||||
return mlir::tile(*loops, tileSizes, loops->back());
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToTiledLoops(mlir::FuncOp f, ArrayRef<uint64_t> tileSizes) {
|
||||
f.walk([tileSizes](Operation *op) {
|
||||
if (writeAsTiledLoops(op, tileSizes).hasValue())
|
||||
op->erase();
|
||||
});
|
||||
}
|
||||
|
||||
static bool isZeroIndex(Value *v) {
|
||||
return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
|
||||
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
|
||||
}
|
||||
|
||||
template <typename ConcreteOp>
|
||||
static llvm::SmallVector<Value *, 4>
|
||||
makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
|
||||
llvm::ArrayRef<Value *> tileSizes) {
|
||||
assert(ivs.size() == tileSizes.size());
|
||||
if (ivs.empty())
|
||||
return RangeParts(allRanges).makeRanges();
|
||||
|
||||
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||
RangeParts result(allRanges.size());
|
||||
RangeParts rangeParts(allRanges);
|
||||
|
||||
for (auto map : op->loopsToOperandRangeMaps()) {
|
||||
// 1. Take the first ivs results of the map, the other ones are not composed
|
||||
// but merely copied over.
|
||||
assert(map.getNumSymbols() == 0);
|
||||
MLIRContext *context = ScopedContext::getContext();
|
||||
unsigned numParallel = op->getNumParallelDims();
|
||||
unsigned numReduction = op->getNumReductionDims();
|
||||
if (ivs.size() < numParallel + numReduction) {
|
||||
// Inject zeros in positions that are not tiled.
|
||||
SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
|
||||
for (unsigned i = 0, e = numParallel + numReduction; i < e; ++i) {
|
||||
dimReplacements[i] = (i < ivs.size())
|
||||
? getAffineDimExpr(i, context)
|
||||
: getAffineConstantExpr(0, context);
|
||||
}
|
||||
map = map.replaceDimsAndSymbols(dimReplacements, {}, ivs.size(), 0);
|
||||
}
|
||||
|
||||
// 2. Apply the rewritten map to the ranges.
|
||||
unsigned numDims = map.getNumDims();
|
||||
for (auto en : llvm::enumerate(map.getResults())) {
|
||||
auto index = en.index();
|
||||
auto expr = en.value();
|
||||
AffineMap exprMap = AffineMap::get(numDims, 0, expr);
|
||||
ValueHandle offset(makeFoldedComposedAffineApply(exprMap, ivs));
|
||||
// Offset is normally a function of loop induction variables.
|
||||
// If it is 0, it must come from a dimension that was not tiled.
|
||||
if (isZeroIndex(offset)) {
|
||||
result.mins.push_back(rangeParts.mins[index]);
|
||||
result.maxes.push_back(rangeParts.maxes[index]);
|
||||
continue;
|
||||
}
|
||||
|
||||
ValueHandle step(makeFoldedComposedAffineApply(exprMap, tileSizes));
|
||||
ValueHandle min(rangeParts.mins[index]);
|
||||
using edsc::op::operator+;
|
||||
result.mins.push_back(min + offset);
|
||||
// Ideally this should be:
|
||||
// `min(rangeParts.max, rangeParts.min + offset + step)`
|
||||
// but that breaks the current limitations of the affine dialect.
|
||||
result.maxes.push_back(min + offset + step);
|
||||
}
|
||||
}
|
||||
// Note that for the purpose of tiled ranges and views, the steps do not
|
||||
// change in our representation.
|
||||
result.steps = rangeParts.steps;
|
||||
|
||||
return result.makeRanges();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static SmallVector<Value *, 4>
|
||||
makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
|
||||
auto tiledRanges =
|
||||
makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
|
||||
SmallVector<Value *, 4> res;
|
||||
unsigned currentRange = 0;
|
||||
for (auto *in : contraction.getInputsAndOutputs()) {
|
||||
unsigned runningSliceDim = 0;
|
||||
auto *runningSlice = in;
|
||||
assert(runningSlice->getType().template isa<ViewType>());
|
||||
for (unsigned d = 0, e = getViewRank(runningSlice); d < e; ++d) {
|
||||
auto *r = tiledRanges[currentRange++];
|
||||
runningSlice = slice(runningSlice, r, runningSliceDim++).getValue();
|
||||
}
|
||||
res.push_back(runningSlice);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static SmallVector<mlir::AffineForOp, 8>
|
||||
writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> tileSizes) {
|
||||
assert(tileSizes.size() <=
|
||||
contraction.getNumParallelDims() + contraction.getNumReductionDims());
|
||||
|
||||
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||
mlir::OpBuilder builder(op->getOperation());
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
|
||||
auto pivs = makeIndexHandlePointers(ivs);
|
||||
|
||||
// clang-format off
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
||||
getRanges(contraction), tileSizes);
|
||||
linalg::common::LoopNestRangeBuilder(pivs, ranges)(
|
||||
[&contraction, &tileSizes, &ivs]() {
|
||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||
auto views = makeTiledViews(contraction, ivValues, tileSizes);
|
||||
ScopedContext::getBuilder().create<ConcreteOp>(
|
||||
ScopedContext::getLocation(), views);
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
SmallVector<mlir::AffineForOp, 8> res;
|
||||
res.reserve(ivs.size());
|
||||
for (auto iv : ivs)
|
||||
res.push_back(getForInductionVarOwner(iv.getValue()));
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<mlir::AffineForOp, 8>>
|
||||
linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
return writeContractionAsTiledViews(matmulOp, tileSizes);
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
return writeContractionAsTiledViews(matvecOp, tileSizes);
|
||||
} else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) {
|
||||
return writeContractionAsTiledViews(dotOp, tileSizes);
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToTiledViews(mlir::FuncOp f, ArrayRef<Value *> tileSizes) {
|
||||
f.walk([tileSizes](Operation *op) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
writeAsTiledViews(matmulOp, tileSizes);
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
writeAsTiledViews(matvecOp, tileSizes);
|
||||
} else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) {
|
||||
writeAsTiledViews(dotOp, tileSizes);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
op->erase();
|
||||
});
|
||||
}
|
|
@ -1,260 +0,0 @@
|
|||
# Linalg Dialect
|
||||
|
||||
This chapter describes the design and implementation of a simple linear algebra
|
||||
dialect in MLIR. The objective of the `linalg` dialect is to demonstrate that
|
||||
the MLIR infrastructure is a great fit for implementing high-level operations
|
||||
and lower them gradually to LLVM by reusing existing components and lowering
|
||||
paths. In particular, `linalg` is built upon the type system of the
|
||||
[`affine`](../../Dialects/Affine.md) dialect, which allows partial lowering to
|
||||
be implemented with relative ease.
|
||||
|
||||
The `linalg` dialect is introduced gradually following this outline:
|
||||
|
||||
1. Type system and type-building operations.
|
||||
2. Compute operations.
|
||||
3. Lowerings between the `linalg` operations into `linalg` + `affine`
|
||||
operations.
|
||||
4. Tiling transformations.
|
||||
5. A simple tiling and fusion transformation.
|
||||
|
||||
The Toy language tutorial already introduced core MLIR concepts and best
|
||||
practices, the `linalg` dialect operates mostly at the level of the C++ API and
|
||||
in particular makes use of [declarative builders](DeclarativeBuilders.md), for
|
||||
terser IR emitting expressions. Without loss of generality, anything in this
|
||||
section can also be implemented with `mlir::Builder` and enough
|
||||
`getInsertionPoint` and `setInsertionPoint` manipulations.
|
||||
|
||||
The implementation follows a few conventions to decouple, at each step, the
|
||||
newly introduced concepts and code from ones introduced previously without
|
||||
duplicating the whole code base in each directory. The code for concepts
|
||||
introduced at a particular step `k` live in the `Linalgk/include/linalgk` and
|
||||
`Linalgk/lib` directories and is linked into the `Linalgk` library.
|
||||
|
||||
Lastly, note that simplifying assumptions are made to cut down on boilerplate
|
||||
and help focus on the core concepts. In particular, parsing the linalg dialect
|
||||
is currently not supported as it is used as an intermediary dialect. This does
|
||||
not impact the ability to lower all the way to LLVM with proper verified IR at
|
||||
each step of the lowering, or to execute the compiled binary.
|
||||
|
||||
# Linalg Part 1: Type system
|
||||
|
||||
We first describe the `linalg` type system.
|
||||
|
||||
## RangeType and RangeOp
|
||||
|
||||
A
|
||||
[RangeType](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/include/linalg1/RangeType.h)
|
||||
is a simple triple of `index` values. It represents a minimal range abstraction
|
||||
`(min, max, step)`. `RangeType` is a fully defined type and is constructed
|
||||
without any additional type argument. Its implementation illustrates the minimal
|
||||
amount of information required to implement a new custom MLIR type.
|
||||
|
||||
```
|
||||
class RangeType : public mlir::Type::TypeBase<RangeType, mlir::Type> {
|
||||
public:
|
||||
// Used to implement llvm-style cast.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static RangeType get(mlir::MLIRContext *context) {
|
||||
/// Custom, uniqu'ed construction in the mlir::MLIRContext.
|
||||
return Base::get(context, LinalgTypes::Range);
|
||||
}
|
||||
/// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
|
||||
};
|
||||
```
|
||||
|
||||
Unlike more complex types, RangeType does not require a hashing key for
|
||||
uniquing in the `MLIRContext`. Note that all MLIR types derive from
|
||||
`mlir::Type::TypeBase` and expose `using Base::Base` to enable generic hooks to
|
||||
work properly (in this instance for llvm-style casts. RangeType does not even
|
||||
require an implementation file as the above represents the whole code for the
|
||||
type.
|
||||
|
||||
The `linalg` dialect type `RangeType` pretty-prints simply as `!linalg.range`.
|
||||
|
||||
A `linalg::RangeOp`, defined
|
||||
[here](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/include/linalg1/RangeOp.h),
|
||||
is the operation that produces ssa-values of `RangeType`. It pretty-prints as
|
||||
|
||||
```
|
||||
%0 = linalg.range %min, %max, %range : !linalg.range
|
||||
```
|
||||
|
||||
The implementation of the `RangeOp::build` method and `RangeOp::verify`
|
||||
[methods](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/lib/RangeOp.cpp)
|
||||
are straightforward.
|
||||
|
||||
A RangeType is used throughout to step over iteration domains (i.e. loop
|
||||
iterations via loop bounds and steps) as well as over the view data abstraction.
|
||||
A `LoopNestRangeBuilder` helper class is
|
||||
[introduced](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/include/linalg1/Common.h)
|
||||
to allow emission of loop nests from an `llvm::ArrayRef<mlir::Value*>` where
|
||||
each `mlir::Value` is a `linalg.range`.
|
||||
|
||||
### Simplifying assumption
|
||||
|
||||
The `linalg.range` type is generally unrestricted beyond having elements of
|
||||
`index` type. However it is used to build loop nests using the `affine.for`
|
||||
[operation](../../Dialects/Affine.md) whose restrictions it inherits, at the
|
||||
point where `affine.for` operations are materialized. This is a tradeoff to
|
||||
reuse existing MLIR operations that are already known to lower to LLVM. As a
|
||||
consequence, the `step` in a `linalg.range` must be a static constant and cannot
|
||||
be symbolic.
|
||||
|
||||
## ViewType and ViewOp
|
||||
|
||||
A
|
||||
[ViewType](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/include/linalg1/ViewType.h)
|
||||
represents a multi-dimensional range abstraction to iterate over an underlying
|
||||
storage type. It is backed by a data type, in our case objects of
|
||||
[MemRefType](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/StandardTypes.h).
|
||||
A ViewType is a parameterized type which has a base element type and a rank. It
|
||||
is thus slightly more complex than RangeType and requires unique'ing in the
|
||||
enclosing MLIRContext.
|
||||
|
||||
This is materialized by the existence of a storage type and a `hashKey` in the
|
||||
implementation
|
||||
[file](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/lib/ViewType.cpp).
|
||||
|
||||
```
|
||||
struct ViewTypeStorage : public mlir::TypeStorage {
|
||||
/// Underlying Key type to transport the payload needed to construct a custom
|
||||
/// type in a generic way.
|
||||
struct Key {
|
||||
Key(Type elementType, unsigned rank)
|
||||
: elementType(elementType), rank(rank) {}
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
...
|
||||
};
|
||||
```
|
||||
|
||||
The `ViewTypeStorage` is not visible outside of the `ViewType` implementation
|
||||
and is referred to from `ViewType` as such: `class ViewType : public
|
||||
mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> { ... }`
|
||||
|
||||
A two dimensional ViewType over a f32 storage pretty-prints as `view<?x?xf32>`.
|
||||
|
||||
A `linalg::ViewOp`, defined
|
||||
[here](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/lib/ViewOp.cpp),
|
||||
is the operation that produces ssa-values of `ViewType` from an ssa-value of
|
||||
type `MemRefType`. A ViewOp has operands called "indexings" which can be either
|
||||
of `index` or `!linalg.range` type. The rationale is that `index` reduces the
|
||||
rank of a ViewType by 1 while a `!linalg.range` keeps the rank unchanged. This
|
||||
behavior is a convention that we have found useful during the implementation in
|
||||
order to fold chains of slice operations (introduced in the following paragraph)
|
||||
and capture enough information in the ViewOp so it can be lowered to LLVM.
|
||||
|
||||
The entry point to the builder is the method: `static void
|
||||
ViewOp::build(mlir::Builder *b, mlir::OperationState &result, mlir::Value
|
||||
*memRef, llvm::ArrayRef<mlir::Value *> indexings = {});`
|
||||
|
||||
A `ViewOp` pretty-prints as: `%1 = linalg.view %0[%m, %n, %k] :
|
||||
!linalg.view<?x?xf32>`
|
||||
|
||||
This signifies that `%0` is a three dimensional `MemRef` of `f32` elemental type
|
||||
and that the `%1` view uses an `index` into one of the dimensions and two
|
||||
`!linalg.range` for the two other dimensions.
|
||||
|
||||
The implementation of the `ViewOp::build` and `ViewOp::verify`
|
||||
[methods](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/lib/ViewOp.cpp)
|
||||
are simple.
|
||||
|
||||
### Simplifying assumption
|
||||
|
||||
We choose to reuse the existing MLIR
|
||||
`MemRef`[type](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/StandardTypes.h)
|
||||
as the underlying data structure. This avoids the need to redefine a new
|
||||
abstraction and simplifies lowering all the way to LLVM.
|
||||
|
||||
## SliceOp
|
||||
|
||||
A slice is a subview that is fully contained within its parent view and is
|
||||
constructed using a `SliceOp`. A SliceOp takes an ssa-value of type
|
||||
`linalg.view` and an "indexing" to produce a new `linalg.view` of rank:
|
||||
|
||||
1. Equal to the rank of the original view, if the indexing is a
|
||||
`!linalg.range`.
|
||||
2. Equal to the rank of the original view minus one, if the indexing is an
|
||||
`index`.
|
||||
|
||||
A slice op has an integer attribute which specifies the dimension of the parent
|
||||
view it slices and pretty-prints as:
|
||||
|
||||
```
|
||||
%2 = linalg.slice %1[*, *, %0, *] : !linalg.view<?x?x?xf32>
|
||||
```
|
||||
|
||||
In this particular case, %2 slices dimension `2` of the four-dimensional view
|
||||
%1. The returned `!linalg.view<?x?x?xf32>` indicates that the indexing is
|
||||
rank-reducing and that %0 is an `index`.
|
||||
|
||||
The implementation of the `SliceOp::build` and `SliceOp::verify`
|
||||
[methods](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg1/lib/SliceOp.cpp)
|
||||
are simple.
|
||||
|
||||
### Simplifying assumption
|
||||
|
||||
In this tutorial we do not enforce the strict subview property or perform bounds
|
||||
check analysis and instead assume that the code is correct by construction.
|
||||
|
||||
## Notable remarks
|
||||
|
||||
The declaration for the classes implementing the operations we described have
|
||||
common traits that enable certain API shortcuts and other behaviors. For
|
||||
instance, the `mlir::OpTrait::OneResult` makes the `getResult()` method
|
||||
available to the class.
|
||||
|
||||
```
|
||||
|
||||
class RangeOp : public mlir::Op<RangeOp, mlir::OpTrait::NOperands<3>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> { ... };
|
||||
|
||||
class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> { ... } ;
|
||||
|
||||
class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> { ... };
|
||||
```
|
||||
|
||||
One particular trait of interest is `mlir::OpTrait::HasNoSideEffect` which
|
||||
enables constant folding and dead code elimination in the `canonicalizerPass`.
|
||||
|
||||
## Dialect Registration
|
||||
|
||||
Similarly to Toy, the dialect must be registered so that the pretty-printer and
|
||||
verifier can be enabled. Without registration, only the custom op form can be
|
||||
printed. Beware of ops printed in custom op form, when a shorthand form exists,
|
||||
because there is a high chance the IR verification is not enabled.
|
||||
|
||||
To register the Linalg dialect, call
|
||||
`mlir::registerDialect<linalg::LinalgDialect>();`.
|
||||
|
||||
### Note on code organization
|
||||
|
||||
Registration occurs by constructing a new `LinalgDialect` which registers the
|
||||
proper types and ops at construction time, with sanity checks guarding against
|
||||
multiple registrations of the same symbols. At that point, the constructor needs
|
||||
to be statically aware of all the types and ops. Since our code structure
|
||||
chooses to isolate independent portions of the tutorial, and certain ops are
|
||||
introduced in later parts, we explicitly separate `DialectConstruction.cpp` in
|
||||
its' separate library. Linking with the proper library enables the types that
|
||||
have been declared so far.
|
||||
|
||||
## Putting it all together
|
||||
|
||||
We create a `linalg1-opt` executable which links together `Linalg1` and the core
|
||||
`MlirOptLib` library to add traditional compiler support for file handling,
|
||||
parsing, command-line interface etc. The FileCheck'd test
|
||||
[example](https://github.com/tensorflow/mlir/blob/master/test/Examples/Linalg/Linalg1.mlir)
|
||||
demonstrates parsing, verification, pretty printing of the IR we have
|
||||
constructed so far. We introduce a custom op called `some_consumer` to ensure
|
||||
that dead-code elimination does not optimize these simple examples out of
|
||||
existence, in the case an extra -canonicalize option is passed to `linalg1-opt`.
|
||||
When called with `lower-linalg-to-llvm`, the test uses the
|
||||
[LLVM conversion](LLVMConversion.md) mechanisms.
|
|
@ -1,98 +0,0 @@
|
|||
# Linalg Part 2: Compute Operations
|
||||
|
||||
We now describe the main compute operations `linalg.dot`, `linalg.matvec` and
|
||||
`linalg.matmul`. These operations are a subset of a more general tensor
|
||||
contraction
|
||||
[class](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg2/include/linalg2/TensorOps.h)
|
||||
of operations. In this tutorial, we define a tensor contraction as a generic
|
||||
operation which:
|
||||
|
||||
1. Reads a `getNumInputs()` number of input ssa-values of ViewType.
|
||||
2. Writes into a `getNumOutputs()` number of input ssa-values of ViewType.
|
||||
3. Can be written in scalar loop form as a perfect loop nest with
|
||||
`getNumParallelDims()` outermost loops with parallel semantics and
|
||||
`getNumReductionDims()` innermost dimensions with reduction semantics.
|
||||
4. Has a scalar form that is specific to each particular specialization.
|
||||
|
||||
## Operation Definition
|
||||
|
||||
In this section we do not discuss the specific properties of tensor contractions
|
||||
but only define the `linalg.dot`, `linalg.matvec` and `linalg.matmul` operations
|
||||
as opaque operations with side-effects (reads and writes into input and output
|
||||
views).
|
||||
|
||||
These operations take input and output views of the proper rank as operands. For
|
||||
the purpose of illustration, assume all the elemental types are fixed to `f32`.
|
||||
The invariants checked by the op-specific
|
||||
[verify](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg2/lib/TensorOps.cpp)
|
||||
functions are:
|
||||
|
||||
1. `linalg.dot` reads two one-dimensional `view<?xf32>` and writes a
|
||||
zero-dimensional `view<f32>` (i.e. a scalar).
|
||||
2. `linalg.matvec` reads a two-dimensional `view<?x?xf32>` matrix and a one
|
||||
dimensional `view<?xf32>` vector and writes a one-dimensional `view<?xf32>`.
|
||||
3. `linalg.matmul` reads two two-dimensional `view<?x?xf32>` matrices and
|
||||
writes a two-dimensional `view<?x?xf32>` matrix.
|
||||
|
||||
Other operations on higher-order tensors can be defined and would behave
|
||||
similarly with respect to IR verification and interactions with ViewType
|
||||
operands. The generic form of verification and pretty-printing is defined on the
|
||||
`TensorContractionBase`
|
||||
[class](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg2/include/linalg2/TensorOps.h).
|
||||
|
||||
Note that in order to give TensorContractionBase access to the mlir::Op in a
|
||||
generic fashion, we use a CRTP pattern where:
|
||||
|
||||
```
|
||||
template <class ConcreteOp> class TensorContractionBase { ... };
|
||||
|
||||
class DotOp : public TensorContractionBase<DotOp>,
|
||||
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> { ... }
|
||||
```
|
||||
|
||||
In turn, this allows the generic behavior of TensorContractionBase to be
|
||||
implemented once and reused across ops. The generic verify method is:
|
||||
|
||||
```
|
||||
template <class ConcreteOp>
|
||||
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
|
||||
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
if (getNumInputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one input");
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Each specialized operation then calls into the generic verification method
|
||||
before applying its own verification steps.
|
||||
|
||||
```
|
||||
LogicalResult linalg::MatmulOp::verify() {
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
unsigned index = 0;
|
||||
for (auto *v : {A, B, C}) {
|
||||
if (getViewRank(v) != 2)
|
||||
return emitOpError("operand " + Twine(index++) + " must be of rank 2");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
```
|
||||
|
||||
Note that in a more future-proof design, it is considered a best practice for
|
||||
operations which share similarity in their behavior to be defined with Tablegen.
|
||||
|
||||
All TensorContractionBase ops pretty-print similarly. In the case of
|
||||
`linalg.matmul` the pretty-printed form is: `linalg.matmul(%A, %B, %C) :
|
||||
view<?x?xf32>`
|
||||
|
||||
## Putting it all together
|
||||
|
||||
The
|
||||
[example](https://github.com/tensorflow/mlir/blob/master/examples/Linalg/Linalg2/Example.cpp)
|
||||
demonstrates how to construct some simple IR snippets that pass through the
|
||||
verifier checks. The example demonstrate how to allocate three memref buffers
|
||||
from `index` function arguments and use those buffers as backing data structures
|
||||
for views that get passed to `dot`, `matvec` and `matmul` operations.
|
|
@ -1,643 +0,0 @@
|
|||
# Conversion to the LLVM IR Dialect
|
||||
|
||||
This chapter in the tutorial uses the DialectConversion framework to implement
|
||||
the conversion from the Linalg dialect to the [LLVM IR](../../Dialects/LLVM.md)
|
||||
dialect. This framework is a part of a more general pattern rewriting
|
||||
infrastructure available in MLIR. Its key feature is the ability to update
|
||||
function signatures and function/block argument types, along with the
|
||||
pattern-based operation rewriting patterns.
|
||||
|
||||
## Structure of a Dialect Conversion
|
||||
|
||||
The Dialect Conversion framework comprises three components:
|
||||
|
||||
1. Type conversion function.
|
||||
2. (Optional) function signature conversion function.
|
||||
3. Operation conversion patterns.
|
||||
|
||||
The function signature conversion has a default implementation that performs
|
||||
type conversion individually for each of the function arguments and results. A
|
||||
custom implementation is required when function signature can change when
|
||||
switching dialects, for example to include dialect-specific attributes or to
|
||||
accommodate calling conventions.
|
||||
|
||||
## Linalg to LLVM IR Conversion
|
||||
|
||||
Let us illustrate how one can use the Dialect Conversion framework using Linalg
|
||||
to LLVM IR Conversion and defining the three components listed above.
|
||||
|
||||
Instead of performing progressive lowering from Linalg to Standard dialect, we
|
||||
will define the semantics of Linalg types and type-related operations in terms
|
||||
of their LLVM IR counterparts.
|
||||
|
||||
### Linalg Types to LLVM IR Types
|
||||
|
||||
#### Range Type
|
||||
|
||||
The Linalg Range abstraction is a triple of size values representing `min`,
|
||||
`max` and `step` of the space of iteration (address or loop). This easily maps
|
||||
to the LLVM IR's structure type containing three integers: `{i64, i64, i64}`,
|
||||
assuming that 64 bits are sufficient to hold a size.
|
||||
|
||||
In a conversion function, this can be implemented by checking if the input type
|
||||
is indeed `linalg::RangeType` and constructing the corresponding LLVM IR dialect
|
||||
type. The LLVM IR dialect types are merely LLVM IR types wrapped into an MLIR
|
||||
object.
|
||||
|
||||
```c++
|
||||
Type linalg::convertLinalgType(Type t) {
|
||||
// Obtain the MLIR context and the MLIR LLVM IR dialect. The latter stores
|
||||
// the LLVM context that is necessary to construct types.
|
||||
auto *context = t.getContext();
|
||||
auto *dialect =
|
||||
static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm"));
|
||||
|
||||
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
|
||||
// Create the LLVM IR i64 type.
|
||||
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
|
||||
// Create the LLVM IR structure type containing three i64.
|
||||
auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
|
||||
// Wrap the LLVM IR type into the MLIR LLVM IR dialect.
|
||||
return LLVM::LLVMType::get(context, structTy);
|
||||
}
|
||||
|
||||
// Leave an unknown type as is.
|
||||
return t;
|
||||
}
|
||||
```
|
||||
|
||||
#### View Type
|
||||
|
||||
Converting a Linalg View type requires more careful consideration. First, a View
|
||||
is a container type whose element must be converted as well. Second, it has a
|
||||
defined semantics that its representation should respect. In particular, a View
|
||||
is an abstraction around MLIR's standard `memref` type, which itself represents
|
||||
a pointer with size information attached. Accessing an element through a view
|
||||
means accessing the same buffer but with additional index arithmetics.
|
||||
|
||||
We start by re-postulating the features of the underlying data type, `memref`. A
|
||||
`memref` is a contiguous array of data indexed by multiple values, typically
|
||||
stored in a row-major format. We will base our representation of a View on a
|
||||
notion of _stride_ as used in some machine learning frameworks. A _stride_ along
|
||||
a given indexing dimension is the number of contiguous elements that separate
|
||||
two elements with adjacent indices along the given dimension. For example, an
|
||||
`2x6x4` array will have strides `6*4=24`, `4` and `1`. Strides immediately allow
|
||||
us to capture the _step_ of the range used to construct the array: the step is
|
||||
reflected as an additional multiplier in the stride, making View "step over"
|
||||
elements.
|
||||
|
||||
A View will contain as many strides as it has dimensions. For rank-reducing
|
||||
strides, this allows one to simply remove the stride of the dimension that is
|
||||
not included in the view. For example, taking a view that projects away the
|
||||
middle dimension from a `2x6x4` array will give one strides `24` and `1` over
|
||||
the original buffer.
|
||||
|
||||
In addition to steps, ranges used to create a View can impose a lower and an
|
||||
upper bound on the indices along each dimension. These indices are necessary for
|
||||
two cases: (a) computing the indices in the original array given the indices in
|
||||
the view and (b) verifying out-of-bounds accesses and overflows on loads and
|
||||
stores. For the former purpose, we will introduce the _linearized offset_ below.
|
||||
For the latter purpose, we will store the _size_ along the given dimension, i.e.
|
||||
the difference between the maximum and the minimum value of the indices in the
|
||||
range.
|
||||
|
||||
Finally, we need to account for rank-reducing views that fix the projected away
|
||||
index at a specific value. This cannot be implemented as by keeping the `min`
|
||||
value for all the projected away dimensions because it would make the
|
||||
representation of Views obtained from `memref`'s of different ranks different,
|
||||
defying the purpose of Views. Instead, we will keep only a single value
|
||||
representing the (linearized) offset of the first contiguous element that can be
|
||||
accessed. For example, if the second index in a `2x6x4` array was fixed to `3`
|
||||
when producing a 2D view, the offset will be `3*4=12` elements. Adding the
|
||||
strides to this offset will let one access the other elements in the view. Since
|
||||
addresses are linearized anyway, and since we cannot have a rank-expanding view
|
||||
by construction, it is sufficient to store a single linearized offset.
|
||||
|
||||
For the sake of simplicity, we will store the offset separate from the buffer
|
||||
pointer. Combining the two can save the space required for storing the data but
|
||||
make functionality like alias analysis more complex. Implementing such a
|
||||
combination is left as an exercise to the reader.
|
||||
|
||||
Bringing all pieces together, we can define a view _descriptor_ that consists of
|
||||
the following:
|
||||
|
||||
1. the buffer pointer, `T*` where `T` is the result of converting the elemental
|
||||
type of the view;
|
||||
1. the linearized offset of the first accessed element;
|
||||
1. as many values as view rank, representing the size along each dimension;
|
||||
1. as many values as view rank, representing the step along each dimension in
|
||||
the index space of the original memref;
|
||||
|
||||
Using a hypothetical template syntax, the corresponding LLVM IR type would look
|
||||
like `template <type T, i64 N> { T*, i64, i64[N], i64[N] }`.
|
||||
|
||||
Let's start implementing the conversion by extending the type conversion
|
||||
function to include some primitive types, for example `f32`, `f64` and integers.
|
||||
|
||||
```c++
|
||||
Type linalg::convertLinalgType(Type t) {
|
||||
/*...*/
|
||||
// Construct an LLVM IR integer type of the same width as the MLIR type.
|
||||
if (auto intTy = t.dyn_cast<IntegerType>()) {
|
||||
int width = intTy.getWidth();
|
||||
auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
|
||||
return LLVM::LLVMType::get(context, integerTy);
|
||||
}
|
||||
// Convert f32 to LLVM IR float.
|
||||
if (t.isF32()) {
|
||||
auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
|
||||
return LLVM::LLVMType::get(context, floatTy);
|
||||
}
|
||||
// Convret f64 to LLVM IR double.
|
||||
if (t.isF64()) {
|
||||
auto *doubleTy = llvm::Type::getDoubleTy(dialect->getLLVMContext());
|
||||
return LLVM::LLVMType::get(context, doubleTy);
|
||||
}
|
||||
/*...*/
|
||||
```
|
||||
|
||||
Once properly defined, the conversion of the view type to the view descriptor
|
||||
type is straightforward:
|
||||
|
||||
```c++
|
||||
/*...*/
|
||||
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
|
||||
// Recursively call the type conversion for the element type, and extract
|
||||
// the LLVM IR type form the result.
|
||||
Type convertedElemTy = linalg::convertLinalgType(viewTy.getElementType());
|
||||
llvm::Type *convertedLLVMElemTy =
|
||||
convertedElemTy.cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
llvm::PointerType *ptrToElemLLVMTy = convertedLLVMElemTy->getPointerTo();
|
||||
|
||||
// Progressively construct the LLVM IR structure type.
|
||||
auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
|
||||
auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
|
||||
auto *structTy = llvm::StructType::get(
|
||||
ptrToElemLLVMTy, int64Ty, arrayTy, arrayTy);
|
||||
|
||||
// Wrap the LLVM IR type into the MLIR type.
|
||||
return LLVM::LLVMType::get(context, structTy);
|
||||
}
|
||||
/*...*/
|
||||
```
|
||||
|
||||
### Function Signature Conversions
|
||||
|
||||
For the sake of simplicity, let's rely on the default implementation of the
|
||||
function signature conversion that just converts the types.
|
||||
|
||||
Note that, in practice, LLVM IR does not support multi-result functions while
|
||||
MLIR does, which would require changing the function signature and introducing
|
||||
additional instructions during conversion. You can check how this is
|
||||
[implemented](https://github.com/tensorflow/mlir/blob/master/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp)
|
||||
in the actual conversion to the LLVM IR dialect.
|
||||
|
||||
### Operation Conversions
|
||||
|
||||
Operations on the view abstractions are mostly defined by the definition of the
|
||||
view descriptor: a `linalg.view` operation creates a view descriptor and fills
|
||||
it in with the initial data; a `linalg.slice` operation creates a new descriptor
|
||||
from an existing one, with modifications; `linalg.load` and `linalg.store` use
|
||||
the view descriptor to compute the address of the element before accessing it.
|
||||
|
||||
#### `linalg.view`
|
||||
|
||||
The role of a `linalg.view` is to construct a view descriptor given a memref, or
|
||||
rather, a memref descriptor as
|
||||
[defined](../../ConversionToLLVMDialect.md#memref-model) by the conversion of
|
||||
the standard dialect to the LLVM IR dialect. A memref descriptor is similar to a
|
||||
view descriptor: it contains the buffer pointer and the list of _dynamic_ sizes
|
||||
of the memref. Since memrefs are contiguous, there is no need to store the
|
||||
offset, the min/max values or the strides. Their static (constant) dimensions
|
||||
are available directly in the type signature.
|
||||
|
||||
An operation conversion is defined as special pattern by inheriting from
|
||||
`mlir::ConversionPattern` and by reimplementing the matching and the rewriting
|
||||
functions:
|
||||
|
||||
```c++
|
||||
class ViewOpConversion : public ConversionPattern {
|
||||
public:
|
||||
// A conversion constructor, may take arbtirary operands but must be able
|
||||
// to obtain an MLIRContext from them to call the parent constructor.
|
||||
explicit ViewOpConversion(MLIRContext *context);
|
||||
|
||||
// A matching function takes an operation and checks whether the pattern is
|
||||
// applicable to it by inspecting its properties.
|
||||
PatternMatchResult match(Operation *op) const override;
|
||||
|
||||
// A "rewriting" function that takes an original operation `op`, a list of
|
||||
// already rewritten operands, and a function builder `rewriter`. It can use
|
||||
// the builder to construct new operations and ultimately create new values
|
||||
// that will replace those currently produced by the original operation. It
|
||||
// needs to define as many value as the original operation, but their types
|
||||
// may be different.
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override;
|
||||
}
|
||||
```
|
||||
|
||||
The `ConversionPattern` constructor takes, in addition to the context, the name
|
||||
of the main operation to be matched and the "benefit" of a match. These operands
|
||||
are intended to be useful for defining an optimization problem across multiple
|
||||
possible conversions but are currently ignored by the conversion framework.
|
||||
|
||||
```c++
|
||||
ViewOpConversion::ViewOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
|
||||
```
|
||||
|
||||
The matching function can be used, for example, to apply different conversions
|
||||
depending on the argument types or on the attributes of the operation. In our
|
||||
case, it is applicable for any operation of the given type.
|
||||
|
||||
```c++
|
||||
PatternMatchResult ViewOpConversion::match(Operation *op) const override {
|
||||
if (op->isa<linalg::ViewOp>())
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
```
|
||||
|
||||
The actual conversion function may become quite involved. First, let us go over
|
||||
the components of a view descriptor and see how they can be constructed to
|
||||
represent a _complete_ view of a `memref`, e.g. a view that covers all its
|
||||
elements.
|
||||
|
||||
1. The buffer pointer is copied as is from the memref descriptor.
|
||||
1. The linearized offset is always 0.
|
||||
1. The size is originally the size of a memref:
|
||||
- static sizes are taken as constants given the values from the memref
|
||||
type signature;
|
||||
- dynamic sizes are extracted from the memref descriptor.
|
||||
1. The stride along a dimension can be defined recursively as:
|
||||
- the stride along the innermost dimension is always 1;
|
||||
- the stride along any other dimension is the size of the next inner
|
||||
dimension times its stride.
|
||||
|
||||
When a view is not complete, we need to take into account the ranges supplied as
|
||||
arguments to the `linalg.view` operation. In particular, the minimum and maximum
|
||||
index for each dimension is extracted from the corresponding range. The
|
||||
linearized offset is then computed as a sum of products of minimum indices along
|
||||
each dimension with the strides of these dimensions.
|
||||
|
||||
If a single `index` is supplied instead of a range, i.e. if we have a
|
||||
rank-reducing view, it will not have a dynamic representation in the view
|
||||
descriptor. However, its value is used as the minimum value in the linearized
|
||||
offset computation and the stride for this dimension participates in the
|
||||
recursive definition, although it is not stored in the descriptor.
|
||||
|
||||
The full conversion function is
|
||||
[available](https://github.com/tensorflow/mlir/blob/master/examples/Linalg1/lib/ConvertToLLVMDialect.cpp)
|
||||
and accounts for all these details and minimizes the number of instructions it
|
||||
produces. Let us consider some parts of this functions to understand how it
|
||||
operates.
|
||||
|
||||
```c++
|
||||
SmallVector<Value *, 4> ViewOpConversion::rewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override {
|
||||
// Obtain the typed operation (we know we matched only one type).
|
||||
auto viewOp = op->cast<linalg::ViewOp>();
|
||||
|
||||
// Extract the buffer pointer from the memref descriptor (first argument).
|
||||
Value *memrefDescriptor = operands[0];
|
||||
Value *bufferPtr;
|
||||
// The descriptor type varies depending on the memref type signature, so we
|
||||
// inspect the _original_ operand that has the memref type.
|
||||
auto memrefType = viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
|
||||
|
||||
// Build the "position" attribute, which correspond to the trailing constant
|
||||
// operands of the LLVM IR extractvalue instruction. (Note: in MLIR,
|
||||
// compile-time operation parameters are passed in as attributes). This is
|
||||
// an Array attribute holding Integer attributes. In our case, it only
|
||||
// holds one value. It will be used in insert/extact value below.
|
||||
auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(/*width=*/64),
|
||||
/*value=*/0);
|
||||
auto positionAttr = rewriter.getArrayAttr({attr});
|
||||
|
||||
// Create the context object (RAII) in which we can use declarative builders.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
// For statically-shaped memrefs, the descriptor is just the pointer.
|
||||
if (memrefType.hasStaticShape()) {
|
||||
bufferPtr = memrefType;
|
||||
// For dynamically-shaped memrefs, it is a structure whose first element is
|
||||
// a pointer. Extract it from the structure.
|
||||
} else {
|
||||
// Obtain an LLVM IR pointer type to the element, wrapped in MLIR type.
|
||||
Type wrappedElementTy =
|
||||
linalg::convertLinalgType(memrefType.getElementType());
|
||||
llvm::Type *elementTy =
|
||||
wrappedElementTy.cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
llvm::Type *elementPtrTy = elementTy->getPointerTo();
|
||||
Type wrappedElementPtrTy = rewriter.getType<LLVM::LLVMType>(elementPtrTy);
|
||||
|
||||
// Emit LLVM IR extractvalue to obtain the buffer pointer from the memref
|
||||
// descriptor.
|
||||
bufferPtr = intrinsics::extractvalue(wrappedElementPtrTy, memrefDescriptor,
|
||||
positionAttr);
|
||||
}
|
||||
|
||||
// Convert the type of the view to get the type of its descriptor.
|
||||
Type viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
|
||||
|
||||
// Define the descriptor value using the "undef" operation, equivalent to LLVM
|
||||
// IR's "undef" value. (Note: in MLIR, constants are not a special subclass
|
||||
// of a value; instead, they are produced by operations that take compile-time
|
||||
// constants as attributes and produce regular SSA values).
|
||||
Value *viewDescriptor = intrinsics::undef(viewDescriptorType);
|
||||
|
||||
// Insert the buffer pointer into the descriptor using `insertvalue`.
|
||||
viewDescriptor = intrinsics::insertvalue(viewDescriptorType, viewDescriptor,
|
||||
bufferPtr, positionAttr);
|
||||
|
||||
// *** the function continues with the remaining part of the descriptor *** //
|
||||
}
|
||||
```
|
||||
|
||||
Pay attention to the functions prefixed with `intrinsics`. They use the MLIR's
|
||||
[declarative builders](DeclarativeBuilders.md) interface for better readability.
|
||||
They can be rewritten using LLVM-like imperative IR builders. For example, the
|
||||
`extractvalue` call becomes
|
||||
|
||||
```c++
|
||||
bufferPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), wrappedElementPtrTy, memrefDescriptor, positionAttr);
|
||||
```
|
||||
|
||||
#### `linalg.slice`
|
||||
|
||||
The slice operation creates a view from another view, changing only a single
|
||||
dimension, so its conversion is significantly simpler. Practically, it creates a
|
||||
new view descriptor and fills it given the old descriptor and the range
|
||||
information supplied as argument. The minimum and maximum index values are
|
||||
updated with those supplied in the range, the linearized offset is recomputed
|
||||
with the new minimum, and the stride along the dimension is multiplied with the
|
||||
step of the range. If an index is supplied instead of the range, the minimum,
|
||||
maximum index and the stride corresponding to the slicing dimension are simply
|
||||
omitted in the new descriptor while the linearized offset is recomputed using
|
||||
the index as minimum value.
|
||||
|
||||
In order to avoid the proliferation of magic constants in insert/extractvalue
|
||||
operations for the descriptor, we can define an auxiliary IR-emitting data
|
||||
structure around it as follows.
|
||||
|
||||
```c++
|
||||
struct ViewDescriptor {
|
||||
// Obtain the type of the descriptor.
|
||||
Type type() { return d->getType(); }
|
||||
|
||||
// Obtain the pointer type to the element.
|
||||
Type elementPtrType() {
|
||||
llvm::Type *ty = type().cast<LLVM::LLVMType>().getUnderlyingType();
|
||||
llvm::StructType *structTy = cast<llvm::StructType>(ty);
|
||||
return builder.getType<LLVM::LLVMType>(structTy->getElementType(0));
|
||||
}
|
||||
|
||||
// Construct the materialization of the index type (currently, i64).
|
||||
Type indexType() {
|
||||
auto *dialect = static_cast<LLVM::LLVMDialect *>(
|
||||
builder.getContext().getRegisteredDialect("llvm"));
|
||||
llvm::Type *ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
|
||||
return builder.getType<LLVM::LLVMType>(ty);
|
||||
}
|
||||
|
||||
// Get the array attribute containing the given values as integer attributes.
|
||||
Attribute pos(ArrayRef<unsigned> position) {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
attrs.reserve(position.size());
|
||||
for (auto p : position)
|
||||
attrs.push_back(builder.getI64IntegerAttr(p));
|
||||
return builder.getArrayAttr(attrs);
|
||||
}
|
||||
|
||||
// Emit instructions obtaining individual values from the descriptor.
|
||||
Value *ptr() { return intrinsics::extractvalue(elementPtrType(), d, pos(0)); }
|
||||
Value *offset() { return intrinsics::extractvalue(indexType(), d, pos(1)); }
|
||||
Value *size(unsigned dim) {
|
||||
return intrinsics::extractvalue(indexType(), d, pos({2, dim}));
|
||||
}
|
||||
Value *stride(unsigned dim) {
|
||||
return intrinsics::extractvalue(indexType(), d, pos({3, dim}));
|
||||
}
|
||||
|
||||
// Emit instructions inserting individual values in the descriptor.
|
||||
void setPtr(Value *v) {
|
||||
return intrinsics::insertvalue(type(), d, v, pos(0));
|
||||
}
|
||||
void setOffset(Value *v) {
|
||||
return intrinsics::insertvalue(type(), d, v, pos(1));
|
||||
}
|
||||
void setSize(unsigned dim, Value *v) {
|
||||
return intrinsics::insertvalue(type(), d, v, pos({2, dim}));
|
||||
}
|
||||
void setStride(unsigned dim, Value *v) {
|
||||
return intrinsics::insertvalue(type(), d, v, pos({3, dim}));
|
||||
}
|
||||
|
||||
// The builder into which we emit code.
|
||||
OpBuilder &builder;
|
||||
|
||||
// The actual descriptor.
|
||||
Value *d;
|
||||
};
|
||||
```
|
||||
|
||||
With such a descriptor, the conversion function resembles closely the conversion
|
||||
rules described above:
|
||||
|
||||
```c++
|
||||
SmallVector<Value *, 4> SliceOpConversion::rewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override {
|
||||
// Obtain the typed operation (we know we matched only one type).
|
||||
auto sliceOp = op->cast<linalg::SliceOp>();
|
||||
|
||||
// Create the context object (RAII) in which we can use declarative builders.
|
||||
// Bring all the builders into the namespace.
|
||||
using namespace intrinsics;
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
auto newViewDescriptorType =
|
||||
linalg::convertLinalgType(sliceOp.getViewType());
|
||||
|
||||
// Define the new descriptor and obtain the old descriptor.
|
||||
Value *newViewDescriptor = intrinsics::undef(newViewDescriptorType);
|
||||
Value *oldViewDescriptor = operands[0];
|
||||
|
||||
// Create the code-emitting objects around the descriptors. The range
|
||||
// descriptor is defined similarly to the view descriptor with additional
|
||||
// support on the value being either of linalg.range type or of index type.
|
||||
auto newDescriptor = ViewDescriptor{rewriter, newViewDescriptor};
|
||||
auto oldDescriptor = ViewDescriptor{rewriter, oldViewDescriptor};
|
||||
auto rangeDescriptor = RangeDescriptor{rewriter, operands[1]};
|
||||
|
||||
// Properties of the slice.
|
||||
bool isRankDecreasing = sliceOp.isRankDecreasing();
|
||||
int dim = sliceOp.getSlicingDim();
|
||||
|
||||
// Copy the buffer pointer.
|
||||
newDecsriptor.setPtr(oldDescriptor.ptr());
|
||||
|
||||
// Recompute the offset.
|
||||
Value *min = rangeDescriptor.min();
|
||||
Value *stride = oldDescriptor.stride(dim);
|
||||
newDescriptor.setOffset(add(oldDescriptor.getOffset(), mul(min, stride)));
|
||||
|
||||
// Copy the sizes and strides into the new descriptor, updating or dropping
|
||||
// the affected dimension. If the `slice` is rank-decreasing, the resulting
|
||||
// view will no longer one of the dimensions, its size and stride become
|
||||
// unnecessary and can be dropped. Otherwise, the size of the affected
|
||||
// updated to the size of the range and its stride is multiplied with the step
|
||||
// of the range.
|
||||
for (int i = 0, e = sliceOp.getRank(); i < e; ++i) {
|
||||
int originalPos = (isRankDecreasing && i >= dim) ? i + 1 : i;
|
||||
if (!isRankDecreasing && i == dim) {
|
||||
newDescriptor.setSize(
|
||||
i, sub(rangeDescriptor.max(), rangeDescriptor.min()));
|
||||
newDescriptor.setStride(
|
||||
i, mul(oldDescriptor.getStride(i), rangeDescriptor.step()));
|
||||
} else {
|
||||
newDescriptor.setSize(i, oldDescriptor.getSize(originalPos));
|
||||
newDescriptor.setStride(i, oldDescriptor.getStride(originalPos));
|
||||
}
|
||||
}
|
||||
|
||||
return {newViewDescriptor};
|
||||
}
|
||||
```
|
||||
|
||||
Note that we used a `using namespace intrinsics` statement to make the
|
||||
declarative builders for the LLVM IR dialect operations available without extra
|
||||
qualification in order to make composed expressions even simpler. We also
|
||||
omitted the matching function that is similar to that of the `linalg.view`.
|
||||
|
||||
#### `linalg.load` and `linalg.store`
|
||||
|
||||
Loads and stores through views are implemented in a similar fashion. Both need
|
||||
to first compute the effective linearized address of the element in the
|
||||
underlying buffer, and then emit either a load or a store operation on that
|
||||
address. The linearization is straightforward given the presence of the offset
|
||||
and strides in the descriptor: the total offset is the sum of the base offset
|
||||
and the products between access subscripts with strides along the given
|
||||
dimension.
|
||||
|
||||
The linearization part can be easily implemented using the code emitting object
|
||||
for the view descriptor:
|
||||
|
||||
```c++
|
||||
Value *obtainDataPtr(Location loc, int rank, Value *viewDescriptorVal,
|
||||
ArrayRef<Value *> indices, OpBuilder &rewriter) {
|
||||
// Create the context object (RAII) in which we can use declarative builders.
|
||||
// Bring all the builders into the namespace.
|
||||
using namespace intrinsics;
|
||||
edsc::ScopedContext context(rewriter, loc);
|
||||
|
||||
// Create the code emitting object for the descriptor.
|
||||
auto viewDescriptor = ViewDescriptor{rewriter, viewDescriptorVal};
|
||||
|
||||
// Linearize subscripts as:
|
||||
// base_offset + SUM_i index_i * stride_i.
|
||||
Value *offset = viewDescriptor.offset();
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
Value *stride = viewDescriptor.getStride(i);
|
||||
offset = add(offset, mul(indices[i], stride));
|
||||
}
|
||||
|
||||
// Emit a getelementptr instruction with the linearized offset from the buffer
|
||||
// pointer, producing a pointer to the accessed element.
|
||||
Value *elementPtr = gep(viewDescriptor.elementPtrType(), viewDescriptor.ptr(),
|
||||
ArrayRef<Value *>{offset});
|
||||
return elementPtr;
|
||||
}
|
||||
```
|
||||
|
||||
Given this utility function template, it becomes easy to implement the actual
|
||||
conversions for load and store operations.
|
||||
|
||||
```c++
|
||||
// Load Operation Conversion.
|
||||
SmallVector<Value *, 4> LoadOpConversion::rewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override {
|
||||
// Obtain the typed operation (we know we matched only one type).
|
||||
auto loadOp = op->cast<linalg::LoadOp>();
|
||||
|
||||
// Separate the descriptor operand from the index operands.
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
|
||||
// Call the auxiliary function to emit code computing the element pointer.
|
||||
Value *ptr = obtainDataPtr(op->getLoc(), loadOp->getRank(), viewDescriptor,
|
||||
indices, rewriter);
|
||||
|
||||
// Use declarative builders to load from the element pointer.
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
|
||||
Value *element = intrinsics::load(elementType, ptr);
|
||||
return {element};
|
||||
}
|
||||
|
||||
// Store Operation Conversion
|
||||
SmallVector<Value *, 4> StoreOpConversion::rewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override {
|
||||
// Obtain the typed operation (we know we matched only one type).
|
||||
auto loadOp = op->cast<linalg::StoreOp>();
|
||||
|
||||
// Separate the value and descriptor operands from the index operands.
|
||||
Value *data = operands[0];
|
||||
Value *viewDescriptor = operands[1];
|
||||
ArrayRef<Value *> indices = operands.drop_front(2);
|
||||
|
||||
// Call the auxiliary function to emit code computing the element pointer.
|
||||
Value *ptr = obtainDataPtr(op->getLoc(), loadOp->getRank(), viewDescriptor,
|
||||
indices, rewriter);
|
||||
|
||||
// Use declarative builders to load from the element pointer.
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
Value *element = intrinsics::store(data, ptr);
|
||||
|
||||
// "Store" does not produce any values.
|
||||
return {};
|
||||
}
|
||||
```
|
||||
|
||||
### Putting It All Together
|
||||
|
||||
Having defined the conversions for the types and the operations, we can now
|
||||
proceed to invoking the dialect conversion framework that will transform entire
|
||||
MLIR modules for us. The conversion class must inherit from `DialectConversion`
|
||||
and override two pure virtual functions: one that initializes the list of
|
||||
operation converters, and another one that is called to convert individual
|
||||
types. Function signature conversion function can also be overridden but it has
|
||||
a default implementation.
|
||||
|
||||
```c++
|
||||
// Define a dialect conversion class.
|
||||
class Lowering : public DialectConversion {
|
||||
protected:
|
||||
// Produce a set of operation conversion patterns. This is called once per
|
||||
// conversion.
|
||||
llvm::DenseSet<ConversionPattern *>
|
||||
initConverter(MLIRContext *context) override {
|
||||
allocator.Reset();
|
||||
// Call a helper function provided by MLIR to build a set of operation
|
||||
// conversion instances given a list of classes as template parameters.
|
||||
// These instances will be allocated within `allocator` and their lifetime
|
||||
// is managed by the Lowering class.
|
||||
return RewriteListBuilder<
|
||||
LoadOpConversion, SliceOpConversion, StoreOpConversion,
|
||||
ViewOpConversion>::build(allocator, context);
|
||||
}
|
||||
|
||||
// Convert a type. This function will be called for each function/region
|
||||
// argument or result type (unless there is a custom function signature
|
||||
// conversion) as well as for each block argument type.
|
||||
Type convertType(Type t) override { return linalg::convertLinalgType(t); }
|
||||
|
||||
// The individual conversion patterns will live here.
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
```
|
|
@ -44,7 +44,6 @@ set(MLIR_TEST_DEPENDS
|
|||
|
||||
if(LLVM_BUILD_EXAMPLES)
|
||||
list(APPEND MLIR_TEST_DEPENDS
|
||||
linalg1-opt
|
||||
toyc-ch1
|
||||
toyc-ch2
|
||||
toyc-ch3
|
||||
|
|
|
@ -1,169 +0,0 @@
|
|||
// RUN: linalg1-opt %s | FileCheck %s
|
||||
// RUN: linalg1-opt %s -lower-linalg-to-llvm | FileCheck %s -check-prefix=LLVM
|
||||
|
||||
func @view_op(%arg0: memref<f32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
|
||||
%c3 = constant 3 : index
|
||||
%c17 = constant 17 : index
|
||||
%c1 = constant 1 : index
|
||||
%3 = linalg.range %c3:%c17:%c1 : !linalg.range
|
||||
%4 = linalg.view %arg0[] : memref<f32>, !linalg.view<f32>
|
||||
%5 = linalg.view %arg1[%3] : memref<?xf32>, !linalg.range, !linalg.view<?xf32>
|
||||
%6 = linalg.view %arg2[%3, %3] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
"some_consumer"(%4, %5, %6) : (!linalg.view<f32>, !linalg.view<?xf32>, !linalg.view<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @view_op(%arg0: memref<f32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %0 = linalg.range {{.*}} : !linalg.range
|
||||
// CHECK: {{.*}} = linalg.view %arg0[] : memref<f32>, !linalg.view<f32>
|
||||
// CHECK: {{.*}} = linalg.view %arg1[%0] : memref<?xf32>, !linalg.range, !linalg.view<?xf32>
|
||||
// CHECK: {{.*}} = linalg.view %arg2[%0, %0] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
|
||||
func @slice_op(%arg0: memref<?x?xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%1 = dim %arg0, 0 : memref<?x?xf32>
|
||||
%2 = dim %arg0, 1 : memref<?x?xf32>
|
||||
%3 = linalg.range %c0:%1:%c1 : !linalg.range
|
||||
%4 = linalg.range %c0:%2:%c1 : !linalg.range
|
||||
%5 = linalg.view %arg0[%3, %4] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
affine.for %i0 = 0 to (d0) -> (d0)(%1) {
|
||||
affine.for %i1 = 0 to (d0) -> (d0)(%2) {
|
||||
%6 = linalg.slice %5[%i0] {dim = 1} : !linalg.view<?x?xf32>, index
|
||||
"some_consumer"(%6) : (!linalg.view<?xf32>) -> ()
|
||||
%7 = linalg.slice %5[%i1] {dim = 0} : !linalg.view<?x?xf32>, index
|
||||
%8 = linalg.slice %7[%i0] {dim = 0} : !linalg.view<?xf32>, index
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @slice_op(%{{.*}}: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[r1:.*]] = linalg.range %{{.*}}:%[[M]]:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[r2:.*]] = linalg.range %{{.*}}:%[[N]]:%{{.*}} : !linalg.range
|
||||
// CHECK: %[[V:.*]] = linalg.view %{{.*}}[%[[r1]], %[[r2]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %{{.*}} = 0 to #map1(%{{.*}}) {
|
||||
// CHECK: affine.for %{{.*}} = 0 to #map1(%{{.*}}) {
|
||||
// CHECK: {{.*}} = linalg.slice %[[V]][%{{.*}}] {dim = 1} : !linalg.view<?x?xf32>, index
|
||||
// CHECK: %[[V2:.*]] = linalg.slice %[[V]][%{{.*}}] {dim = 0} : !linalg.view<?x?xf32>, index
|
||||
// CHECK: {{.*}} = linalg.slice %[[V2]][%{{.*}}] {dim = 0} : !linalg.view<?xf32>, index
|
||||
|
||||
func @rangeConversion(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
return
|
||||
}
|
||||
// LLVM-LABEL: @rangeConversion
|
||||
// LLVM-NEXT: llvm.mlir.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
|
||||
func @viewRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
return
|
||||
}
|
||||
// LLVM-LABEL: @viewRangeConversion
|
||||
// LLVM-NEXT: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// LLVM-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
||||
func @viewNonRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: index) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
// LLVM-LABEL: @viewNonRangeConversion
|
||||
// LLVM-NEXT: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// LLVM-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
func @sliceRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
%1 = linalg.slice %0[%arg3] {dim = 0} : !linalg.view<?x?xf32>, !linalg.range
|
||||
return
|
||||
}
|
||||
// LLVM-LABEL: @sliceRangeConversion
|
||||
// LLVM: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
||||
func @sliceNonRangeConversion2(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: index) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
%1 = linalg.slice %0[%arg3] {dim = 0} : !linalg.view<?x?xf32>, index
|
||||
return
|
||||
}
|
||||
// LLVM-LABEL: @sliceNonRangeConversion2
|
||||
// LLVM: llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.mul %{{.*}}arg3, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
Loading…
Reference in New Issue