forked from mindspore-Ecosystem/mindspore
Fix MatrixSetDiag, Select, Conv2D op annotation problem.
This commit is contained in:
parent
11058ad0ef
commit
8eedd68b8e
|
@ -66,7 +66,7 @@ Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph
|
|||
}
|
||||
|
||||
if (graph_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed";
|
||||
MS_LOG(INFO) << "The new graph {" << name << "}'s pointer is null, add graph failed";
|
||||
return Status::INVALID_ARGUMENT;
|
||||
}
|
||||
|
||||
|
|
|
@ -629,9 +629,9 @@ class MatrixSetDiag(Cell):
|
|||
Modify the batched diagonal part of a batched tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The batched tensor. It can be one of the following data types:
|
||||
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
|
||||
float32, float16, int32, int8, and uint8.
|
||||
- **diagonal** (Tensor) - The diagonal values.
|
||||
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type and shape as input `x`.
|
||||
|
|
|
@ -410,10 +410,10 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
Modifies the batched diagonal part of a batched tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The batched tensor. It can be one of the following data types:
|
||||
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
|
||||
float32, float16, int32, int8, uint8.
|
||||
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
|
||||
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
|
||||
- **diagonal** (Tensor) - The diagonal values.
|
||||
|
||||
Outputs:
|
||||
Tensor, data type same as input `x`. The shape same as `x`.
|
||||
|
|
|
@ -2026,22 +2026,22 @@ class Select(PrimitiveWithInfer):
|
|||
and :math:`y`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
- **input_cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The condition tensor, decides which element is chosen.
|
||||
- **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The first input tensor.
|
||||
- **input_z** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
- **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The second input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `input_y`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
Tensor, has the same shape as `input_x`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> select = P.Select()
|
||||
>>> input_x = Tensor([True, False])
|
||||
>>> input_y = Tensor([2,3], mindspore.float32)
|
||||
>>> input_z = Tensor([1,2], mindspore.float32)
|
||||
>>> select(input_x, input_y, input_z)
|
||||
>>> input_cond = Tensor([True, False])
|
||||
>>> input_x = Tensor([2,3], mindspore.float32)
|
||||
>>> input_y = Tensor([1,2], mindspore.float32)
|
||||
>>> select(input_cond, input_x, input_y)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -58,6 +58,23 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals
|
|||
return ret_value
|
||||
|
||||
|
||||
def _check_shape(arg_name, arg_value, prim_name):
|
||||
"""
|
||||
Checks whether an shape dims is a positive int elements.
|
||||
"""
|
||||
|
||||
def _raise_message():
|
||||
raise ValueError(f"For '{prim_name}' attr '{arg_name}' dims elements should be positive int numbers, "
|
||||
f"but got {arg_value}")
|
||||
|
||||
validator.check_value_type(arg_name, arg_value, (list, tuple), prim_name)
|
||||
for item in arg_value:
|
||||
if isinstance(item, int) and item > 0:
|
||||
continue
|
||||
_raise_message()
|
||||
return arg_value
|
||||
|
||||
|
||||
class Flatten(PrimitiveWithInfer):
|
||||
r"""
|
||||
Flattens a tensor without changing its batch size on the 0-th axis.
|
||||
|
@ -1052,6 +1069,7 @@ class Conv2D(PrimitiveWithInfer):
|
|||
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
|
||||
out_channel = self.out_channel
|
||||
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
||||
_check_shape('output', out_shape, self.name)
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
||||
|
|
Loading…
Reference in New Issue