forked from mindspore-Ecosystem/mindspore
use _inner_ops shard instand of array_ops range
This commit is contained in:
parent
4451a6a3ae
commit
04c5a582da
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -19,6 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.train import Model
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
@ -52,16 +54,19 @@ class Net(Cell):
|
|||
self.type = mstype.float32
|
||||
else:
|
||||
self.type = mstype.int32
|
||||
self.start = Tensor(start, self.type)
|
||||
self.limit = Tensor(limit, self.type)
|
||||
self.delta = Tensor(delta, self.type)
|
||||
self.range = P.Range()
|
||||
limit = float(limit)
|
||||
start = float(start)
|
||||
delta = float(delta)
|
||||
|
||||
length_input = math.ceil((limit - start) / delta)
|
||||
self.input_tensor = Tensor(list(range(int(length_input))), self.type)
|
||||
self.range = inner.Range(start, limit, delta)
|
||||
self.range.shard(strategy2)
|
||||
self.mul2 = P.Mul().shard(strategy3)
|
||||
self.weight = Parameter(weight, "w")
|
||||
|
||||
def construct(self, x, b):
|
||||
r_out = self.range(self.start, self.limit, self.delta)
|
||||
r_out = self.range(self.input_tensor)
|
||||
out = self.mul(x, self.weight)
|
||||
out = self.mul2(out, r_out)
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue