[spirv] Add a skeleton to translate standard ops into SPIR-V dialect

PiperOrigin-RevId: 252651994
This commit is contained in:
Mahesh Ravishankar 2019-06-11 10:47:06 -07:00 committed by Mehdi Amini
parent 420c1f383a
commit d3a601ce33
7 changed files with 193 additions and 1 deletions

View File

@ -7,3 +7,5 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen) add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
add_subdirectory(Transforms)

View File

@ -0,0 +1,35 @@
//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes that expose pass constructors.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SPIRV_PASSES_H_
#define MLIR_SPIRV_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace spirv {
FunctionPassBase *createStdOpsToSPIRVConversionPass();
} // namespace spirv
} // namespace mlir
#endif // MLIR_SPIRV_PASSES_H_

View File

@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)

View File

@ -0,0 +1,48 @@
//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines Patterns to lower standard ops to SPIR-V
//
//===----------------------------------------------------------------------===//
#ifdef STANDARD_OPS_TO_SPIRV
#else
#define STANDARD_OPS_TO_SPIRV
#ifdef STANDARD_OPS
#else
include "mlir/StandardOps/Ops.td"
#endif // STANDARD_OPS
#ifdef SPIRV_OPS
#else
include "mlir/SPIRV/SPIRVOps.td"
#endif // SPIRV_OPS
def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
class IsVectorLengthPred<int vecLength> :
CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
"$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;
class IsVectorOfLength<int vecLength>:
TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
vecLength # "-element vector">;
multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
foreach vecLength = [2, 3, 4] in {
def : Pat<(src IsVectorOfLength<vecLength>:$l,
IsVectorOfLength<vecLength>:$r),
(tgt $l, $r)>;
}
}
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
#endif // STANDARD_OPS_TO_SPIRV

View File

@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV
SPIRVDialect.cpp SPIRVDialect.cpp
SPIRVOps.cpp SPIRVOps.cpp
SPIRVTypes.cpp SPIRVTypes.cpp
Transforms/StdOpsToSPIRVConversion.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
@ -10,6 +11,7 @@ add_llvm_library(MLIRSPIRV
add_dependencies(MLIRSPIRV add_dependencies(MLIRSPIRV
MLIRSPIRVOpsIncGen MLIRSPIRVOpsIncGen
MLIRSPIRVEnumsIncGen) MLIRSPIRVEnumsIncGen
MLIRStdOpsToSPIRVConversionIncGen)
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport) target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)

View File

@ -0,0 +1,56 @@
//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass to convert MLIR standard ops into the SPIR-V
// dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/SPIRV/Passes.h"
#include "mlir/SPIRV/SPIRVOps.h"
namespace mlir {
#include "mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
class StdOpsToSPIRVConversionPass
: public FunctionPass<StdOpsToSPIRVConversionPass> {
void runOnFunction() override;
};
} // namespace
void StdOpsToSPIRVConversionPass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
populateWithGenerated(func.getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));
}
FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
return new StdOpsToSPIRVConversionPass();
}
static PassRegistration<StdOpsToSPIRVConversionPass>
pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");

View File

@ -0,0 +1,46 @@
// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
// CHECK-LABEL: @fmul_scalar
func @fmul_scalar(%arg: f32) -> f32 {
// CHECK: spv.FMul
%0 = mulf %arg, %arg : f32
return %0 : f32
}
// CHECK-LABEL: @fmul_vector2
func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> {
// CHECK: spv.FMul
%0 = mulf %arg, %arg : vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: @fmul_vector3
func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> {
// CHECK: spv.FMul
%0 = mulf %arg, %arg : vector<3xf32>
return %0 : vector<3xf32>
}
// CHECK-LABEL: @fmul_vector4
func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> {
// CHECK: spv.FMul
%0 = mulf %arg, %arg : vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: @fmul_vector5
func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> {
// Vector length of only 2, 3, and 4 is valid for SPIR-V
// CHECK: mulf
%0 = mulf %arg, %arg : vector<5xf32>
return %0 : vector<5xf32>
}
// CHECK-LABEL: @fmul_tensor
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
// For tensors mulf cannot be lowered directly to spv.FMul
// CHECK: mulf
%0 = mulf %arg, %arg : tensor<4xf32>
return %0 : tensor<4xf32>
}