forked from OSchip/llvm-project
Add quant.const_fake_quant_per_axis op
Comparing to the existing quant.const_fake_quant op, the min and max attributes of this new op is for each channel of last dimension of the input. PiperOrigin-RevId: 268093722
This commit is contained in:
parent
d3a6dbc0b8
commit
f4ae4762bf
|
@ -132,6 +132,38 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis",
|
||||||
|
[SameOperandsAndResultType, NoSideEffect]> {
|
||||||
|
let summary =
|
||||||
|
"Simulates the effect of per axis uniform quantization with const range.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Given a const min, max, num_bits and narrow_range attribute, applies the
|
||||||
|
same per axis uniform quantization simulation as is done by the TensorFlow
|
||||||
|
fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType()
|
||||||
|
utility method and the quant-convert-simulated-quantization pass for futher
|
||||||
|
details.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
F32Tensor:$inputs,
|
||||||
|
F32ArrayAttr:$min,
|
||||||
|
F32ArrayAttr:$max,
|
||||||
|
// The quantized dimension of the inputs tensor.
|
||||||
|
I64Attr:$axis,
|
||||||
|
// The bitwidth of the quantization; between 2 and 16, inclusive.
|
||||||
|
I64Attr:$num_bits,
|
||||||
|
// Quantization range starts from 0 or 1; starts from 1 if true.
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
|
||||||
|
// The sign of the quantization.
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$is_signed
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
F32Tensor:$outputs
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
|
def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
|
||||||
let summary =
|
let summary =
|
||||||
"Indicates that statistics are resolved by reference.";
|
"Indicates that statistics are resolved by reference.";
|
||||||
|
|
|
@ -15,6 +15,21 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||||
return %2 : tensor<8x4x3xf32>
|
return %2 : tensor<8x4x3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: validConstFakeQuantPerAxis
|
||||||
|
func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
|
||||||
|
%0 = "quant.const_fake_quant_per_axis"(%arg0) {
|
||||||
|
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true
|
||||||
|
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||||
|
%1 = "quant.const_fake_quant_per_axis"(%0) {
|
||||||
|
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false
|
||||||
|
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||||
|
%2 = "quant.const_fake_quant_per_axis"(%1) {
|
||||||
|
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8
|
||||||
|
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||||
|
return %2 : tensor<8x4x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: validStatisticsRef
|
// CHECK-LABEL: validStatisticsRef
|
||||||
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||||
|
|
Loading…
Reference in New Issue