Move QuantTypes out of QuantOps to match the file structures of other dialects

This CL also moved the UniformSupport.cpp and FakeQuantSupport.cpp into utils because they are not really the core of the IR.

--

PiperOrigin-RevId: 244001666
This commit is contained in:
Feng Liu 2019-04-17 08:31:39 -07:00 committed by Mehdi Amini
parent a2e08eb384
commit c9f21cf355
12 changed files with 733 additions and 649 deletions

View File

@ -46,7 +46,7 @@
#ifndef MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
#define MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
namespace mlir {
namespace quant {

View File

@ -29,329 +29,6 @@
namespace mlir {
namespace quant {
class QuantizedIntegerType;
namespace detail {
struct QuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
} // namespace detail
namespace QuantizationTypes {
enum Kind {
UniformQuantized = Type::FIRST_QUANTIZATION_TYPE,
UniformQuantizedPerAxis,
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
};
} // namespace QuantizationTypes
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {
// Indicates that the storage type should be interpreted as a signed
// integer. The default is to interpret it as an unsigned value.
Signed = 1,
};
} // namespace QuantizationFlags
/// Base class for all quantized types known to this dialect.
/// All quantized types have:
/// - storageType: The (narrower) numeric type that is being used to
/// approximate some expressed type.
/// - expressedType: The type that is being approximated.
///
/// The base class provides generic support for manipulating the types based
/// on these fields.
class QuantizedType : public Type {
public:
using ImplType = detail::QuantizedTypeStorage;
using Type::Type;
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantized;
}
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.
static int64_t getDefaultMininumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::minIntN(integralWidth);
}
return 0;
}
/// Gets the maximum possible stored by a storageType. storageTypeMax must
/// be less than or equal to this value.
static int64_t getDefaultMaxinumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::maxIntN(integralWidth);
}
return llvm::maxUIntN(integralWidth);
}
/// Gets the original expressed type that this quantized type approximates.
/// Note that this presumes that the quantized type was always derived from
/// a floating point type, which in the broadest definition, is not true (i.e.
/// it could be some form of integral, fixed type or affine type in its own
/// right); however, at the high level, no examples of such usage are
/// presently known and the restriction serves some useful purposes (such as
/// always being able to reverse a transformation or measure error). In most
/// cases, this will be f32.
Type getExpressedType() const;
/// Gets the flags associated with this type. Typically a more specific
/// accessor is appropriate.
unsigned getFlags() const;
// Convenience helpers.
/// Whether the storage type should be interpreted as a signed quantity
/// (true) or an unsigned value (false).
bool isSigned() const {
return (getFlags() & QuantizationFlags::Signed) ==
QuantizationFlags::Signed;
}
/// Gets the underlying type used for to store values. Note that this may
/// be signed or unsigned. Use the isSigned() accessor to differentiate.
Type getStorageType() const;
/// The minimum value that storageType can take.
int64_t getStorageTypeMin() const;
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
/// Returns whether the candidateExpressedType is a match for this
/// QuantizedType. This will be true if the candidate type is either a
/// primitive type or a container type whose element type equals this
/// QuantizedType's expressed type.
/// Examples of compatible candidateExpressedType:
/// !quant<"uniform[i8:f32]{1.0}"> =~ f32
/// !quant<"uniform[i8:f32]{1.0}"> =~ tensor<4xf32>
bool isCompatibleExpressedType(Type candidateExpressedType);
/// Returns the element type as a QuantizedType or nullptr if it is not
/// a quantized type. If the type is primitive, returns that. If it is a
/// container (vector/tensor), return the element type.
/// Examples:
/// !quant<"uniform[i8:f32]{1.0}"> -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4x!quant<"uniform[i8:f32]{1.0}"> -> quant<"uniform[i8:f32]{1.0}">
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
/// Casts from a type based on the storageType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// i8 -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
/// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
Type castFromStorageType(Type candidateType);
/// Casts from a type based on a QuantizedType to a corresponding type based
/// on the storageType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromStorageType().
static Type castToStorageType(Type quantizedType);
/// Casts from a type based on the expressedType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// f32 -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
/// vector<4xf32> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
Type castFromExpressedType(Type candidateType);
/// Casts from a type based on QuantizedType to a corresponding type based
/// on the expressedType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromExpressedType.
static Type castToExpressedType(Type quantizedType);
/// Casts from a type based on the expressedType to the equivalent type
/// based on storageType by way of this QuantizedType. Equivalent to:
/// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
/// (but with validity checks).
/// Example (for this = !quant<"uniform[i8:f32]{1.0}">):
/// tensor<4xf32> -> tensor<4xi8>
Type castExpressedToStorageType(Type candidateType);
};
/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
/// often expressed in floating point f32) and quantized values (either fixed
/// point or affine).
///
/// The relationship is:
/// real_value = scale * (quantized_value - zero_point)
///
/// It is used as part of high level graph transformations that have the goal
/// of re-expressing parts of a computation in terms of this common form for
/// more efficient execution at runtime. In addition, it is designed to be
/// expressive enough to facilitate lowering to precise types and operations
/// in target hardware.
///
/// As a high-level type, focused on intermediate passes, this type holds
/// opinions consistent with high-level usage. If lowering math kernels below
/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
/// instruction sets), it is expected that the information expressed here
/// will be used to drive low level codegen and target specific type selection,
/// but this type will likely be erased in the process.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedType
: public Type::TypeBase<UniformQuantizedType, QuantizedType,
detail::UniformQuantizedTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedType
getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantized;
}
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
double getScale() const;
/// Gets the storage value corresponding to the real value 0 in the affine
/// equation.
int64_t getZeroPoint() const;
// Fixed point values are real numbers divided by a scale.
// Currently, only signed storage types are treated as fixed point.
// A fixed point value can be obtained from an affine value by subtracting
// the zeroPoint.
// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
};
/// Represents per-axis (also known as per-channel quantization).
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedPerAxisType
: public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
detail::UniformQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedPerAxisType
getChecked(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantizedPerAxis;
}
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
/// by 1. The ith scale corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<double> getScales() const;
/// Gets the storage values corresponding to the real value 0 in the affine
/// equation. The ith zero point corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<int64_t> getZeroPoints() const;
/// Specifies the dimension of the Tensor's shape that the scales and
/// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
/// with quantization params:
/// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
/// will be quantized across the second dimension of t.
/// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
/// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
/// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
int32_t getQuantizedDimension() const;
/// Fixed point values are real numbers divided by a scale.
/// Currently, only signed storage types are treated as fixed point.
/// A fixed point value can be obtained from an affine value by subtracting
/// the zeroPoint.
/// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const {
if (!isSigned())
return false;
return llvm::all_of(getZeroPoints(),
[](int64_t zeroPoint) { return zeroPoint != 0; });
}
};
/// Defines the 'Quantization' dialect
class QuantizationDialect : public Dialect {
public:

View File

@ -0,0 +1,358 @@
//===- Quantization/QuantOps.h - Quantization Ops and Types -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_QUANTIZATION_QUANT_TYPES_H_
#define MLIR_QUANTIZATION_QUANT_TYPES_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace quant {
class QuantizedIntegerType;
namespace detail {
struct QuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
} // namespace detail
namespace QuantizationTypes {
enum Kind {
UniformQuantized = Type::FIRST_QUANTIZATION_TYPE,
UniformQuantizedPerAxis,
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
};
} // namespace QuantizationTypes
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {
// Indicates that the storage type should be interpreted as a signed
// integer. The default is to interpret it as an unsigned value.
Signed = 1,
};
} // namespace QuantizationFlags
/// Base class for all quantized types known to this dialect.
/// All quantized types have:
/// - storageType: The (narrower) numeric type that is being used to
/// approximate some expressed type.
/// - expressedType: The type that is being approximated.
///
/// The base class provides generic support for manipulating the types based
/// on these fields.
class QuantizedType : public Type {
public:
using ImplType = detail::QuantizedTypeStorage;
using Type::Type;
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantized;
}
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.
static int64_t getDefaultMininumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::minIntN(integralWidth);
}
return 0;
}
/// Gets the maximum possible stored by a storageType. storageTypeMax must
/// be less than or equal to this value.
static int64_t getDefaultMaxinumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::maxIntN(integralWidth);
}
return llvm::maxUIntN(integralWidth);
}
/// Gets the original expressed type that this quantized type approximates.
/// Note that this presumes that the quantized type was always derived from
/// a floating point type, which in the broadest definition, is not true (i.e.
/// it could be some form of integral, fixed type or affine type in its own
/// right); however, at the high level, no examples of such usage are
/// presently known and the restriction serves some useful purposes (such as
/// always being able to reverse a transformation or measure error). In most
/// cases, this will be f32.
Type getExpressedType() const;
/// Gets the flags associated with this type. Typically a more specific
/// accessor is appropriate.
unsigned getFlags() const;
// Convenience helpers.
/// Whether the storage type should be interpreted as a signed quantity
/// (true) or an unsigned value (false).
bool isSigned() const {
return (getFlags() & QuantizationFlags::Signed) ==
QuantizationFlags::Signed;
}
/// Gets the underlying type used for to store values. Note that this may
/// be signed or unsigned. Use the isSigned() accessor to differentiate.
Type getStorageType() const;
/// The minimum value that storageType can take.
int64_t getStorageTypeMin() const;
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
/// Returns whether the candidateExpressedType is a match for this
/// QuantizedType. This will be true if the candidate type is either a
/// primitive type or a container type whose element type equals this
/// QuantizedType's expressed type.
/// Examples of compatible candidateExpressedType:
/// !quant<"uniform[i8:f32]{1.0}"> =~ f32
/// !quant<"uniform[i8:f32]{1.0}"> =~ tensor<4xf32>
bool isCompatibleExpressedType(Type candidateExpressedType);
/// Returns the element type as a QuantizedType or nullptr if it is not
/// a quantized type. If the type is primitive, returns that. If it is a
/// container (vector/tensor), return the element type.
/// Examples:
/// !quant<"uniform[i8:f32]{1.0}"> -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4x!quant<"uniform[i8:f32]{1.0}"> -> quant<"uniform[i8:f32]{1.0}">
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
/// Casts from a type based on the storageType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// i8 -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
/// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
Type castFromStorageType(Type candidateType);
/// Casts from a type based on a QuantizedType to a corresponding type based
/// on the storageType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromStorageType().
static Type castToStorageType(Type quantizedType);
/// Casts from a type based on the expressedType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// f32 -> !quant<"uniform[i8:f32]{1.0}">
/// tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
/// vector<4xf32> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
Type castFromExpressedType(Type candidateType);
/// Casts from a type based on QuantizedType to a corresponding type based
/// on the expressedType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromExpressedType.
static Type castToExpressedType(Type quantizedType);
/// Casts from a type based on the expressedType to the equivalent type
/// based on storageType by way of this QuantizedType. Equivalent to:
/// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
/// (but with validity checks).
/// Example (for this = !quant<"uniform[i8:f32]{1.0}">):
/// tensor<4xf32> -> tensor<4xi8>
Type castExpressedToStorageType(Type candidateType);
};
/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
/// often expressed in floating point f32) and quantized values (either fixed
/// point or affine).
///
/// The relationship is:
/// real_value = scale * (quantized_value - zero_point)
///
/// It is used as part of high level graph transformations that have the goal
/// of re-expressing parts of a computation in terms of this common form for
/// more efficient execution at runtime. In addition, it is designed to be
/// expressive enough to facilitate lowering to precise types and operations
/// in target hardware.
///
/// As a high-level type, focused on intermediate passes, this type holds
/// opinions consistent with high-level usage. If lowering math kernels below
/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
/// instruction sets), it is expected that the information expressed here
/// will be used to drive low level codegen and target specific type selection,
/// but this type will likely be erased in the process.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedType
: public Type::TypeBase<UniformQuantizedType, QuantizedType,
detail::UniformQuantizedTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedType
getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantized;
}
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
double getScale() const;
/// Gets the storage value corresponding to the real value 0 in the affine
/// equation.
int64_t getZeroPoint() const;
// Fixed point values are real numbers divided by a scale.
// Currently, only signed storage types are treated as fixed point.
// A fixed point value can be obtained from an affine value by subtracting
// the zeroPoint.
// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
};
/// Represents per-axis (also known as per-channel quantization).
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedPerAxisType
: public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
detail::UniformQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedPerAxisType
getChecked(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == QuantizationTypes::UniformQuantizedPerAxis;
}
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
/// by 1. The ith scale corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<double> getScales() const;
/// Gets the storage values corresponding to the real value 0 in the affine
/// equation. The ith zero point corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<int64_t> getZeroPoints() const;
/// Specifies the dimension of the Tensor's shape that the scales and
/// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
/// with quantization params:
/// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
/// will be quantized across the second dimension of t.
/// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
/// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
/// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
int32_t getQuantizedDimension() const;
/// Fixed point values are real numbers divided by a scale.
/// Currently, only signed storage types are treated as fixed point.
/// A fixed point value can be obtained from an affine value by subtracting
/// the zeroPoint.
/// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const {
if (!isSigned())
return false;
return llvm::all_of(getZeroPoints(),
[](int64_t zeroPoint) { return zeroPoint != 0; });
}
};
} // namespace quant
} // namespace mlir
#endif // MLIR_QUANTIZATION_QUANT_TYPES_H_

View File

@ -20,7 +20,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"

View File

@ -18,7 +18,7 @@
#include "mlir/FxpMathOps/FxpMathOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"

View File

@ -19,6 +19,7 @@
#include "mlir/FxpMathOps/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/UniformSupport.h"
#include <functional>

View File

@ -1,13 +1,14 @@
add_llvm_library(MLIRQuantization
IR/DialectRegistration.cpp
IR/FakeQuantSupport.cpp
IR/QuantOps.cpp
IR/QuantTypes.cpp
IR/TypeDetail.h
IR/TypeParser.cpp
IR/UniformSupport.cpp
Transforms/ConvertConst.cpp
Transforms/ConvertSimQuant.cpp
Utils/QuantizeUtils.cpp
Utils/UniformSupport.cpp
Utils/FakeQuantSupport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Quantization

View File

@ -19,6 +19,7 @@
#include "TypeDetail.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Quantization/QuantTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
@ -27,326 +28,6 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(type)->flags;
}
LogicalResult QuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>()) {
if (loc) {
context->emitError(*loc, "expressed type must be floating point");
}
return failure();
}
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType) {
if (loc) {
context->emitError(*loc, "storage type must be integral");
}
return failure();
}
unsigned integralWidth = intStorageType.getWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits) {
if (loc) {
context->emitError(*loc,
"illegal storage type size: " + Twine(integralWidth));
}
return failure();
}
// Verify storageTypeMin and storageTypeMax.
bool isSigned =
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
int64_t defaultIntegerMin =
getDefaultMininumForInteger(isSigned, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaxinumForInteger(isSigned, integralWidth);
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
if (loc) {
context->emitError(*loc, "illegal storage min and storage max: (" +
Twine(storageTypeMin) + ":" +
Twine(storageTypeMax) + ")");
}
return failure();
}
return success();
}
Type QuantizedType::getStorageType() const {
return static_cast<ImplType *>(type)->storageType;
}
int64_t QuantizedType::getStorageTypeMin() const {
return static_cast<ImplType *>(type)->storageTypeMin;
}
int64_t QuantizedType::getStorageTypeMax() const {
return static_cast<ImplType *>(type)->storageTypeMax;
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
}
Type QuantizedType::getExpressedType() const {
return static_cast<ImplType *>(type)->expressedType;
}
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
if (candidateExpressedType.isa<VectorOrTensorType>()) {
return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
getExpressedType();
}
return candidateExpressedType == getExpressedType();
}
QuantizedType
QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
Type elementType =
primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
return elementType.dyn_cast<QuantizedType>();
}
return primitiveOrContainerType.dyn_cast<QuantizedType>();
}
Type QuantizedType::castFromStorageType(Type candidateType) {
if (candidateType == getStorageType()) {
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
} else if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(
candidateType.cast<RankedTensorType>().getShape(), getStorageType());
} else if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(getStorageType());
} else if (candidateType.isa<VectorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(candidateType.cast<VectorType>().getShape(),
getStorageType());
}
return nullptr;
}
Type QuantizedType::castToStorageType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
return quantizedType.cast<QuantizedType>().getStorageType();
} else if (quantizedType.isa<VectorOrTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
if (!vtType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type storageType =
vtType.getElementType().cast<QuantizedType>().getStorageType();
if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(vtType.getShape(), storageType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(storageType);
} else if (quantizedType.isa<VectorType>()) {
return VectorType::get(vtType.getShape(), storageType);
}
}
return nullptr;
}
Type QuantizedType::castFromExpressedType(Type candidateType) {
if (candidateType == getExpressedType()) {
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
} else if (candidateType.isa<VectorOrTensorType>()) {
VectorOrTensorType candidateVtType =
candidateType.cast<VectorOrTensorType>();
if (candidateVtType.getElementType() != getExpressedType()) {
return nullptr;
}
if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(candidateVtType.getShape(), *this);
} else if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(*this);
} else if (candidateType.isa<VectorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(candidateVtType.getShape(), *this);
}
}
return nullptr;
}
Type QuantizedType::castToExpressedType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
return quantizedType.cast<QuantizedType>().getExpressedType();
} else if (quantizedType.isa<VectorOrTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
if (!vtType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type expressedType =
vtType.getElementType().cast<QuantizedType>().getExpressedType();
if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(vtType.getShape(), expressedType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(expressedType);
} else if (quantizedType.isa<VectorType>()) {
return VectorType::get(vtType.getShape(), expressedType);
}
}
return nullptr;
}
Type QuantizedType::castExpressedToStorageType(Type candidateType) {
Type expressedQuantizedType = castFromExpressedType(candidateType);
if (!expressedQuantizedType) {
return nullptr;
}
return QuantizedType::castToStorageType(expressedQuantizedType);
}
UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint,
int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantized, flags, storageType,
expressedType, scale, zeroPoint, storageTypeMin,
storageTypeMax);
}
UniformQuantizedType
UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantized, flags,
storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
// Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
context->emitError(*loc,
"illegal scale: " + Twine(std::to_string(scale)));
}
return failure();
}
return success();
}
double UniformQuantizedType::getScale() const { return getImpl()->scale; }
int64_t UniformQuantizedType::getZeroPoint() const {
return getImpl()->zeroPoint;
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
// Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size()) {
if (loc) {
context->emitError(*loc, "illegal number of scales and zeroPoints: " +
Twine(scales.size()) + ", " +
Twine(zeroPoints.size()));
}
return failure();
}
// Verify scale.
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
context->emitError(*loc,
"illegal scale: " + Twine(std::to_string(scale)));
}
return failure();
}
}
return success();
}
ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
return getImpl()->getScales();
}
ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
return getImpl()->getZeroPoints();
}
int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
return getImpl()->quantizedDimension;
}
#define GET_OP_CLASSES
#include "mlir/Quantization/QuantOps.cpp.inc"

View File

@ -0,0 +1,348 @@
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Quantization/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(type)->flags;
}
LogicalResult QuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>()) {
if (loc) {
context->emitError(*loc, "expressed type must be floating point");
}
return failure();
}
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType) {
if (loc) {
context->emitError(*loc, "storage type must be integral");
}
return failure();
}
unsigned integralWidth = intStorageType.getWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits) {
if (loc) {
context->emitError(*loc,
"illegal storage type size: " + Twine(integralWidth));
}
return failure();
}
// Verify storageTypeMin and storageTypeMax.
bool isSigned =
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
int64_t defaultIntegerMin =
getDefaultMininumForInteger(isSigned, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaxinumForInteger(isSigned, integralWidth);
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
if (loc) {
context->emitError(*loc, "illegal storage min and storage max: (" +
Twine(storageTypeMin) + ":" +
Twine(storageTypeMax) + ")");
}
return failure();
}
return success();
}
Type QuantizedType::getStorageType() const {
return static_cast<ImplType *>(type)->storageType;
}
int64_t QuantizedType::getStorageTypeMin() const {
return static_cast<ImplType *>(type)->storageTypeMin;
}
int64_t QuantizedType::getStorageTypeMax() const {
return static_cast<ImplType *>(type)->storageTypeMax;
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
}
Type QuantizedType::getExpressedType() const {
return static_cast<ImplType *>(type)->expressedType;
}
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
if (candidateExpressedType.isa<VectorOrTensorType>()) {
return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
getExpressedType();
}
return candidateExpressedType == getExpressedType();
}
QuantizedType
QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
Type elementType =
primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
return elementType.dyn_cast<QuantizedType>();
}
return primitiveOrContainerType.dyn_cast<QuantizedType>();
}
Type QuantizedType::castFromStorageType(Type candidateType) {
if (candidateType == getStorageType()) {
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
} else if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(
candidateType.cast<RankedTensorType>().getShape(), getStorageType());
} else if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(getStorageType());
} else if (candidateType.isa<VectorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(candidateType.cast<VectorType>().getShape(),
getStorageType());
}
return nullptr;
}
Type QuantizedType::castToStorageType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
return quantizedType.cast<QuantizedType>().getStorageType();
} else if (quantizedType.isa<VectorOrTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
if (!vtType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type storageType =
vtType.getElementType().cast<QuantizedType>().getStorageType();
if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(vtType.getShape(), storageType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(storageType);
} else if (quantizedType.isa<VectorType>()) {
return VectorType::get(vtType.getShape(), storageType);
}
}
return nullptr;
}
Type QuantizedType::castFromExpressedType(Type candidateType) {
if (candidateType == getExpressedType()) {
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
} else if (candidateType.isa<VectorOrTensorType>()) {
VectorOrTensorType candidateVtType =
candidateType.cast<VectorOrTensorType>();
if (candidateVtType.getElementType() != getExpressedType()) {
return nullptr;
}
if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(candidateVtType.getShape(), *this);
} else if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(*this);
} else if (candidateType.isa<VectorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(candidateVtType.getShape(), *this);
}
}
return nullptr;
}
Type QuantizedType::castToExpressedType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
return quantizedType.cast<QuantizedType>().getExpressedType();
} else if (quantizedType.isa<VectorOrTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
if (!vtType.getElementType().isa<QuantizedType>()) {
return nullptr;
}
Type expressedType =
vtType.getElementType().cast<QuantizedType>().getExpressedType();
if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(vtType.getShape(), expressedType);
} else if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(expressedType);
} else if (quantizedType.isa<VectorType>()) {
return VectorType::get(vtType.getShape(), expressedType);
}
}
return nullptr;
}
Type QuantizedType::castExpressedToStorageType(Type candidateType) {
Type expressedQuantizedType = castFromExpressedType(candidateType);
if (!expressedQuantizedType) {
return nullptr;
}
return QuantizedType::castToStorageType(expressedQuantizedType);
}
UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint,
int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantized, flags, storageType,
expressedType, scale, zeroPoint, storageTypeMin,
storageTypeMax);
}
UniformQuantizedType
UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantized, flags,
storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
// Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
context->emitError(*loc,
"illegal scale: " + Twine(std::to_string(scale)));
}
return failure();
}
return success();
}
double UniformQuantizedType::getScale() const { return getImpl()->scale; }
int64_t UniformQuantizedType::getZeroPoint() const {
return getImpl()->zeroPoint;
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
// Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size()) {
if (loc) {
context->emitError(*loc, "illegal number of scales and zeroPoints: " +
Twine(scales.size()) + ", " +
Twine(zeroPoints.size()));
}
return failure();
}
// Verify scale.
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
context->emitError(*loc,
"illegal scale: " + Twine(std::to_string(scale)));
}
return failure();
}
}
return success();
}
ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
return getImpl()->getScales();
}
ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
return getImpl()->getZeroPoints();
}
int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
return getImpl()->quantizedDimension;
}

View File

@ -19,6 +19,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Format.h"

View File

@ -1,5 +1,22 @@
//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Quantization/FakeQuantSupport.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
using namespace mlir;
using namespace mlir::quant;