forked from mindspore-Ecosystem/mindspore
add global batch normalization
This commit is contained in:
parent
17e27824c5
commit
97e250d4f1
|
@ -21,7 +21,6 @@ class Hccl():
|
||||||
_instance = None
|
_instance = None
|
||||||
_rank_id = 0
|
_rank_id = 0
|
||||||
_rank_size = 1
|
_rank_size = 1
|
||||||
_group_size = 4
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -48,10 +47,6 @@ class Hccl():
|
||||||
def rank_size(self):
|
def rank_size(self):
|
||||||
return self._rank_size
|
return self._rank_size
|
||||||
|
|
||||||
@property
|
|
||||||
def group_size(self):
|
|
||||||
return self._group_size
|
|
||||||
|
|
||||||
@rank_size.setter
|
@rank_size.setter
|
||||||
def rank_size(self, size):
|
def rank_size(self, size):
|
||||||
self._rank_size = size
|
self._rank_size = size
|
||||||
|
@ -70,14 +65,6 @@ def get_rank_size(group=None):
|
||||||
return int(group.split("-")[0])
|
return int(group.split("-")[0])
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def get_group_size(group=None):
|
|
||||||
hccl = Hccl()
|
|
||||||
if group is None:
|
|
||||||
return hccl.group_size
|
|
||||||
if isinstance(group, str):
|
|
||||||
return int(group.split("-")[0])
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def get_world_rank_from_group_rank(group, group_rank_id):
|
def get_world_rank_from_group_rank(group, group_rank_id):
|
||||||
return group_rank_id
|
return group_rank_id
|
||||||
|
|
|
@ -19,9 +19,6 @@ import pytest
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore.communication.management import init
|
|
||||||
from mindspore import context
|
|
||||||
from mindspore import ParallelMode
|
|
||||||
|
|
||||||
|
|
||||||
def test_bn_pars_valid1():
|
def test_bn_pars_valid1():
|
||||||
|
|
Loading…
Reference in New Issue