forked from OSchip/llvm-project
[spirv] Add a skeleton to translate standard ops into SPIR-V dialect
PiperOrigin-RevId: 252651994
This commit is contained in:
parent
420c1f383a
commit
d3a601ce33
|
@ -7,3 +7,5 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
|||
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
||||
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file defines prototypes that expose pass constructors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_SPIRV_PASSES_H_
|
||||
#define MLIR_SPIRV_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
FunctionPassBase *createStdOpsToSPIRVConversionPass();
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_SPIRV_PASSES_H_
|
|
@ -0,0 +1,3 @@
|
|||
set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
|
||||
mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)
|
|
@ -0,0 +1,48 @@
|
|||
//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
|
||||
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines Patterns to lower standard ops to SPIR-V
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef STANDARD_OPS_TO_SPIRV
|
||||
#else
|
||||
#define STANDARD_OPS_TO_SPIRV
|
||||
|
||||
#ifdef STANDARD_OPS
|
||||
#else
|
||||
include "mlir/StandardOps/Ops.td"
|
||||
#endif // STANDARD_OPS
|
||||
|
||||
#ifdef SPIRV_OPS
|
||||
#else
|
||||
include "mlir/SPIRV/SPIRVOps.td"
|
||||
#endif // SPIRV_OPS
|
||||
|
||||
def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
|
||||
|
||||
class IsVectorLengthPred<int vecLength> :
|
||||
CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
|
||||
"$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;
|
||||
|
||||
class IsVectorOfLength<int vecLength>:
|
||||
TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
|
||||
vecLength # "-element vector">;
|
||||
|
||||
multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
|
||||
def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
|
||||
foreach vecLength = [2, 3, 4] in {
|
||||
def : Pat<(src IsVectorOfLength<vecLength>:$l,
|
||||
IsVectorOfLength<vecLength>:$r),
|
||||
(tgt $l, $r)>;
|
||||
}
|
||||
}
|
||||
|
||||
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
|
||||
|
||||
#endif // STANDARD_OPS_TO_SPIRV
|
|
@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV
|
|||
SPIRVDialect.cpp
|
||||
SPIRVOps.cpp
|
||||
SPIRVTypes.cpp
|
||||
Transforms/StdOpsToSPIRVConversion.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
|
||||
|
@ -10,6 +11,7 @@ add_llvm_library(MLIRSPIRV
|
|||
|
||||
add_dependencies(MLIRSPIRV
|
||||
MLIRSPIRVOpsIncGen
|
||||
MLIRSPIRVEnumsIncGen)
|
||||
MLIRSPIRVEnumsIncGen
|
||||
MLIRStdOpsToSPIRVConversionIncGen)
|
||||
|
||||
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a pass to convert MLIR standard ops into the SPIR-V
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/SPIRV/Passes.h"
|
||||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
|
||||
class StdOpsToSPIRVConversionPass
|
||||
: public FunctionPass<StdOpsToSPIRVConversionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void StdOpsToSPIRVConversionPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
|
||||
populateWithGenerated(func.getContext(), &patterns);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
|
||||
return new StdOpsToSPIRVConversionPass();
|
||||
}
|
||||
|
||||
static PassRegistration<StdOpsToSPIRVConversionPass>
|
||||
pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
|
|
@ -0,0 +1,46 @@
|
|||
// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @fmul_scalar
|
||||
func @fmul_scalar(%arg: f32) -> f32 {
|
||||
// CHECK: spv.FMul
|
||||
%0 = mulf %arg, %arg : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fmul_vector2
|
||||
func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> {
|
||||
// CHECK: spv.FMul
|
||||
%0 = mulf %arg, %arg : vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fmul_vector3
|
||||
func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> {
|
||||
// CHECK: spv.FMul
|
||||
%0 = mulf %arg, %arg : vector<3xf32>
|
||||
return %0 : vector<3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fmul_vector4
|
||||
func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> {
|
||||
// CHECK: spv.FMul
|
||||
%0 = mulf %arg, %arg : vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fmul_vector5
|
||||
func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> {
|
||||
// Vector length of only 2, 3, and 4 is valid for SPIR-V
|
||||
// CHECK: mulf
|
||||
%0 = mulf %arg, %arg : vector<5xf32>
|
||||
return %0 : vector<5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fmul_tensor
|
||||
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// For tensors mulf cannot be lowered directly to spv.FMul
|
||||
// CHECK: mulf
|
||||
%0 = mulf %arg, %arg : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue