forked from OSchip/llvm-project
Rename shape_cast to tensor_cast.
"shape_cast" only applies to tensors, and there are other operations that actually affect shape, for example "reshape". Rename "shape_cast" to "tensor_cast" in both the code and the documentation. PiperOrigin-RevId: 218528122
This commit is contained in:
parent
c1b0918617
commit
e8d254b909
|
@ -1847,27 +1847,27 @@ TODO: In the distant future, this will accept
|
|||
optional attributes for fast math, contraction, rounding mode, and other
|
||||
controls.
|
||||
|
||||
#### 'shape_cast' operation {#'shape_cast'-operation}
|
||||
#### 'tensor_cast' operation {#'tensor_cast'-operation}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.mlir}
|
||||
operation ::= ssa-id `=` `shape_cast` ssa-use `:` type `to` type
|
||||
operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
``` {.mlir}
|
||||
// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
%2 = "shape_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
|
||||
%2 = shape_cast %1 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%2 = "tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
|
||||
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
|
||||
|
||||
// Convert to a type with more known dimensions.
|
||||
%3 = "shape_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
|
||||
%3 = "tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
|
||||
|
||||
// Discard static dimension and rank information.
|
||||
%4 = "shape_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
|
||||
%5 = "shape_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||
%4 = "tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
|
||||
%5 = "tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||
```
|
||||
|
||||
Convert a tensor from one type to an equivalent type without changing any data
|
||||
|
|
|
@ -542,32 +542,6 @@ private:
|
|||
explicit MulIOp(const Operation *state) : BinaryOp(state) {}
|
||||
};
|
||||
|
||||
/// The "shape_cast" operation converts a tensor from one type to an equivalent
|
||||
/// type without changing any data elements. The source and destination types
|
||||
/// must both be tensor types with the same element type, and the source and
|
||||
/// destination types may not be the same. They must either have the same rank,
|
||||
/// or one may be an unknown rank. The operation is invalid if converting to a
|
||||
/// mismatching constant dimension.
|
||||
///
|
||||
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
/// %2 = shape_cast %1 : tensor<??f32> to tensor<?x?xf32>
|
||||
///
|
||||
class ShapeCastOp : public CastOp<ShapeCastOp> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "shape_cast"; }
|
||||
|
||||
/// The result of a shape_cast is always a tensor.
|
||||
TensorType *getType() const {
|
||||
return cast<TensorType>(getResult()->getType());
|
||||
}
|
||||
|
||||
bool verify() const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit ShapeCastOp(const Operation *state) : CastOp(state) {}
|
||||
};
|
||||
|
||||
/// The "store" op writes an element to a memref specified by an index list.
|
||||
/// The arity of indices is the rank of the memref (i.e. if the memref being
|
||||
/// stored to is of rank 3, then 3 indices are required for the store following
|
||||
|
@ -651,6 +625,32 @@ private:
|
|||
explicit SubIOp(const Operation *state) : BinaryOp(state) {}
|
||||
};
|
||||
|
||||
/// The "tensor_cast" operation converts a tensor from one type to an equivalent
|
||||
/// type without changing any data elements. The source and destination types
|
||||
/// must both be tensor types with the same element type, and the source and
|
||||
/// destination types may not be the same. They must either have the same rank,
|
||||
/// or one may be an unknown rank. The operation is invalid if converting to a
|
||||
/// mismatching constant dimension.
|
||||
///
|
||||
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
|
||||
///
|
||||
class TensorCastOp : public CastOp<TensorCastOp> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "tensor_cast"; }
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType *getType() const {
|
||||
return cast<TensorType>(getResult()->getType());
|
||||
}
|
||||
|
||||
bool verify() const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit TensorCastOp(const Operation *state) : CastOp(state) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
||||
|
|
|
@ -36,8 +36,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
|||
: Dialect(/*opPrefix=*/"", context) {
|
||||
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, DeallocOp,
|
||||
DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp,
|
||||
MemRefCastOp, MulFOp, MulIOp, ShapeCastOp, StoreOp, SubFOp,
|
||||
SubIOp>();
|
||||
MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
|
||||
TensorCastOp>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -710,10 +710,10 @@ Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeCastOp
|
||||
// TensorCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool ShapeCastOp::verify() const {
|
||||
bool TensorCastOp::verify() const {
|
||||
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
|
||||
auto *resType = dyn_cast<TensorType>(getType());
|
||||
if (!opType || !resType)
|
||||
|
|
|
@ -169,19 +169,19 @@ mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
|
|||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @shape_cast(%arg0
|
||||
mlfunc @shape_cast(%arg0 : tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
|
||||
// CHECK: %0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK-LABEL: mlfunc @tensor_cast(%arg0
|
||||
mlfunc @tensor_cast(%arg0 : tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
|
||||
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
|
||||
// CHECK: %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
%1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
// CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
%1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
|
||||
// CHECK: %2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
%2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
// CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
%2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
|
||||
// CHECK: %3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue