!17758 Fix some bug of ops.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui,@liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-06-07 17:21:52 +08:00 committed by Gitee
commit b1ccab290a
4 changed files with 18 additions and 3 deletions

View File

@ -27,6 +27,7 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kAvgPool3DInputNum = 1;
constexpr size_t k5DInferDims = 5;
constexpr size_t kC0 = 16;
@ -229,6 +230,11 @@ const AnfNodePtr AvgPool3DFusion::Process(const FuncGraphPtr &func_graph, const
MS_EXCEPTION_IF_NULL(node);
auto avg_pool_3d_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(avg_pool_3d_node);
if (avg_pool_3d_node->size() != kAvgPool3DInputNum + 1) {
MS_LOG(INFO) << "The node " << avg_pool_3d_node->DebugString() << " is not equal to " << kAvgPool3DInputNum
<< " inputs. Can not do fusion.";
return nullptr;
}
auto dims_in = AnfAlgo::GetPrevNodeOutputInferShape(avg_pool_3d_node, 0);
auto dims_out = AnfAlgo::GetOutputInferShape(avg_pool_3d_node, 0);
if (dims_in.size() < k5DInferDims || dims_out.size() < k5DInferDims) {

View File

@ -27,6 +27,7 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kAvgPool3DGradInputNum = 1;
constexpr size_t k5DInferDims = 5;
constexpr size_t kKernelDims = 3;
constexpr size_t kStridesDims = 3;
@ -208,6 +209,11 @@ const AnfNodePtr AvgPool3DGradFusion::Process(const FuncGraphPtr &func_graph, co
MS_EXCEPTION_IF_NULL(node);
auto avg_pool_3d_grad_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(avg_pool_3d_grad_node);
if (avg_pool_3d_grad_node->size() != kAvgPool3DGradInputNum + 1) {
MS_LOG(INFO) << "The node " << avg_pool_3d_grad_node->DebugString() << " is not equal to " << kAvgPool3DGradInputNum
<< " inputs. Can not do fusion.";
return nullptr;
}
std::vector<int64_t> kernel_size;
std::vector<int64_t> strides;
std::vector<int64_t> pad_list;

View File

@ -1004,7 +1004,7 @@ class MaxPool3DGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return mstype.tensor_type(mstype.float32)
return x_dtype
class MaxPool3DGradGrad(PrimitiveWithInfer):

View File

@ -7277,6 +7277,9 @@ class DynamicRNN(PrimitiveWithInfer):
are learnable weights between the output and the input in the formula. For instance,
:math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
Note:
The `hidden_size` in shape of inputs must be multiple of 16.
Args:
cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
Only 'LSTM' is currently supported.
@ -7785,8 +7788,8 @@ class AvgPool3D(Primitive):
>>> avg_pool3d = ops.AvgPool3D(kernel_size=2, strides=1, pad_mode="valid")
>>> output = avg_pool3d(input)
>>> print(output)
[[[[[233.5 248.625]]]
[[[233.5 238.125]]]]]
[[[[[5. 6.]]]
[[[17. 18.]]]]]
"""
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", pad=0, ceil_mode=False,