forked from mindspore-Ecosystem/mindspore
set auto parallel for dataset warp cell
This commit is contained in:
parent
45484c690c
commit
2f8516e5d7
|
@ -237,6 +237,9 @@ class Model:
|
|||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
def init(self, train_dataset=None, valid_dataset=None):
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class MindDataSet(MindData):
|
||||
def __init__(self, dataset_types, dataset_shapes):
|
||||
super(MindDataSet, self).__init__(size=2, batch_size=32,
|
||||
np_types=dataset_types,
|
||||
output_shapes=dataset_shapes,
|
||||
input_indexs=(0, 1))
|
||||
|
||||
def __next__(self):
|
||||
if self._size < self._iter_num:
|
||||
raise StopIteration
|
||||
self._iter_num += 1
|
||||
next = []
|
||||
for shape, type in zip(self._output_shapes, self._np_types):
|
||||
next.append(Tensor(np.ones(shape).astype(type)))
|
||||
return tuple(next)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
|
||||
self.matmul = P.MatMul()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, input):
|
||||
output = self.add(self.matmul(input, self.weight), self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class NetFP16(nn.Cell):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(NetFP16, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
|
||||
self.matmul = P.MatMul()
|
||||
self.add = P.TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input):
|
||||
output = self.cast(
|
||||
self.add(self.matmul(self.cast(input, mstype.float16), self.cast(self.weight, mstype.float16)),
|
||||
self.cast(self.bias, mstype.float16)), mstype.float32)
|
||||
return output
|
||||
|
||||
|
||||
def get_axis(x):
|
||||
shape_op = P.Shape()
|
||||
shape = shape_op(x)
|
||||
length = F.tuple_len(shape)
|
||||
perm = F.make_range(0, length)
|
||||
return perm
|
||||
|
||||
|
||||
class MSELoss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MSELoss, self).__init__()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.square = P.Square()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
|
||||
def construct(self, data, label):
|
||||
diff = data - label
|
||||
return self.reduce_mean(self.square(diff), get_axis(diff))
|
||||
|
||||
|
||||
def test_auto_parallel_flag():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=1)
|
||||
dataset_types = (np.float32, np.float32)
|
||||
dataset_shapes = ((16, 16), (16, 16))
|
||||
|
||||
dataset = MindDataSet(dataset_types, dataset_shapes)
|
||||
net = NetFP16(16, 16)
|
||||
net.set_train()
|
||||
scale_manager = FixedLossScaleManager()
|
||||
loss = MSELoss()
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager)
|
||||
model.train(2, dataset)
|
||||
assert(model._train_network.get_flags()["auto_parallel"] == True)
|
||||
context.reset_auto_parallel_context()
|
Loading…
Reference in New Issue