mindspore/tests/ut/python/parallel/test_gather_v2_primitive.py

235 lines
8.0 KiB
Python

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore.ops import composite as C
from mindspore.common.parameter import ParameterTuple
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init
from mindspore.train import Model, ParallelMode
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Dense, Cell
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
device_number = 32
batch_size_per_device = 128
class Dataset():
def __init__(self, predict, length=3):
self.predict = predict
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return (self.predict,)
def reset(self):
self.index = 0
def get_dataset_size(self):
return 128
def get_repeat_count(self):
return 1
class GatherV2(_Loss):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2, self).__init__()
self.pow = P.Pow()
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size*16)
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), 16))
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.GatherV2().set_strategy(strategy).add_prim_attr("data_parallel", True)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
emb2 = self.gatherv2(nembeddings, self.emb2_param, 0)
return self.pow((emb1 - emb2), 2.0)
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data):
embeddings = self.network(input_data)
loss = self.criterion(embeddings)
return loss
class TrainOneStepCell(Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.sens = sens
def construct(self, data):
weights = self.weights
loss = self.network(data)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, sens)
return F.depend(loss, self.optimizer(grads))
def net_trains(gather_v2_strategy, criterion, rank):
init()
lr = 0.1
momentum = 0.9
max_epoch = 20
input_channels = 256
out_channels = 512
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
global_rank=rank)
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
dataset = Dataset(predict, 4)
network = fc_with_initialize(input_channels, out_channels)
network.set_train()
train_network = BuildTrainNetwork(network, criterion)
train_network.set_train()
opt = Momentum(train_network.trainable_params(), lr, momentum)
train_net = TrainOneStepCell(train_network, opt).set_train()
model = Model(train_net)
model.train(max_epoch, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
def test_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_2d_index_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_batch_parallel():
gather_v2_strategy = ((device_number, 1),)
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy1():
gather_v2_strategy = ((16, 2),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy2():
gather_v2_strategy = ((1, device_number),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy3():
gather_v2_strategy = ((8, 1),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
class GatherV2Axis1(_Loss):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2Axis1, self).__init__()
self.pow = P.Pow()
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size*index_size)
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), index_size))
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), index_size))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.GatherV2().set_strategy(strategy)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
return self.pow((emb1 - emb2), 2.0)
def test_axis1_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_axis1_batch_parallel():
gather_v2_strategy = ((device_number, 1),)
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_axis1_strategy1():
gather_v2_strategy = ((16, 2),)
rank = 17
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
net_trains(gather_v2_strategy, criterion, rank)