forked from mindspore-Ecosystem/mindspore
Upgrade ps ci cases
This commit is contained in:
parent
e5c2dcbabb
commit
2e1d231b37
|
@ -22,7 +22,6 @@ export MS_WORKER_NUM=$2
|
|||
export MS_SERVER_NUM=$3
|
||||
export MS_SCHED_HOST=$4
|
||||
export MS_SCHED_PORT=$5
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
@ -26,10 +25,11 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
|
||||
from mindspore.parallel._ps_context import _is_role_worker
|
||||
from mindspore.communication.management import init
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_sparse_embedding")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(
|
||||
|
@ -71,10 +71,6 @@ def do_sparse_embedding(ps=False):
|
|||
for _ in range(epoch):
|
||||
data = Tensor(np.random.randint(-5, 15, (32, 3), np.int32))
|
||||
label = Tensor(np.random.randint(0, 9, (32), np.int32))
|
||||
if _is_role_pserver():
|
||||
train_network(data, label)
|
||||
sys.exit()
|
||||
else:
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
|
@ -84,6 +80,7 @@ def do_sparse_embedding(ps=False):
|
|||
envs = os.environ
|
||||
if __name__ == "__main__":
|
||||
set_seed(0)
|
||||
init()
|
||||
ps_loss = do_sparse_embedding(True)
|
||||
|
||||
if _is_role_worker():
|
||||
|
|
|
@ -13,13 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cmp_sparse_embedding():
|
||||
return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081")
|
||||
return_code = os.system("bash shell_run_test.sh GPU 1 1 127.0.0.1 8081")
|
||||
assert return_code == 0
|
||||
|
|
|
@ -23,8 +23,8 @@ export MS_WORKER_NUM=$3
|
|||
export MS_SERVER_NUM=$4
|
||||
export MS_SCHED_HOST=$5
|
||||
export MS_SCHED_PORT=$6
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
export fusion=True
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/sched_$i/
|
||||
|
|
|
@ -13,15 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_full_ps_ascend_lenet():
|
||||
return_code = os.system(
|
||||
"bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082"
|
||||
"bash shell_run_test.sh GPU /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082"
|
||||
)
|
||||
assert return_code == 0
|
||||
|
|
|
@ -27,9 +27,10 @@ from mindspore.nn.metrics import Accuracy
|
|||
from mindspore.train import Model
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.communication.management import init
|
||||
|
||||
parser = argparse.ArgumentParser(description='test_ps_lenet')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -122,6 +123,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
|||
return mnist_ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
network = LeNet5(10)
|
||||
network.set_param_ps()
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
|
|
@ -23,8 +23,8 @@ export MS_WORKER_NUM=$3
|
|||
export MS_SERVER_NUM=$4
|
||||
export MS_SCHED_HOST=$5
|
||||
export MS_SCHED_PORT=$6
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
export fusion=True
|
||||
for((i=0;i<1;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/sched_$i/
|
||||
|
|
|
@ -13,11 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
def test_ps_embedding_heterogeneous_conv2d_adam():
|
||||
return_code = os.system(
|
||||
"bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8085"
|
||||
"bash shell_run_test.sh GPU /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8085"
|
||||
)
|
||||
assert return_code == 0
|
||||
|
|
|
@ -36,9 +36,10 @@ import mindspore.ops.operations as op
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init
|
||||
|
||||
parser = argparse.ArgumentParser(description='test_ps_lenet')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -180,5 +181,6 @@ class NetFactory:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
fact = NetFactory()
|
||||
fact.part_cmp()
|
||||
|
|
Loading…
Reference in New Issue