[feat] change type of 2-rd input of operator AffineGrid from tensor to tuple
This commit is contained in:
parent
a9a9d6ee85
commit
c1da0d3b05
|
@ -330,6 +330,7 @@ Array操作
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.col2im
|
||||
|
@ -350,7 +351,6 @@ Array操作
|
|||
mindspore.ops.matrix_diag
|
||||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.meshgrid
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.nonzero
|
||||
mindspore.ops.one_hot
|
||||
mindspore.ops.padding
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.ops.affine_grid
|
|||
**参数:**
|
||||
|
||||
- **theta** (Tensor) - 仿射矩阵输入,其形状为 (N, 2, 3) 用于 2D 或 (N, 3, 4) 用于 3D。
|
||||
- **output_size** (Tensor) - 目标输出图像大小。 其值为 (N, C, H, W) 用于 2D 或 (N, C, D, H, W) 用于 3D。示例:`Tensor([32, 3, 24, 24], mindspore.int32)`。
|
||||
- **output_size** (tuple[int]) - 目标输出图像大小。 其值为 (N, C, H, W) 用于 2D 或 (N, C, D, H, W) 用于 3D。示例:`(32, 3, 24, 24)`。
|
||||
- **align_corners** (bool) - 在几何上,我们将输入的像素视为正方形而不是点。如果设置为True,则极值 -1 和 1 被认为是指输入角像素的中心点。如果设置为False,则它们被认为是指输入角像素的角点,从而使采样与分辨率无关。默认值:False。
|
||||
|
||||
**返回:**
|
||||
|
@ -17,8 +17,8 @@ mindspore.ops.affine_grid
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `theta` 或 `output_size` 不是Tensor。
|
||||
- **TypeError** - `theta` 不是Tensor或 `output_size` 不是tuple。
|
||||
- **ValueError** - `theta` 的形状不是 (N, 2, 3) 或 (N, 3, 4)。
|
||||
- **ValueError** - `output_size` 的维度不是 1,长度不是 4 或 5。
|
||||
- **ValueError** - `theta` 的形状是 (N, 2, 3),`output_size` 的维度却不是4; `theta` 的形状是 (N, 3, 4),`output_size` 的维度却不是5。
|
||||
- **ValueError** - `output_size` 的长度不是 4 或 5。
|
||||
- **ValueError** - `theta` 的形状是 (N, 2, 3),`output_size` 的长度却不是4; `theta` 的形状是 (N, 3, 4),`output_size` 的长度却不是5。
|
||||
- **ValueError** - `output_size` 的第一个值不等于 `theta` 的第一维的长度。
|
||||
|
|
|
@ -329,6 +329,7 @@ Array Operation
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.col2im
|
||||
|
|
|
@ -28,7 +28,6 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
constexpr int RANK_THETA = 3;
|
||||
constexpr int RANK_IMAGE_SIZE = 1;
|
||||
constexpr int N_ROWS_THETA_4D = 2;
|
||||
constexpr int N_COLS_THETA_4D = 3;
|
||||
constexpr int LEN_IMAGE_SIZE_4D = 4;
|
||||
|
@ -39,25 +38,35 @@ constexpr int LEN_IMAGE_SIZE_5D = 5;
|
|||
abstract::ShapePtr AffineGridInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto theta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto output_size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(
|
||||
input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto theta_rank = SizeToLong(theta_shape.size());
|
||||
auto output_size_rank = SizeToLong(output_size_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'theta'", theta_rank, kEqual, RANK_THETA, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'output_size'", output_size_rank,
|
||||
kEqual, RANK_IMAGE_SIZE, prim_name);
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>() &&
|
||||
input_args[kInputIndex1]->BuildValue()->isa<tensor::Tensor>()) {
|
||||
auto output_size = input_args[kInputIndex1]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_size);
|
||||
auto output_size_value_ptr = output_size->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(output_size_value_ptr);
|
||||
auto output_size_tensor = output_size_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_size_tensor);
|
||||
auto output_size_val = reinterpret_cast<int *>(output_size_tensor->data_c());
|
||||
int64_t output_size_val_size = SizeToLong(output_size_tensor->DataSize());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("size of 'output_size'", output_size_val_size,
|
||||
kIncludeBoth, {LEN_IMAGE_SIZE_4D, LEN_IMAGE_SIZE_5D}, prim_name);
|
||||
auto theta_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
|
||||
if (theta_shape_ptr->IsDynamic()) {
|
||||
// theta is dynamic shape, verification could not be performed.
|
||||
// launch kernel will fail, and infer shape will run again.
|
||||
ShapeVector grid_shape = {-2};
|
||||
ShapeVector infer_shape_min;
|
||||
ShapeVector infer_shape_max;
|
||||
infer_shape_min = infer_shape_max = {1};
|
||||
return std::make_shared<abstract::Shape>(grid_shape, infer_shape_min, infer_shape_max);
|
||||
}
|
||||
auto output_size_arg = input_args[kInputIndex1];
|
||||
auto output_size_value_ptr = output_size_arg->BuildValue();
|
||||
if ((output_size_arg->isa<abstract::AbstractTuple>() && output_size_value_ptr->isa<ValueTuple>()) ||
|
||||
(output_size_arg->isa<abstract::AbstractTensor>() && output_size_value_ptr->isa<tensor::Tensor>())) {
|
||||
ShapeVector output_size_val;
|
||||
if (output_size_value_ptr->isa<ValueTuple>()) {
|
||||
output_size_val = CheckAndConvertUtils::CheckTupleInt("input[output_size]", output_size_value_ptr, prim_name);
|
||||
} else if (output_size_value_ptr->isa<tensor::Tensor>()) { // 2-rd infer will be a tensor
|
||||
output_size_val = CheckAndConvertUtils::CheckTensorIntValue("output_size", output_size_value_ptr, prim_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', "
|
||||
<< "the input[output_size] must be a tuple of int.";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("output_size", output_size_val, prim_name);
|
||||
int64_t output_size_val_size = SizeToLong(output_size_val.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("size of 'output_size'", output_size_val_size, kIncludeBoth,
|
||||
{LEN_IMAGE_SIZE_4D, LEN_IMAGE_SIZE_5D}, prim_name);
|
||||
if (output_size_val[kInputIndex0] != theta_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "the output_size[0] must be equal to the shape[0] of theta, "
|
||||
|
@ -65,14 +74,14 @@ abstract::ShapePtr AffineGridInferShape(const PrimitivePtr &primitive, const std
|
|||
<< " and the shape[0] of theta is " << theta_shape[kInputIndex0] << ".";
|
||||
}
|
||||
ShapeVector grid_shape;
|
||||
if (output_size_val_size == LEN_IMAGE_SIZE_4D &&
|
||||
theta_shape[kInputIndex1] == N_ROWS_THETA_4D && theta_shape[kInputIndex2] == N_COLS_THETA_4D) {
|
||||
if (output_size_val_size == LEN_IMAGE_SIZE_4D && theta_shape[kInputIndex1] == N_ROWS_THETA_4D &&
|
||||
theta_shape[kInputIndex2] == N_COLS_THETA_4D) {
|
||||
auto N = static_cast<int64_t>(output_size_val[kInputIndex0]);
|
||||
auto H = static_cast<int64_t>(output_size_val[kInputIndex2]);
|
||||
auto W = static_cast<int64_t>(output_size_val[kInputIndex3]);
|
||||
grid_shape = {N, H, W, kInputIndex2};
|
||||
} else if (output_size_val_size == LEN_IMAGE_SIZE_5D &&
|
||||
theta_shape[kInputIndex1] == N_ROWS_THETA_5D && theta_shape[kInputIndex2] == N_COLS_THETA_5D) {
|
||||
} else if (output_size_val_size == LEN_IMAGE_SIZE_5D && theta_shape[kInputIndex1] == N_ROWS_THETA_5D &&
|
||||
theta_shape[kInputIndex2] == N_COLS_THETA_5D) {
|
||||
auto N = static_cast<int64_t>(output_size_val[kInputIndex0]);
|
||||
auto D = static_cast<int64_t>(output_size_val[kInputIndex2]);
|
||||
auto H = static_cast<int64_t>(output_size_val[kInputIndex3]);
|
||||
|
@ -105,22 +114,19 @@ abstract::ShapePtr AffineGridInferShape(const PrimitivePtr &primitive, const std
|
|||
TypePtr AffineGridInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::string op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, kInputIndex0);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, kInputIndex1);
|
||||
auto theta_type = input_args[kInputIndex0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(theta_type);
|
||||
const std::set<TypePtr> theta_valid_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("theta", theta_type, theta_valid_types, op_name);
|
||||
auto output_size_type = input_args[kInputIndex1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(output_size_type);
|
||||
const std::set<TypePtr> output_size_valid_types = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("output_size", output_size_type, output_size_valid_types, op_name);
|
||||
const std::set<TypePtr> output_size_valid_types = {kTensorType, kTuple}; // 2-rd infer will be a tensor.
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("output_size", output_size_type, output_size_valid_types, op_name);
|
||||
return theta_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AffineGrid::Init(const bool align_corners) {
|
||||
set_align_corners(align_corners);
|
||||
}
|
||||
void AffineGrid::Init(const bool align_corners) { set_align_corners(align_corners); }
|
||||
|
||||
bool AffineGrid::get_align_corners() const {
|
||||
auto value_ptr = this->GetAttr(kAlignCorners);
|
||||
|
|
|
@ -2917,7 +2917,7 @@ def affine_grid(theta, output_size, align_corners=False):
|
|||
Args:
|
||||
theta (Tensor) - The input tensor whose dtype is float16, float32.
|
||||
Input batch of affine matrices with shape [N, 2, 3] for 2D grid or [N, 3, 4] for 3D grid.
|
||||
output_size (Tensor[int32]) - The target output image size. The input is a 1-dimensional Tensor.
|
||||
output_size (tuple[int]) - The target output image size.
|
||||
The value of target output with format [N, C, H, W] for 2D grid or [N, C, D, H, W] for 3D grid.
|
||||
align_corners (bool): If True, consider -1 and 1 to refer to the centers of the corner pixels rather
|
||||
than the image corners. The default value is False.
|
||||
|
@ -2927,12 +2927,11 @@ def affine_grid(theta, output_size, align_corners=False):
|
|||
or [N, D, H, W, 3] for 3D grid.
|
||||
|
||||
Raises:
|
||||
TypeError: If `theta` or `output_size` is not a Tensor.
|
||||
TypeError: If `theta` is not a Tensor or `output_size` is not a tuple.
|
||||
ValueError: If the shape of `theta` is not [N, 2, 3] or [N, 3, 4].
|
||||
ValueError: the dimension of `output_size` is not 1;
|
||||
the size of `output_size` is not 4 or 5.
|
||||
ValueError: If the shape of `theta` is [N, 2, 3], the dimension of `output_size` is not 4;
|
||||
If the shape of `theta` is [N, 3, 4], the dimension of `output_size` is not 5.
|
||||
ValueError: If the size of `output_size` is not 4 or 5.
|
||||
ValueError: If the shape of `theta` is [N, 2, 3], the size of `output_size` is not 4;
|
||||
If the shape of `theta` is [N, 3, 4], the size of `output_size` is not 5.
|
||||
If the output_size[0] is not equal to the shape[0] of theta.
|
||||
|
||||
Supported Platforms:
|
||||
|
@ -2943,7 +2942,7 @@ def affine_grid(theta, output_size, align_corners=False):
|
|||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> theta = Tensor([[[0.8, 0.5, 0],[-0.5, 0.8, 0]]], mindspore.float32)
|
||||
>>> out_size = Tensor([1, 3, 2, 3], mindspore.int32)
|
||||
>>> out_size = (1, 3, 2, 3)
|
||||
>>> output = op.affine_grid(theta, out_size, False)
|
||||
>>> print(output)
|
||||
[[[[-0.78333336 -0.06666666]
|
||||
|
|
|
@ -7846,7 +7846,7 @@ class AffineGrid(Primitive):
|
|||
Examples:
|
||||
>>> affinegrid = AffineGrid(align_corners=False)
|
||||
>>> theta = Tensor([[[0.8, 0.5, 0],[-0.5, 0.8, 0]]], mindspore.float32)
|
||||
>>> out_size = Tensor([1, 3, 2, 3], mindspore.int32)
|
||||
>>> out_size = (1, 3, 2, 3)
|
||||
>>> output = affinegrid(theta, out_size)
|
||||
>>> print(output)
|
||||
[[[[-0.78333336 -0.06666666]
|
||||
|
|
|
@ -42,7 +42,6 @@ class AffineGridDynamicShapeNet(nn.Cell):
|
|||
|
||||
def construct(self, theta, size):
|
||||
theta = self.test_dynamic(theta)
|
||||
size = self.test_dynamic(size)
|
||||
grid = self.affine_grid(theta, size)
|
||||
return grid
|
||||
|
||||
|
@ -52,7 +51,7 @@ def generate_nchw():
|
|||
c = np.random.randint(1, 128)
|
||||
h = np.random.randint(100, 1000)
|
||||
w = np.random.randint(100, 1000)
|
||||
return np.array([n, c, h, w]).astype(np.int32)
|
||||
return np.array([n, c, h, w])
|
||||
|
||||
|
||||
def generate_ncdhw():
|
||||
|
@ -61,7 +60,7 @@ def generate_ncdhw():
|
|||
d = np.random.randint(50, 300)
|
||||
h = np.random.randint(100, 1000)
|
||||
w = np.random.randint(100, 1000)
|
||||
return np.array([n, c, d, h, w]).astype(np.int32)
|
||||
return np.array([n, c, d, h, w])
|
||||
|
||||
|
||||
def np_linspace_from_neg_one(theta, n_steps, align_corners):
|
||||
|
@ -111,14 +110,14 @@ def test_affine_grid_4d(net, align, dtype):
|
|||
Expectation: success or throw AssertionError exception or raise TypeError.
|
||||
"""
|
||||
# Big case, require enormous memory.
|
||||
np_nchw = np.array([128, 1, 1080, 1920]).astype(np.int32)
|
||||
np_nchw = (32, 1, 540, 960)
|
||||
np_theta = np.array([[[1, 0, 0], [0, 1, 0]]]).astype(dtype)
|
||||
np_theta = np.repeat(np_theta, np_nchw[0], axis=0)
|
||||
|
||||
np_grid = np_affine_grid_4d(np_theta, np_nchw, align_corners=align)
|
||||
|
||||
affine_grid = net(align_corners=align)
|
||||
ms_theta, ms_nchw = Tensor(np_theta), Tensor(np_nchw)
|
||||
ms_theta, ms_nchw = Tensor(np_theta), np_nchw
|
||||
ms_grid = affine_grid(ms_theta, ms_nchw)
|
||||
|
||||
print(f"max error: {np.max(np_grid - ms_grid.asnumpy())}")
|
||||
|
@ -144,14 +143,14 @@ def test_affine_grid_5d(net, align, dtype):
|
|||
Expectation: success or throw AssertionError exception or raise TypeError.
|
||||
"""
|
||||
# Big case, require enormous memory.
|
||||
np_ncdhw = np.array([128, 1, 16, 270, 480]).astype(np.int32)
|
||||
np_ncdhw = (32, 1, 16, 135, 240)
|
||||
np_theta = np.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]]).astype(dtype)
|
||||
np_theta = np.repeat(np_theta, np_ncdhw[0], axis=0)
|
||||
|
||||
np_grid = np_affine_grid_5d(np_theta, np_ncdhw, align_corners=align)
|
||||
|
||||
affine_grid = net(align_corners=align)
|
||||
ms_theta, ms_ncdhw = Tensor(np_theta), Tensor(np_ncdhw)
|
||||
ms_theta, ms_ncdhw = Tensor(np_theta), np_ncdhw
|
||||
ms_grid = affine_grid(ms_theta, ms_ncdhw)
|
||||
|
||||
print(f"max error: {np.max(np_grid - ms_grid.asnumpy())}")
|
||||
|
|
Loading…
Reference in New Issue