!15865 fix fasterrcnn fail in pynative

From: @chujinjin
Reviewed-by: @linqingke,@kisnwang
Signed-off-by: @linqingke
This commit is contained in:
mindspore-ci-bot 2021-05-06 14:43:05 +08:00 committed by Gitee
commit 9fbfc63de9
2 changed files with 7 additions and 5 deletions

View File

@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer):
def check_elim(self, x, dtype):
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
if isinstance(x, Parameter):
data = x.data
if data.dtype == dtype:
return (True, x)
if isinstance(x, Tensor) and x.dtype == dtype:
x = Tensor(x)
x.set_cast_dtype()
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
return (False, None)
def __infer__(self, x, t):

View File

@ -143,7 +143,7 @@ class Rcnn(nn.Cell):
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type)
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)