!33904 use _inner_ops shard instand of array_ops range
Merge pull request !33904 from yanzhenxiang2020/range_shard_inner
This commit is contained in:
commit
8f565f098e
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
|
@ -19,6 +20,7 @@ from mindspore.common import dtype as mstype
|
||||||
from mindspore import context, Tensor, Parameter
|
from mindspore import context, Tensor, Parameter
|
||||||
from mindspore.nn import Cell, Momentum
|
from mindspore.nn import Cell, Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from tests.dataset_mock import MindData
|
from tests.dataset_mock import MindData
|
||||||
|
|
||||||
|
@ -52,16 +54,19 @@ class Net(Cell):
|
||||||
self.type = mstype.float32
|
self.type = mstype.float32
|
||||||
else:
|
else:
|
||||||
self.type = mstype.int32
|
self.type = mstype.int32
|
||||||
self.start = Tensor(start, self.type)
|
limit = float(limit)
|
||||||
self.limit = Tensor(limit, self.type)
|
start = float(start)
|
||||||
self.delta = Tensor(delta, self.type)
|
delta = float(delta)
|
||||||
self.range = P.Range()
|
|
||||||
|
length_input = math.ceil((limit - start) / delta)
|
||||||
|
self.input_tensor = Tensor(list(range(int(length_input))), self.type)
|
||||||
|
self.range = inner.Range(start, limit, delta)
|
||||||
self.range.shard(strategy2)
|
self.range.shard(strategy2)
|
||||||
self.mul2 = P.Mul().shard(strategy3)
|
self.mul2 = P.Mul().shard(strategy3)
|
||||||
self.weight = Parameter(weight, "w")
|
self.weight = Parameter(weight, "w")
|
||||||
|
|
||||||
def construct(self, x, b):
|
def construct(self, x, b):
|
||||||
r_out = self.range(self.start, self.limit, self.delta)
|
r_out = self.range(self.input_tensor)
|
||||||
out = self.mul(x, self.weight)
|
out = self.mul(x, self.weight)
|
||||||
out = self.mul2(out, r_out)
|
out = self.mul2(out, r_out)
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue