forked from OSchip/llvm-project
[spirv] Add a skeleton to translate standard ops into SPIR-V dialect
PiperOrigin-RevId: 252651994
This commit is contained in:
parent
420c1f383a
commit
d3a601ce33
|
@ -7,3 +7,5 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||||
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
||||||
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
||||||
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
||||||
|
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This header file defines prototypes that expose pass constructors.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_SPIRV_PASSES_H_
|
||||||
|
#define MLIR_SPIRV_PASSES_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace spirv {
|
||||||
|
|
||||||
|
FunctionPassBase *createStdOpsToSPIRVConversionPass();
|
||||||
|
|
||||||
|
} // namespace spirv
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_SPIRV_PASSES_H_
|
|
@ -0,0 +1,3 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
|
||||||
|
mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)
|
|
@ -0,0 +1,48 @@
|
||||||
|
//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
|
||||||
|
|
||||||
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Defines Patterns to lower standard ops to SPIR-V
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef STANDARD_OPS_TO_SPIRV
|
||||||
|
#else
|
||||||
|
#define STANDARD_OPS_TO_SPIRV
|
||||||
|
|
||||||
|
#ifdef STANDARD_OPS
|
||||||
|
#else
|
||||||
|
include "mlir/StandardOps/Ops.td"
|
||||||
|
#endif // STANDARD_OPS
|
||||||
|
|
||||||
|
#ifdef SPIRV_OPS
|
||||||
|
#else
|
||||||
|
include "mlir/SPIRV/SPIRVOps.td"
|
||||||
|
#endif // SPIRV_OPS
|
||||||
|
|
||||||
|
def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
|
||||||
|
|
||||||
|
class IsVectorLengthPred<int vecLength> :
|
||||||
|
CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
|
||||||
|
"$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;
|
||||||
|
|
||||||
|
class IsVectorOfLength<int vecLength>:
|
||||||
|
TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
|
||||||
|
vecLength # "-element vector">;
|
||||||
|
|
||||||
|
multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
|
||||||
|
def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
|
||||||
|
foreach vecLength = [2, 3, 4] in {
|
||||||
|
def : Pat<(src IsVectorOfLength<vecLength>:$l,
|
||||||
|
IsVectorOfLength<vecLength>:$r),
|
||||||
|
(tgt $l, $r)>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
|
||||||
|
|
||||||
|
#endif // STANDARD_OPS_TO_SPIRV
|
|
@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV
|
||||||
SPIRVDialect.cpp
|
SPIRVDialect.cpp
|
||||||
SPIRVOps.cpp
|
SPIRVOps.cpp
|
||||||
SPIRVTypes.cpp
|
SPIRVTypes.cpp
|
||||||
|
Transforms/StdOpsToSPIRVConversion.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
|
||||||
|
@ -10,6 +11,7 @@ add_llvm_library(MLIRSPIRV
|
||||||
|
|
||||||
add_dependencies(MLIRSPIRV
|
add_dependencies(MLIRSPIRV
|
||||||
MLIRSPIRVOpsIncGen
|
MLIRSPIRVOpsIncGen
|
||||||
MLIRSPIRVEnumsIncGen)
|
MLIRSPIRVEnumsIncGen
|
||||||
|
MLIRStdOpsToSPIRVConversionIncGen)
|
||||||
|
|
||||||
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
|
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a pass to convert MLIR standard ops into the SPIR-V
|
||||||
|
// dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/SPIRV/Passes.h"
|
||||||
|
#include "mlir/SPIRV/SPIRVOps.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
#include "mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp.inc"
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
|
||||||
|
class StdOpsToSPIRVConversionPass
|
||||||
|
: public FunctionPass<StdOpsToSPIRVConversionPass> {
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void StdOpsToSPIRVConversionPass::runOnFunction() {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
auto &func = getFunction();
|
||||||
|
|
||||||
|
populateWithGenerated(func.getContext(), &patterns);
|
||||||
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
|
||||||
|
return new StdOpsToSPIRVConversionPass();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<StdOpsToSPIRVConversionPass>
|
||||||
|
pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
|
|
@ -0,0 +1,46 @@
|
||||||
|
// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_scalar
|
||||||
|
func @fmul_scalar(%arg: f32) -> f32 {
|
||||||
|
// CHECK: spv.FMul
|
||||||
|
%0 = mulf %arg, %arg : f32
|
||||||
|
return %0 : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_vector2
|
||||||
|
func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> {
|
||||||
|
// CHECK: spv.FMul
|
||||||
|
%0 = mulf %arg, %arg : vector<2xf32>
|
||||||
|
return %0 : vector<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_vector3
|
||||||
|
func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> {
|
||||||
|
// CHECK: spv.FMul
|
||||||
|
%0 = mulf %arg, %arg : vector<3xf32>
|
||||||
|
return %0 : vector<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_vector4
|
||||||
|
func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> {
|
||||||
|
// CHECK: spv.FMul
|
||||||
|
%0 = mulf %arg, %arg : vector<4xf32>
|
||||||
|
return %0 : vector<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_vector5
|
||||||
|
func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> {
|
||||||
|
// Vector length of only 2, 3, and 4 is valid for SPIR-V
|
||||||
|
// CHECK: mulf
|
||||||
|
%0 = mulf %arg, %arg : vector<5xf32>
|
||||||
|
return %0 : vector<5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fmul_tensor
|
||||||
|
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
// For tensors mulf cannot be lowered directly to spv.FMul
|
||||||
|
// CHECK: mulf
|
||||||
|
%0 = mulf %arg, %arg : tensor<4xf32>
|
||||||
|
return %0 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue