forked from mindspore-Ecosystem/mindspore
!2494 fixed StridedSliceGrad
Merge pull request !2494 from jiangjinsheng/issue_fix4
This commit is contained in:
commit
edeba61c05
|
@ -1093,8 +1093,18 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, dy, shapex, begin, end, strides):
|
def __infer__(self, dy, shapex, begin, end, strides):
|
||||||
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']}
|
args = {"dy": dy['dtype']}
|
||||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
|
|
||||||
|
for idx, item in enumerate(shapex['value']):
|
||||||
|
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(begin['value']):
|
||||||
|
validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(end['value']):
|
||||||
|
validator.check_value_type("end[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(strides['value']):
|
||||||
|
validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
|
||||||
|
|
||||||
return {'shape': shapex['value'],
|
return {'shape': shapex['value'],
|
||||||
'dtype': dy['dtype'],
|
'dtype': dy['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
Loading…
Reference in New Issue