[mlir] Prevent SubElementInterface from going into infinite recursion

Since only mutable types and attributes can go into infinite recursion
inside SubElementInterface::walkSubElement, and there are only a few of
them (mutable types and attributes), we introduce new traits for Type
and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable,
respectively. They indicate whether a type or attribute is mutable.
Such traits are required if the ImplType defines a `mutate` function.

Then, inside SubElementInterface, we use a set to record visited mutable
types and attributes that have been visited before.

Differential Revision: https://reviews.llvm.org/D127537
This commit is contained in:
Min-Yih Hsu 2022-05-20 21:52:49 -07:00
parent bc5e7ced1c
commit d41028610b
13 changed files with 153 additions and 11 deletions

View File

@ -264,7 +264,8 @@ public:
/// structs, but does not in uniquing of identified structs.
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait> {
DataLayoutTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
using Base::Base;

View File

@ -275,8 +275,9 @@ public:
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
class StructType
: public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage, TypeTrait::IsMutable> {
public:
using Base::Base;

View File

@ -231,6 +231,18 @@ private:
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// Core AttributeTrait
//===----------------------------------------------------------------------===//
/// This trait is used to determine if an attribute is mutable or not. It is
/// attached on an attribute if the corresponding ImplType defines a `mutate`
/// function with proper signature.
namespace AttributeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
} // namespace AttributeTrait
} // namespace mlir.
namespace llvm {

View File

@ -53,6 +53,16 @@ protected:
}
};
namespace StorageUserTrait {
/// This trait is used to determine if a storage user, like Type, is mutable
/// or not. A storage user is mutable if ImplType of the derived class defines
/// a `mutate` function with a proper signature. Note that this trait is not
/// supposed to be used publicly. Users should use alias names like
/// `TypeTrait::IsMutable` instead.
template <typename ConcreteType>
struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
} // namespace StorageUserTrait
//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
@ -173,6 +183,10 @@ protected:
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {
static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
ConcreteT>::value,
"The `mutate` function expects mutable trait "
"(e.g. TypeTrait::IsMutable) to be attached on parent.");
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
std::forward<Args>(args)...);
}

View File

@ -222,6 +222,18 @@ private:
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// Core TypeTrait
//===----------------------------------------------------------------------===//
/// This trait is used to determine if a type is mutable or not. It is attached
/// on a type if the corresponding ImplType defines a `mutate` function with
/// a proper signature.
namespace TypeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
} // namespace TypeTrait
//===----------------------------------------------------------------------===//
// Type Utils
//===----------------------------------------------------------------------===//

View File

@ -8,12 +8,16 @@
#include "mlir/IR/SubElementInterfaces.h"
#include "llvm/ADT/DenseSet.h"
using namespace mlir;
template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
function_ref<void(Type)> walkTypesFn,
DenseSet<Attribute> &visitedAttrs,
DenseSet<Type> &visitedTypes) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
@ -21,9 +25,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!attr)
return;
// Avoid infinite recursion when visiting sub attributes later, if this
// is a mutable attribute.
if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
if (!visitedAttrs.insert(attr).second)
return;
}
// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
// Walk this attribute.
walkAttrsFn(attr);
@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!type)
return;
// Avoid infinite recursion when visiting sub types later, if this
// is a mutable type.
if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
if (!visitedTypes.insert(type).second)
return;
}
// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
// Walk this type.
walkTypesFn(type);
@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}
void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}
//===----------------------------------------------------------------------===//

View File

@ -1,11 +1,17 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
// CHECK: !test.test_rec<c, test_rec<c>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
// Make sure walkSubElementType, which is used to generate aliases, doesn't go
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
return
}

View File

@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recType = type.dyn_cast<TestRecursiveType>()) {
if (recType.getName() == "type_to_alias") {
// We only make alias for a specific recursive type.
os << "testrec";
return AliasResult::FinalAlias;
}
}
return AliasResult::NoAlias;
}

View File

@ -21,6 +21,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
TestRecursiveTypeStorage> {
TestRecursiveTypeStorage,
::mlir::SubElementTypeInterface::Trait,
::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;
@ -141,10 +144,16 @@ public:
/// Body getter and setter.
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
::mlir::Type getBody() { return getImpl()->body; }
::mlir::Type getBody() const { return getImpl()->body; }
/// Name/key getter.
::llvm::StringRef getName() { return getImpl()->name; }
void walkImmediateSubElements(
::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
walkTypesFn(getBody());
}
};
} // namespace test

View File

@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
MLIRDialect)
add_subdirectory(Affine)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)

View File

@ -0,0 +1,7 @@
add_mlir_unittest(MLIRLLVMIRTests
LLVMTypeTest.cpp
)
target_link_libraries(MLIRLLVMIRTests
PRIVATE
MLIRLLVMDialect
)

View File

@ -0,0 +1,27 @@
//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test fixure for LLVM dialect tests.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"
class LLVMIRTest : public ::testing::Test {
protected:
LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }
mlir::MLIRContext context;
};
#endif

View File

@ -0,0 +1,20 @@
//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "LLVMTestBase.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/SubElementInterfaces.h"
using namespace mlir;
using namespace mlir::LLVM;
TEST_F(LLVMIRTest, IsStructTypeMutable) {
auto structTy = LLVMStructType::getIdentified(&context, "foo");
ASSERT_TRUE(bool(structTy));
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
}