change shape and dtype of tensor from interface to attr

This commit is contained in:
buxue 2020-07-16 19:44:24 +08:00
parent 4936fe487f
commit 6d5c580a61
4 changed files with 8 additions and 8 deletions

View File

@ -1020,7 +1020,7 @@ class LayerNorm(Cell):
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:]
>>> shape1 = x.shape[1:]
>>> m = G.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""

View File

@ -746,8 +746,8 @@ class DenseQuant(Cell):
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(
@ -755,7 +755,7 @@ class DenseQuant(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(

View File

@ -77,15 +77,15 @@ class GNNFeatureTransform(nn.Cell):
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -28,4 +28,4 @@ def test_cast():
type_dst = ms.float32
cast = P.Cast()
result = cast(input_x, type_dst)
assert result.dtype() == type_dst
assert result.dtype == type_dst