Allow std.constant to hold a boolean value.

This was an oversight in the original implementation, std.constant already supports IntegerAttr just not BoolAttr.

PiperOrigin-RevId: 259467710
This commit is contained in:
River Riddle 2019-07-22 21:43:14 -07:00 committed by A. Unique TensorFlower
parent b5f8a4be27
commit 42a767b23d
3 changed files with 14 additions and 8 deletions

View File

@ -1067,7 +1067,7 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError() << "requires attribute's type (" << value.getType() return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")"; << ") to match op's return type (" << type << ")";
if (type.isa<IndexType>()) if (type.isa<IndexType>() || value.isa<BoolAttr>())
return success(); return success();
if (auto intAttr = value.dyn_cast<IntegerAttr>()) { if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
@ -1114,8 +1114,7 @@ static LogicalResult verify(ConstantOp &op) {
if (type.isa<NoneType>() && value.isa<UnitAttr>()) if (type.isa<NoneType>() && value.isa<UnitAttr>())
return success(); return success();
return op.emitOpError( return op.emitOpError("unsupported 'value' attribute: ") << value;
"requires a result type that aligns with the 'value' attribute");
} }
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
@ -1133,8 +1132,9 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (value.getType() != type) if (value.getType() != type)
return false; return false;
// Finally, check that the attribute kind is handled. // Finally, check that the attribute kind is handled.
return value.isa<IntegerAttr>() || value.isa<FloatAttr>() || return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
value.isa<ElementsAttr>() || value.isa<UnitAttr>(); value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
value.isa<UnitAttr>();
} }
void ConstantFloatOp::build(Builder *builder, OperationState *result, void ConstantFloatOp::build(Builder *builder, OperationState *result,

View File

@ -285,11 +285,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
// CHECK: = constant unit // CHECK: = constant unit
%73 = constant unit %73 = constant unit
// CHECK: constant true
%74 = constant true
// CHECK: constant false
%75 = constant false
// CHECK: = index_cast {{.*}} : index to i64 // CHECK: = index_cast {{.*}} : index to i64
%74 = index_cast %idx : index to i64 %76 = index_cast %idx : index to i64
// CHECK: = index_cast {{.*}} : i32 to index // CHECK: = index_cast {{.*}} : i32 to index
%75 = index_cast %i : i32 to index %77 = index_cast %i : i32 to index
return return
} }

View File

@ -34,7 +34,7 @@ func @rank(f32) {
func @constant() { func @constant() {
^bb: ^bb:
%x = "std.constant"(){value = "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}} %x = "std.constant"(){value = "xyz"} : () -> i32 // expected-error {{unsupported 'value' attribute}}
return return
} }