Allow std.constant to hold a boolean value.

This was an oversight in the original implementation, std.constant already supports IntegerAttr just not BoolAttr.

PiperOrigin-RevId: 259467710
This commit is contained in:
River Riddle 2019-07-22 21:43:14 -07:00 committed by A. Unique TensorFlower
parent b5f8a4be27
commit 42a767b23d
3 changed files with 14 additions and 8 deletions

View File

@ -1067,7 +1067,7 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
if (type.isa<IndexType>())
if (type.isa<IndexType>() || value.isa<BoolAttr>())
return success();
if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
@ -1114,8 +1114,7 @@ static LogicalResult verify(ConstantOp &op) {
if (type.isa<NoneType>() && value.isa<UnitAttr>())
return success();
return op.emitOpError(
"requires a result type that aligns with the 'value' attribute");
return op.emitOpError("unsupported 'value' attribute: ") << value;
}
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
@ -1133,8 +1132,9 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (value.getType() != type)
return false;
// Finally, check that the attribute kind is handled.
return value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
value.isa<ElementsAttr>() || value.isa<UnitAttr>();
return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
value.isa<UnitAttr>();
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,

View File

@ -285,11 +285,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
// CHECK: = constant unit
%73 = constant unit
// CHECK: constant true
%74 = constant true
// CHECK: constant false
%75 = constant false
// CHECK: = index_cast {{.*}} : index to i64
%74 = index_cast %idx : index to i64
%76 = index_cast %idx : index to i64
// CHECK: = index_cast {{.*}} : i32 to index
%75 = index_cast %i : i32 to index
%77 = index_cast %i : i32 to index
return
}

View File

@ -34,7 +34,7 @@ func @rank(f32) {
func @constant() {
^bb:
%x = "std.constant"(){value = "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}}
%x = "std.constant"(){value = "xyz"} : () -> i32 // expected-error {{unsupported 'value' attribute}}
return
}