Add quant.const_fake_quant_per_axis op

Comparing to the existing quant.const_fake_quant op, the min and max attributes
of this new op is for each channel of last dimension of the input.

PiperOrigin-RevId: 268093722
This commit is contained in:
Feng Liu 2019-09-09 15:42:07 -07:00 committed by A. Unique TensorFlower
parent d3a6dbc0b8
commit f4ae4762bf
2 changed files with 47 additions and 0 deletions

View File

@ -132,6 +132,38 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
);
}
def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis",
[SameOperandsAndResultType, NoSideEffect]> {
let summary =
"Simulates the effect of per axis uniform quantization with const range.";
let description = [{
Given a const min, max, num_bits and narrow_range attribute, applies the
same per axis uniform quantization simulation as is done by the TensorFlow
fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType()
utility method and the quant-convert-simulated-quantization pass for futher
details.
}];
let arguments = (ins
F32Tensor:$inputs,
F32ArrayAttr:$min,
F32ArrayAttr:$max,
// The quantized dimension of the inputs tensor.
I64Attr:$axis,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I64Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
// The sign of the quantization.
DefaultValuedAttr<BoolAttr, "false">:$is_signed
);
let results = (outs
F32Tensor:$outputs
);
}
def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
let summary =
"Indicates that statistics are resolved by reference.";

View File

@ -15,6 +15,21 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
return %2 : tensor<8x4x3xf32>
}
// -----
// CHECK-LABEL: validConstFakeQuantPerAxis
func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
%0 = "quant.const_fake_quant_per_axis"(%arg0) {
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
%1 = "quant.const_fake_quant_per_axis"(%0) {
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
%2 = "quant.const_fake_quant_per_axis"(%1) {
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
return %2 : tensor<8x4x2xf32>
}
// -----
// CHECK-LABEL: validStatisticsRef
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {