forked from OSchip/llvm-project
[mlir] Prevent SubElementInterface from going into infinite recursion
Since only mutable types and attributes can go into infinite recursion inside SubElementInterface::walkSubElement, and there are only a few of them (mutable types and attributes), we introduce new traits for Type and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable, respectively. They indicate whether a type or attribute is mutable. Such traits are required if the ImplType defines a `mutate` function. Then, inside SubElementInterface, we use a set to record visited mutable types and attributes that have been visited before. Differential Revision: https://reviews.llvm.org/D127537
This commit is contained in:
parent
bc5e7ced1c
commit
d41028610b
|
@ -264,7 +264,8 @@ public:
|
|||
/// structs, but does not in uniquing of identified structs.
|
||||
class LLVMStructType
|
||||
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
|
||||
DataLayoutTypeInterface::Trait> {
|
||||
DataLayoutTypeInterface::Trait,
|
||||
TypeTrait::IsMutable> {
|
||||
public:
|
||||
/// Inherit base constructors.
|
||||
using Base::Base;
|
||||
|
|
|
@ -275,8 +275,9 @@ public:
|
|||
/// In the above, expressing recursive struct types is accomplished by giving a
|
||||
/// recursive struct a unique identified and using that identifier in the struct
|
||||
/// definition for recursive references.
|
||||
class StructType : public Type::TypeBase<StructType, CompositeType,
|
||||
detail::StructTypeStorage> {
|
||||
class StructType
|
||||
: public Type::TypeBase<StructType, CompositeType,
|
||||
detail::StructTypeStorage, TypeTrait::IsMutable> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
|
|
|
@ -231,6 +231,18 @@ private:
|
|||
friend InterfaceBase;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core AttributeTrait
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This trait is used to determine if an attribute is mutable or not. It is
|
||||
/// attached on an attribute if the corresponding ImplType defines a `mutate`
|
||||
/// function with proper signature.
|
||||
namespace AttributeTrait {
|
||||
template <typename ConcreteType>
|
||||
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
|
||||
} // namespace AttributeTrait
|
||||
|
||||
} // namespace mlir.
|
||||
|
||||
namespace llvm {
|
||||
|
|
|
@ -53,6 +53,16 @@ protected:
|
|||
}
|
||||
};
|
||||
|
||||
namespace StorageUserTrait {
|
||||
/// This trait is used to determine if a storage user, like Type, is mutable
|
||||
/// or not. A storage user is mutable if ImplType of the derived class defines
|
||||
/// a `mutate` function with a proper signature. Note that this trait is not
|
||||
/// supposed to be used publicly. Users should use alias names like
|
||||
/// `TypeTrait::IsMutable` instead.
|
||||
template <typename ConcreteType>
|
||||
struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
|
||||
} // namespace StorageUserTrait
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StorageUserBase
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -173,6 +183,10 @@ protected:
|
|||
/// Mutate the current storage instance. This will not change the unique key.
|
||||
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
||||
template <typename... Args> LogicalResult mutate(Args &&...args) {
|
||||
static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
|
||||
ConcreteT>::value,
|
||||
"The `mutate` function expects mutable trait "
|
||||
"(e.g. TypeTrait::IsMutable) to be attached on parent.");
|
||||
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
|
|
@ -222,6 +222,18 @@ private:
|
|||
friend InterfaceBase;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Core TypeTrait
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This trait is used to determine if a type is mutable or not. It is attached
|
||||
/// on a type if the corresponding ImplType defines a `mutate` function with
|
||||
/// a proper signature.
|
||||
namespace TypeTrait {
|
||||
template <typename ConcreteType>
|
||||
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
|
||||
} // namespace TypeTrait
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Utils
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -8,12 +8,16 @@
|
|||
|
||||
#include "mlir/IR/SubElementInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
template <typename InterfaceT>
|
||||
static void walkSubElementsImpl(InterfaceT interface,
|
||||
function_ref<void(Attribute)> walkAttrsFn,
|
||||
function_ref<void(Type)> walkTypesFn) {
|
||||
function_ref<void(Type)> walkTypesFn,
|
||||
DenseSet<Attribute> &visitedAttrs,
|
||||
DenseSet<Type> &visitedTypes) {
|
||||
interface.walkImmediateSubElements(
|
||||
[&](Attribute attr) {
|
||||
// Guard against potentially null inputs. This removes the need for the
|
||||
|
@ -21,9 +25,17 @@ static void walkSubElementsImpl(InterfaceT interface,
|
|||
if (!attr)
|
||||
return;
|
||||
|
||||
// Avoid infinite recursion when visiting sub attributes later, if this
|
||||
// is a mutable attribute.
|
||||
if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
|
||||
if (!visitedAttrs.insert(attr).second)
|
||||
return;
|
||||
}
|
||||
|
||||
// Walk any sub elements first.
|
||||
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
|
||||
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
|
||||
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
|
||||
visitedTypes);
|
||||
|
||||
// Walk this attribute.
|
||||
walkAttrsFn(attr);
|
||||
|
@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
|
|||
if (!type)
|
||||
return;
|
||||
|
||||
// Avoid infinite recursion when visiting sub types later, if this
|
||||
// is a mutable type.
|
||||
if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
|
||||
if (!visitedTypes.insert(type).second)
|
||||
return;
|
||||
}
|
||||
|
||||
// Walk any sub elements first.
|
||||
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
|
||||
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
|
||||
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
|
||||
visitedTypes);
|
||||
|
||||
// Walk this type.
|
||||
walkTypesFn(type);
|
||||
|
@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
|
|||
function_ref<void(Attribute)> walkAttrsFn,
|
||||
function_ref<void(Type)> walkTypesFn) {
|
||||
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
|
||||
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
|
||||
DenseSet<Attribute> visitedAttrs;
|
||||
DenseSet<Type> visitedTypes;
|
||||
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
|
||||
visitedTypes);
|
||||
}
|
||||
|
||||
void SubElementTypeInterface::walkSubElements(
|
||||
function_ref<void(Attribute)> walkAttrsFn,
|
||||
function_ref<void(Type)> walkTypesFn) {
|
||||
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
|
||||
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
|
||||
DenseSet<Attribute> visitedAttrs;
|
||||
DenseSet<Type> visitedTypes;
|
||||
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
|
||||
visitedTypes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
|
||||
|
||||
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
|
||||
|
||||
// CHECK-LABEL: @roundtrip
|
||||
func.func @roundtrip() {
|
||||
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
|
||||
// CHECK: !test.test_rec<c, test_rec<c>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
|
||||
// Make sure walkSubElementType, which is used to generate aliases, doesn't go
|
||||
// into inifinite recursion.
|
||||
// CHECK: !testrec
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
|
|||
return AliasResult::FinalAlias;
|
||||
}
|
||||
}
|
||||
if (auto recType = type.dyn_cast<TestRecursiveType>()) {
|
||||
if (recType.getName() == "type_to_alias") {
|
||||
// We only make alias for a specific recursive type.
|
||||
os << "testrec";
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
}
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SubElementInterfaces.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
||||
|
||||
|
@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
|
|||
/// from type creation.
|
||||
class TestRecursiveType
|
||||
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
|
||||
TestRecursiveTypeStorage> {
|
||||
TestRecursiveTypeStorage,
|
||||
::mlir::SubElementTypeInterface::Trait,
|
||||
::mlir::TypeTrait::IsMutable> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
|
@ -141,10 +144,16 @@ public:
|
|||
|
||||
/// Body getter and setter.
|
||||
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
|
||||
::mlir::Type getBody() { return getImpl()->body; }
|
||||
::mlir::Type getBody() const { return getImpl()->body; }
|
||||
|
||||
/// Name/key getter.
|
||||
::llvm::StringRef getName() { return getImpl()->name; }
|
||||
|
||||
void walkImmediateSubElements(
|
||||
::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
|
||||
::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
|
||||
walkTypesFn(getBody());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
|
|
|
@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
|
|||
MLIRDialect)
|
||||
|
||||
add_subdirectory(Affine)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(MemRef)
|
||||
|
||||
add_subdirectory(Quant)
|
||||
add_subdirectory(SparseTensor)
|
||||
add_subdirectory(SPIRV)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
add_mlir_unittest(MLIRLLVMIRTests
|
||||
LLVMTypeTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRLLVMIRTests
|
||||
PRIVATE
|
||||
MLIRLLVMDialect
|
||||
)
|
|
@ -0,0 +1,27 @@
|
|||
//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Test fixure for LLVM dialect tests.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
|
||||
#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
class LLVMIRTest : public ::testing::Test {
|
||||
protected:
|
||||
LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }
|
||||
|
||||
mlir::MLIRContext context;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,20 @@
|
|||
//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "LLVMTestBase.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/SubElementInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
TEST_F(LLVMIRTest, IsStructTypeMutable) {
|
||||
auto structTy = LLVMStructType::getIdentified(&context, "foo");
|
||||
ASSERT_TRUE(bool(structTy));
|
||||
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
|
||||
}
|
Loading…
Reference in New Issue