fix Supported Platforms api.

fix msssim.
This commit is contained in:
liuxiao93 2020-12-26 10:54:52 +08:00
parent 32877aeffb
commit a9617aec88
2 changed files with 4 additions and 3 deletions

View File

@ -338,7 +338,8 @@ class MSSSIM(Cell):
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
_check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name)
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
P.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))

View File

@ -6435,7 +6435,7 @@ class DynamicRNN(PrimitiveWithInfer):
Has the same type with input `b`.
Supported Platforms:
``Ascend``
``Ascend``
Examples:
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
@ -6570,7 +6570,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
Supported Platforms:
``Ascend``
``Ascend``
Examples:
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))