From 6c9d32d446030a32d598226f5a02ed7fdef4642e Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 23 Jun 2020 14:21:26 +0800 Subject: [PATCH] fix StridedSliceGrad --- mindspore/ops/operations/_grad_ops.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 583005f78db..187a17fd4c4 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1066,8 +1066,18 @@ class StridedSliceGrad(PrimitiveWithInfer): self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) def __infer__(self, dy, shapex, begin, end, strides): - args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']} + args = {"dy": dy['dtype']} validator.check_tensor_type_same(args, mstype.number_type, self.name) + + for idx, item in enumerate(shapex['value']): + validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) + for idx, item in enumerate(begin['value']): + validator.check_value_type("begin[%d]" % idx, item, [int], self.name) + for idx, item in enumerate(end['value']): + validator.check_value_type("end[%d]" % idx, item, [int], self.name) + for idx, item in enumerate(strides['value']): + validator.check_value_type("strides[%d]" % idx, item, [int], self.name) + return {'shape': shapex['value'], 'dtype': dy['dtype'], 'value': None}