Fix MatrixSetDiag, Select, Conv2D op annotation problem.

This commit is contained in:
liangchenghui 2020-10-14 20:30:35 +08:00
parent 11058ad0ef
commit 8eedd68b8e
5 changed files with 31 additions and 13 deletions

View File

@ -66,7 +66,7 @@ Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph
}
if (graph_ptr == nullptr) {
MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed";
MS_LOG(INFO) << "The new graph {" << name << "}'s pointer is null, add graph failed";
return Status::INVALID_ARGUMENT;
}

View File

@ -629,9 +629,9 @@ class MatrixSetDiag(Cell):
Modify the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be one of the following data types:
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
float32, float16, int32, int8, and uint8.
- **diagonal** (Tensor) - The diagonal values.
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
Outputs:
Tensor, has the same type and shape as input `x`.

View File

@ -410,10 +410,10 @@ class MatrixSetDiag(PrimitiveWithInfer):
Modifies the batched diagonal part of a batched tensor.
Inputs:
- **x** (Tensor) - The batched tensor. It can be one of the following data types:
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
float32, float16, int32, int8, uint8.
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
- **diagonal** (Tensor) - The diagonal values.
Outputs:
Tensor, data type same as input `x`. The shape same as `x`.

View File

@ -2026,22 +2026,22 @@ class Select(PrimitiveWithInfer):
and :math:`y`.
Inputs:
- **input_x** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
- **input_cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The condition tensor, decides which element is chosen.
- **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The first input tensor.
- **input_z** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
- **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The second input tensor.
Outputs:
Tensor, has the same shape as `input_y`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
Tensor, has the same shape as `input_x`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
Examples:
>>> select = P.Select()
>>> input_x = Tensor([True, False])
>>> input_y = Tensor([2,3], mindspore.float32)
>>> input_z = Tensor([1,2], mindspore.float32)
>>> select(input_x, input_y, input_z)
>>> input_cond = Tensor([True, False])
>>> input_x = Tensor([2,3], mindspore.float32)
>>> input_y = Tensor([1,2], mindspore.float32)
>>> select(input_cond, input_x, input_y)
"""
@prim_attr_register

View File

@ -58,6 +58,23 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals
return ret_value
def _check_shape(arg_name, arg_value, prim_name):
"""
Checks whether an shape dims is a positive int elements.
"""
def _raise_message():
raise ValueError(f"For '{prim_name}' attr '{arg_name}' dims elements should be positive int numbers, "
f"but got {arg_value}")
validator.check_value_type(arg_name, arg_value, (list, tuple), prim_name)
for item in arg_value:
if isinstance(item, int) and item > 0:
continue
_raise_message()
return arg_value
class Flatten(PrimitiveWithInfer):
r"""
Flattens a tensor without changing its batch size on the 0-th axis.
@ -1052,6 +1069,7 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel
out_shape = [x_shape[0], out_channel, h_out, w_out]
_check_shape('output', out_shape, self.name)
return out_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):