Add support for floating point constants, fixing b/112707848. This also adds string attribute support.

PiperOrigin-RevId: 209074362
This commit is contained in:
Chris Lattner 2018-08-16 16:56:40 -07:00 committed by jpienaar
parent 98a24881d3
commit 2278bcc891
4 changed files with 51 additions and 1 deletions

View File

@ -156,10 +156,30 @@ protected:
explicit ConstantOp(const Operation *state) : OpBase(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
/// %1 = "constant"(){value: 42.0} : bf16
///
class ConstantFloatOp : public ConstantOp {
public:
static OperationState build(Builder *builder, double value, FloatType *type);
double getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
}
static bool isClassFor(const Operation *op);
private:
friend class Operation;
explicit ConstantFloatOp(const Operation *state) : ConstantOp(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of IntegerType.
///
/// %1 = "constant"(){value: 42}
/// %1 = "constant"(){value: 42} : i32
///
class ConstantIntOp : public ConstantOp {
public:

View File

@ -75,6 +75,7 @@ public:
// derived types should use isa/dyn_cast.
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isTFString() const { return getKind() == Kind::TFString; }
bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; }
bool isF32() const { return getKind() == Kind::F32; }

View File

@ -231,6 +231,18 @@ const char *ConstantOp::verify() const {
return nullptr;
}
if (isa<FloatType>(type)) {
if (!isa<FloatAttr>(value))
return "requires 'value' to be a floating point constant";
return nullptr;
}
if (type->isTFString()) {
if (!isa<StringAttr>(value))
return "requires 'value' to be a string constant";
return nullptr;
}
if (isa<FunctionType>(type)) {
// TODO: Verify a function attr.
}
@ -238,6 +250,20 @@ const char *ConstantOp::verify() const {
return "requires a result type that aligns with the 'value' attribute";
}
OperationState ConstantFloatOp::build(Builder *builder, double value,
FloatType *type) {
OperationState result(builder->getIdentifier("constant"));
result.attributes.push_back(
{builder->getIdentifier("value"), builder->getFloatAttr(value)});
result.types.push_back(type);
return result;
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
isa<FloatType>(op->getResult(0)->getType());
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&

View File

@ -48,6 +48,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32):
// CHECK: %c43 = constant 43 {crazy: "foo"} : affineint
%8 = constant 43 {crazy: "foo"} : affineint
// CHECK: %4 = constant 4.300000e+01 : bf16
%9 = constant 43.0 : bf16
return
}