add auto parallel st

This commit is contained in:
lichenever 2020-05-27 22:13:06 +08:00
parent 906389b3a7
commit 191d3c05d0
3 changed files with 37 additions and 43 deletions

View File

@ -22,7 +22,7 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore import Tensor
from mindspore.common.initializer import One
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum
@ -35,10 +35,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
np.random.seed(10)
def weight_variable():
return One()
return TruncatedNormal(0.01)
def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
@ -93,11 +94,9 @@ class BasicBlock(nn.Cell):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample:
identity = self.down_sample_layer(identity)
@ -120,13 +119,10 @@ class ResidualBlock(nn.Cell):
out_chls = out_channels // self.expansion
self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
self.bn1 = _fused_bn(out_chls, momentum=momentum)
self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
self.bn2 = _fused_bn(out_chls, momentum=momentum)
self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
self.bn3 = _fused_bn(out_channels, momentum=momentum)
self.relu = P.ReLU()
self.downsample = (in_channels != out_channels)
@ -134,7 +130,6 @@ class ResidualBlock(nn.Cell):
if self.downsample:
self.conv_down_sample = _conv1x1(in_channels, out_channels,
stride=stride)
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
elif self.stride != 1:
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
@ -144,19 +139,15 @@ class ResidualBlock(nn.Cell):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
elif self.stride != 1:
identity = self.maxpool_down(identity)
@ -211,7 +202,7 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable(),
bias_init=weight_variable())
bias_init=weight_variable()).add_flags_recursive(fp16=True)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
@ -231,7 +222,6 @@ class ResNet(nn.Cell):
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
@ -277,6 +267,7 @@ class SoftmaxCrossEntropyExpand(_Loss):
self.eps = Tensor(1e-24, mstype.float32)
def construct(self, logit, label):
logit = self.cast(logit, mstype.float32)
logit_max = self.max(logit, -1)
exp = self.exp(self.sub(logit, logit_max))
exp_sum = self.sum(exp, -1)
@ -369,41 +360,19 @@ class ModelCallback(Callback):
self.loss_list.append(result.asnumpy().mean())
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_train_feed(num_classes=8192):
def test_train_feed(num_classes=65536):
set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback()
data_gen = DataGenerator()
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224))
_, label_part = data_gen.label_data((32 * 2,))
_, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
_, label_part = data_gen.label_data((32 * 8,))
dataset = Dataset(input_part, label_part)
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489]
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_train_feed2(num_classes=1001):
set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback()
data_gen = DataGenerator()
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224))
_, label_part = data_gen.label_data((32 * 2,))
dataset = Dataset(input_part, label_part)
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097]
expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148]
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)

View File

@ -18,7 +18,6 @@ BASE_PATH=$(cd "$(dirname $0)"; pwd)
CONFIG_PATH=/home/workspace/mindspore_config
export DEVICE_NUM=8
export RANK_SIZE=$DEVICE_NUM
ulimit -n 65535
source ${BASE_PATH}/env.sh
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_expand_loss():
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_auto_parallel_resnet50_expand_loss.sh")
assert ret == 0