!33904 use _inner_ops shard instand of array_ops range

Merge pull request !33904 from yanzhenxiang2020/range_shard_inner
This commit is contained in:
i-robot 2022-05-06 08:10:46 +00:00 committed by Gitee
commit 8f565f098e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 5 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import math
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
@ -19,6 +20,7 @@ from mindspore.common import dtype as mstype
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum from mindspore.nn import Cell, Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.train import Model from mindspore.train import Model
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
@ -52,16 +54,19 @@ class Net(Cell):
self.type = mstype.float32 self.type = mstype.float32
else: else:
self.type = mstype.int32 self.type = mstype.int32
self.start = Tensor(start, self.type) limit = float(limit)
self.limit = Tensor(limit, self.type) start = float(start)
self.delta = Tensor(delta, self.type) delta = float(delta)
self.range = P.Range()
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(int(length_input))), self.type)
self.range = inner.Range(start, limit, delta)
self.range.shard(strategy2) self.range.shard(strategy2)
self.mul2 = P.Mul().shard(strategy3) self.mul2 = P.Mul().shard(strategy3)
self.weight = Parameter(weight, "w") self.weight = Parameter(weight, "w")
def construct(self, x, b): def construct(self, x, b):
r_out = self.range(self.start, self.limit, self.delta) r_out = self.range(self.input_tensor)
out = self.mul(x, self.weight) out = self.mul(x, self.weight)
out = self.mul2(out, r_out) out = self.mul2(out, r_out)
return out return out