!33904 use _inner_ops shard instand of array_ops range

Merge pull request !33904 from yanzhenxiang2020/range_shard_inner
This commit is contained in:
i-robot 2022-05-06 08:10:46 +00:00 committed by Gitee
commit 8f565f098e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 5 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
import mindspore as ms
@ -19,6 +20,7 @@ from mindspore.common import dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.train import Model
from tests.dataset_mock import MindData
@ -52,16 +54,19 @@ class Net(Cell):
self.type = mstype.float32
else:
self.type = mstype.int32
self.start = Tensor(start, self.type)
self.limit = Tensor(limit, self.type)
self.delta = Tensor(delta, self.type)
self.range = P.Range()
limit = float(limit)
start = float(start)
delta = float(delta)
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(int(length_input))), self.type)
self.range = inner.Range(start, limit, delta)
self.range.shard(strategy2)
self.mul2 = P.Mul().shard(strategy3)
self.weight = Parameter(weight, "w")
def construct(self, x, b):
r_out = self.range(self.start, self.limit, self.delta)
r_out = self.range(self.input_tensor)
out = self.mul(x, self.weight)
out = self.mul2(out, r_out)
return out