forked from OSchip/llvm-project
Add quant.const_fake_quant_per_axis op
Comparing to the existing quant.const_fake_quant op, the min and max attributes of this new op is for each channel of last dimension of the input. PiperOrigin-RevId: 268093722
This commit is contained in:
parent
d3a6dbc0b8
commit
f4ae4762bf
|
@ -132,6 +132,38 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
|
|||
);
|
||||
}
|
||||
|
||||
def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis",
|
||||
[SameOperandsAndResultType, NoSideEffect]> {
|
||||
let summary =
|
||||
"Simulates the effect of per axis uniform quantization with const range.";
|
||||
|
||||
let description = [{
|
||||
Given a const min, max, num_bits and narrow_range attribute, applies the
|
||||
same per axis uniform quantization simulation as is done by the TensorFlow
|
||||
fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType()
|
||||
utility method and the quant-convert-simulated-quantization pass for futher
|
||||
details.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$inputs,
|
||||
F32ArrayAttr:$min,
|
||||
F32ArrayAttr:$max,
|
||||
// The quantized dimension of the inputs tensor.
|
||||
I64Attr:$axis,
|
||||
// The bitwidth of the quantization; between 2 and 16, inclusive.
|
||||
I64Attr:$num_bits,
|
||||
// Quantization range starts from 0 or 1; starts from 1 if true.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
|
||||
// The sign of the quantization.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_signed
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$outputs
|
||||
);
|
||||
}
|
||||
|
||||
def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> {
|
||||
let summary =
|
||||
"Indicates that statistics are resolved by reference.";
|
||||
|
|
|
@ -15,6 +15,21 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
|||
return %2 : tensor<8x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: validConstFakeQuantPerAxis
|
||||
func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> {
|
||||
%0 = "quant.const_fake_quant_per_axis"(%arg0) {
|
||||
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true
|
||||
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||
%1 = "quant.const_fake_quant_per_axis"(%0) {
|
||||
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false
|
||||
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||
%2 = "quant.const_fake_quant_per_axis"(%1) {
|
||||
min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8
|
||||
} : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32>
|
||||
return %2 : tensor<8x4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: validStatisticsRef
|
||||
func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
|
||||
|
|
Loading…
Reference in New Issue