fix StridedSliceGrad

This commit is contained in:
jiangjinsheng 2020-06-23 14:21:26 +08:00
parent 157ee1ca16
commit 6c9d32d446
1 changed files with 11 additions and 1 deletions

View File

@ -1066,8 +1066,18 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides): def __infer__(self, dy, shapex, begin, end, strides):
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']} args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensor_type_same(args, mstype.number_type, self.name)
for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(begin['value']):
validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(end['value']):
validator.check_value_type("end[%d]" % idx, item, [int], self.name)
for idx, item in enumerate(strides['value']):
validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
return {'shape': shapex['value'], return {'shape': shapex['value'],
'dtype': dy['dtype'], 'dtype': dy['dtype'],
'value': None} 'value': None}