forked from OSchip/llvm-project
[mlir][tosa] Add folder for tosa.cast
Tosa.cast should fold on splats as it is trivial to fold the operation into the splatted value. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D132518
This commit is contained in:
parent
43e1fc58dd
commit
088f15e346
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
|
@ -687,6 +688,63 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
|
|||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getInput().getType() == getType())
|
||||
return getInput();
|
||||
|
||||
auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
|
||||
if (!operand)
|
||||
return {};
|
||||
|
||||
auto inTy = getInput().getType().cast<ShapedType>();
|
||||
auto outTy = getType().cast<ShapedType>();
|
||||
auto inETy = inTy.getElementType();
|
||||
auto outETy = outTy.getElementType();
|
||||
|
||||
if (operand.isSplat()) {
|
||||
if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) {
|
||||
bool overflow;
|
||||
auto splatVal = operand.getSplatValue<APFloat>();
|
||||
auto &semantics = outETy.cast<FloatType>().getFloatSemantics();
|
||||
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
|
||||
&overflow);
|
||||
return SplatElementsAttr::get(outTy, splatVal);
|
||||
}
|
||||
|
||||
if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) {
|
||||
auto unsign = inETy.cast<IntegerType>().isUnsignedInteger();
|
||||
APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics());
|
||||
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
|
||||
llvm::RoundingMode::NearestTiesToEven);
|
||||
return SplatElementsAttr::get(outTy, splatVal);
|
||||
}
|
||||
|
||||
if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) {
|
||||
auto unsign = outETy.cast<IntegerType>().isUnsignedInteger();
|
||||
auto intVal =
|
||||
APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign);
|
||||
auto floatVal = operand.getSplatValue<APFloat>();
|
||||
bool exact;
|
||||
floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
|
||||
return SplatElementsAttr::get(outTy, intVal);
|
||||
}
|
||||
|
||||
if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) {
|
||||
auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger();
|
||||
bool trunc =
|
||||
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
|
||||
auto intVal = operand.getSplatValue<APInt>();
|
||||
auto bitwidth = outETy.getIntOrFloatBitWidth();
|
||||
|
||||
if (trunc) {
|
||||
intVal = intVal.trunc(bitwidth);
|
||||
} else if (unsignIn) {
|
||||
intVal = intVal.zext(bitwidth);
|
||||
} else {
|
||||
intVal = intVal.sext(bitwidth);
|
||||
}
|
||||
|
||||
return SplatElementsAttr::get(outTy, intVal);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -427,3 +427,58 @@ func.func @slice_singleton() -> tensor<1x1xi32> {
|
|||
// CHECK: return %[[SLICE]]
|
||||
return %slice : tensor<1x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @cast_float_to_float
|
||||
func.func @cast_float_to_float() -> tensor<f16> {
|
||||
%splat = "tosa.const"() {value = dense<42.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.200000e+01> : tensor<f16>} : () -> tensor<f16>
|
||||
%cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<f16>
|
||||
// CHECK: return %[[SPLAT]]
|
||||
return %cast : tensor<f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @cast_int_to_float
|
||||
func.func @cast_int_to_float() -> tensor<f16> {
|
||||
%splat = "tosa.const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<4.000000e+00> : tensor<f16>} : () -> tensor<f16>
|
||||
%cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<f16>
|
||||
// CHECK: return %[[SPLAT]]
|
||||
return %cast : tensor<f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @cast_float_to_int
|
||||
func.func @cast_float_to_int() -> tensor<i16> {
|
||||
%splat = "tosa.const"() {value = dense<-4.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-4> : tensor<i16>} : () -> tensor<i16>
|
||||
%cast = "tosa.cast"(%splat) : (tensor<f32>) -> tensor<i16>
|
||||
// CHECK: return %[[SPLAT]]
|
||||
return %cast : tensor<i16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @cast_int_to_int_trunc
|
||||
func.func @cast_int_to_int_trunc() -> tensor<i16> {
|
||||
%splat = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
|
||||
%cast = "tosa.cast"(%splat) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK: return %[[SPLAT]]
|
||||
return %cast : tensor<i16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @cast_int_to_int_sign
|
||||
func.func @cast_int_to_int_sign() -> tensor<i32> {
|
||||
%splat = "tosa.const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
|
||||
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%cast = "tosa.cast"(%splat) : (tensor<i16>) -> tensor<i32>
|
||||
// CHECK: return %[[SPLAT]]
|
||||
return %cast : tensor<i32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue