Rename shape_cast to tensor_cast.

"shape_cast" only applies to tensors, and there are other operations that
actually affect shape, for example "reshape".  Rename "shape_cast" to
"tensor_cast" in both the code and the documentation.

PiperOrigin-RevId: 218528122
This commit is contained in:
Alex Zinenko 2018-10-24 09:52:06 -07:00 committed by jpienaar
parent c1b0918617
commit e8d254b909
4 changed files with 47 additions and 47 deletions

View File

@ -1847,27 +1847,27 @@ TODO: In the distant future, this will accept
optional attributes for fast math, contraction, rounding mode, and other
controls.
#### 'shape_cast' operation {#'shape_cast'-operation}
#### 'tensor_cast' operation {#'tensor_cast'-operation}
Syntax:
``` {.mlir}
operation ::= ssa-id `=` `shape_cast` ssa-use `:` type `to` type
operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
```
Examples:
``` {.mlir}
// Convert from unknown rank to rank 2 with unknown dimension sizes.
%2 = "shape_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
%2 = shape_cast %1 : tensor<*xf32> to tensor<?x?xf32>
%2 = "tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
// Convert to a type with more known dimensions.
%3 = "shape_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
%3 = "tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
// Discard static dimension and rank information.
%4 = "shape_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
%5 = "shape_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
%4 = "tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
%5 = "tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
```
Convert a tensor from one type to an equivalent type without changing any data

View File

@ -542,32 +542,6 @@ private:
explicit MulIOp(const Operation *state) : BinaryOp(state) {}
};
/// The "shape_cast" operation converts a tensor from one type to an equivalent
/// type without changing any data elements. The source and destination types
/// must both be tensor types with the same element type, and the source and
/// destination types may not be the same. They must either have the same rank,
/// or one may be an unknown rank. The operation is invalid if converting to a
/// mismatching constant dimension.
///
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
/// %2 = shape_cast %1 : tensor<??f32> to tensor<?x?xf32>
///
class ShapeCastOp : public CastOp<ShapeCastOp> {
public:
static StringRef getOperationName() { return "shape_cast"; }
/// The result of a shape_cast is always a tensor.
TensorType *getType() const {
return cast<TensorType>(getResult()->getType());
}
bool verify() const;
private:
friend class Operation;
explicit ShapeCastOp(const Operation *state) : CastOp(state) {}
};
/// The "store" op writes an element to a memref specified by an index list.
/// The arity of indices is the rank of the memref (i.e. if the memref being
/// stored to is of rank 3, then 3 indices are required for the store following
@ -651,6 +625,32 @@ private:
explicit SubIOp(const Operation *state) : BinaryOp(state) {}
};
/// The "tensor_cast" operation converts a tensor from one type to an equivalent
/// type without changing any data elements. The source and destination types
/// must both be tensor types with the same element type, and the source and
/// destination types may not be the same. They must either have the same rank,
/// or one may be an unknown rank. The operation is invalid if converting to a
/// mismatching constant dimension.
///
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
///
class TensorCastOp : public CastOp<TensorCastOp> {
public:
static StringRef getOperationName() { return "tensor_cast"; }
/// The result of a tensor_cast is always a tensor.
TensorType *getType() const {
return cast<TensorType>(getResult()->getType());
}
bool verify() const;
private:
friend class Operation;
explicit TensorCastOp(const Operation *state) : CastOp(state) {}
};
} // end namespace mlir
#endif

View File

@ -36,8 +36,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*opPrefix=*/"", context) {
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, DeallocOp,
DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp,
MemRefCastOp, MulFOp, MulIOp, ShapeCastOp, StoreOp, SubFOp,
SubIOp>();
MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
TensorCastOp>();
}
//===----------------------------------------------------------------------===//
@ -710,10 +710,10 @@ Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
}
//===----------------------------------------------------------------------===//
// ShapeCastOp
// TensorCastOp
//===----------------------------------------------------------------------===//
bool ShapeCastOp::verify() const {
bool TensorCastOp::verify() const {
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
auto *resType = dyn_cast<TensorType>(getType());
if (!opType || !resType)

View File

@ -169,19 +169,19 @@ mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
return %0 : i32
}
// CHECK-LABEL: mlfunc @shape_cast(%arg0
mlfunc @shape_cast(%arg0 : tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
// CHECK: %0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%0 = shape_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK-LABEL: mlfunc @tensor_cast(%arg0
mlfunc @tensor_cast(%arg0 : tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
// CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
%0 = tensor_cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
// CHECK: %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
%1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
// CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
%1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
// CHECK: %2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: %2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = tensor_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: %3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
%3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
// CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
%3 = tensor_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
return
}