diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index 394d3a18ced2..d95b45276074 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -132,6 +132,38 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant", ); } +def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis", + [SameOperandsAndResultType, NoSideEffect]> { + let summary = + "Simulates the effect of per axis uniform quantization with const range."; + + let description = [{ + Given a const min, max, num_bits and narrow_range attribute, applies the + same per axis uniform quantization simulation as is done by the TensorFlow + fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType() + utility method and the quant-convert-simulated-quantization pass for futher + details. + }]; + + let arguments = (ins + F32Tensor:$inputs, + F32ArrayAttr:$min, + F32ArrayAttr:$max, + // The quantized dimension of the inputs tensor. + I64Attr:$axis, + // The bitwidth of the quantization; between 2 and 16, inclusive. + I64Attr:$num_bits, + // Quantization range starts from 0 or 1; starts from 1 if true. + DefaultValuedAttr:$narrow_range, + // The sign of the quantization. + DefaultValuedAttr:$is_signed + ); + + let results = (outs + F32Tensor:$outputs + ); +} + def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> { let summary = "Indicates that statistics are resolved by reference."; diff --git a/mlir/test/Dialect/QuantOps/parse-ops.mlir b/mlir/test/Dialect/QuantOps/parse-ops.mlir index 77968f8011b5..7d6d1abb2538 100644 --- a/mlir/test/Dialect/QuantOps/parse-ops.mlir +++ b/mlir/test/Dialect/QuantOps/parse-ops.mlir @@ -15,6 +15,21 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { return %2 : tensor<8x4x3xf32> } +// ----- +// CHECK-LABEL: validConstFakeQuantPerAxis +func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> { + %0 = "quant.const_fake_quant_per_axis"(%arg0) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + %1 = "quant.const_fake_quant_per_axis"(%0) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + %2 = "quant.const_fake_quant_per_axis"(%1) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8 + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + return %2 : tensor<8x4x2xf32> +} + // ----- // CHECK-LABEL: validStatisticsRef func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {