forked from mindspore-Ecosystem/mindspore
!15865 fix fasterrcnn fail in pynative
From: @chujinjin Reviewed-by: @linqingke,@kisnwang Signed-off-by: @linqingke
This commit is contained in:
commit
9fbfc63de9
|
@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer):
|
|||
|
||||
def check_elim(self, x, dtype):
|
||||
if isinstance(x, (Tensor, numbers.Number, Parameter)):
|
||||
if isinstance(x, Tensor) and x.dtype == dtype:
|
||||
return (True, x)
|
||||
if isinstance(x, numbers.Number):
|
||||
return (True, Tensor(x, dtype=dtype))
|
||||
if isinstance(x, Parameter):
|
||||
data = x.data
|
||||
if data.dtype == dtype:
|
||||
return (True, x)
|
||||
if isinstance(x, Tensor) and x.dtype == dtype:
|
||||
x = Tensor(x)
|
||||
x.set_cast_dtype()
|
||||
return (True, x)
|
||||
if isinstance(x, numbers.Number):
|
||||
return (True, Tensor(x, dtype=dtype))
|
||||
return (False, None)
|
||||
|
||||
def __infer__(self, x, t):
|
||||
|
|
|
@ -143,7 +143,7 @@ class Rcnn(nn.Cell):
|
|||
|
||||
if self.training:
|
||||
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
|
||||
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type)
|
||||
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
|
||||
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
|
||||
|
||||
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
|
||||
|
|
Loading…
Reference in New Issue