diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 427e5a29ce7..d2b5d4f5d86 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -237,6 +237,9 @@ class Model: network.set_train(is_train) network.phase = phase + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + return dataset_helper, network def init(self, train_dataset=None, valid_dataset=None): diff --git a/tests/ut/python/parallel/test_auto_parallel_flag.py b/tests/ut/python/parallel/test_auto_parallel_flag.py new file mode 100644 index 00000000000..d18ae7687cb --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_flag.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.nn.optim import Momentum +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.train import Model +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from ....dataset_mock import MindData + +context.set_context(mode=context.GRAPH_MODE) + + +class MindDataSet(MindData): + def __init__(self, dataset_types, dataset_shapes): + super(MindDataSet, self).__init__(size=2, batch_size=32, + np_types=dataset_types, + output_shapes=dataset_shapes, + input_indexs=(0, 1)) + + def __next__(self): + if self._size < self._iter_num: + raise StopIteration + self._iter_num += 1 + next = [] + for shape, type in zip(self._output_shapes, self._np_types): + next.append(Tensor(np.ones(shape).astype(type))) + return tuple(next) + + +class Net(nn.Cell): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") + self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias") + self.matmul = P.MatMul() + self.add = P.TensorAdd() + + def construct(self, input): + output = self.add(self.matmul(input, self.weight), self.bias) + return output + + +class NetFP16(nn.Cell): + def __init__(self, in_features, out_features): + super(NetFP16, self).__init__() + self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") + self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias") + self.matmul = P.MatMul() + self.add = P.TensorAdd() + self.cast = P.Cast() + + def construct(self, input): + output = self.cast( + self.add(self.matmul(self.cast(input, mstype.float16), self.cast(self.weight, mstype.float16)), + self.cast(self.bias, mstype.float16)), mstype.float32) + return output + + +def get_axis(x): + shape_op = P.Shape() + shape = shape_op(x) + length = F.tuple_len(shape) + perm = F.make_range(0, length) + return perm + + +class MSELoss(nn.Cell): + def __init__(self): + super(MSELoss, self).__init__() + self.reduce_sum = P.ReduceSum() + self.square = P.Square() + self.reduce_mean = P.ReduceMean() + + def construct(self, data, label): + diff = data - label + return self.reduce_mean(self.square(diff), get_axis(diff)) + + +def test_auto_parallel_flag(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=1) + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + + dataset = MindDataSet(dataset_types, dataset_shapes) + net = NetFP16(16, 16) + net.set_train() + scale_manager = FixedLossScaleManager() + loss = MSELoss() + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager) + model.train(2, dataset) + assert(model._train_network.get_flags()["auto_parallel"] == True) + context.reset_auto_parallel_context()