forked from mindspore-Ecosystem/mindspore
add global batch normalization
This commit is contained in:
parent
17e27824c5
commit
97e250d4f1
|
@ -21,7 +21,6 @@ class Hccl():
|
|||
_instance = None
|
||||
_rank_id = 0
|
||||
_rank_size = 1
|
||||
_group_size = 4
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -48,10 +47,6 @@ class Hccl():
|
|||
def rank_size(self):
|
||||
return self._rank_size
|
||||
|
||||
@property
|
||||
def group_size(self):
|
||||
return self._group_size
|
||||
|
||||
@rank_size.setter
|
||||
def rank_size(self, size):
|
||||
self._rank_size = size
|
||||
|
@ -70,14 +65,6 @@ def get_rank_size(group=None):
|
|||
return int(group.split("-")[0])
|
||||
raise ValueError
|
||||
|
||||
def get_group_size(group=None):
|
||||
hccl = Hccl()
|
||||
if group is None:
|
||||
return hccl.group_size
|
||||
if isinstance(group, str):
|
||||
return int(group.split("-")[0])
|
||||
raise ValueError
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_world_rank_from_group_rank(group, group_rank_id):
|
||||
return group_rank_id
|
||||
|
|
|
@ -19,9 +19,6 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.communication.management import init
|
||||
from mindspore import context
|
||||
from mindspore import ParallelMode
|
||||
|
||||
|
||||
def test_bn_pars_valid1():
|
||||
|
|
Loading…
Reference in New Issue