forked from mindspore-Ecosystem/mindspore
!2499 HostAllGather and HostReduceScatter change to internal interface
Merge pull request !2499 from yihuaijie/master
This commit is contained in:
commit
0a368494db
|
@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel {
|
||||||
std::vector<int> ranks_group_;
|
std::vector<int> ranks_group_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
AllGatherCPUKernel);
|
AllGatherCPUKernel);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel {
|
||||||
std::vector<int> ranks_group_;
|
std::vector<int> ranks_group_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
ReduceScatterCPUKernel);
|
ReduceScatterCPUKernel);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
|
||||||
constexpr char STRIDED_SLICE[] = "StridedSlice";
|
constexpr char STRIDED_SLICE[] = "StridedSlice";
|
||||||
constexpr char ALL_GATHER[] = "AllGather";
|
constexpr char ALL_GATHER[] = "AllGather";
|
||||||
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
|
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
|
||||||
constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter";
|
constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
|
||||||
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
|
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
|
||||||
constexpr char CONCAT[] = "Concat";
|
constexpr char CONCAT[] = "Concat";
|
||||||
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
|
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
|
||||||
|
|
|
@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
|
||||||
const char kNameAllReduce[] = "AllReduce";
|
const char kNameAllReduce[] = "AllReduce";
|
||||||
const char kNameBroadcast[] = "Broadcast";
|
const char kNameBroadcast[] = "Broadcast";
|
||||||
const char kNameAllgather[] = "AllGather";
|
const char kNameAllgather[] = "AllGather";
|
||||||
const char kNameHostAllgather[] = "HostAllGather";
|
|
||||||
const char kNameReduceScatter[] = "ReduceScatter";
|
const char kNameReduceScatter[] = "ReduceScatter";
|
||||||
const char kNameHostReduceScatter[] = "HostReduceScatter";
|
|
||||||
const char kNameReduceSum[] = "ReduceSum";
|
const char kNameReduceSum[] = "ReduceSum";
|
||||||
const char kNameIsFinite[] = "isFinite";
|
const char kNameIsFinite[] = "isFinite";
|
||||||
const char kNameReciprocal[] = "Reciprocal";
|
const char kNameReciprocal[] = "Reciprocal";
|
||||||
|
|
|
@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||||
_GetTensorSlice, _MirrorOperator, ReduceOp,
|
_GetTensorSlice, _MirrorOperator, ReduceOp,
|
||||||
ReduceScatter, HostReduceScatter, _VirtualDiv)
|
ReduceScatter, _HostReduceScatter, _VirtualDiv)
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,10 +93,10 @@ def get_bprop_all_gather(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(HostAllGather)
|
@bprop_getters.register(_HostAllGather)
|
||||||
def get_bprop_host_all_gather(self):
|
def get_bprop_host_all_gather(self):
|
||||||
"""Generate bprop for HostAllGather"""
|
"""Generate bprop for _HostAllGather"""
|
||||||
host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group)
|
host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
|
||||||
if self.instance_name:
|
if self.instance_name:
|
||||||
instance_name = "grad" + self.instance_name
|
instance_name = "grad" + self.instance_name
|
||||||
host_all_gather_grad.set_prim_instance_name(instance_name)
|
host_all_gather_grad.set_prim_instance_name(instance_name)
|
||||||
|
@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(HostReduceScatter)
|
@bprop_getters.register(_HostReduceScatter)
|
||||||
def get_bprop_host_reduce_scatter(self):
|
def get_bprop_host_reduce_scatter(self):
|
||||||
"""Generate bprop for HostReduceScatter"""
|
"""Generate bprop for _HostReduceScatter"""
|
||||||
host_reduce_scatter_grad = HostAllGather(self.group)
|
host_reduce_scatter_grad = _HostAllGather(self.group)
|
||||||
if self.instance_name:
|
if self.instance_name:
|
||||||
instance_name = "grad" + self.instance_name
|
instance_name = "grad" + self.instance_name
|
||||||
host_reduce_scatter_grad.set_prim_instance_name(instance_name)
|
host_reduce_scatter_grad.set_prim_instance_name(instance_name)
|
||||||
|
|
|
@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||||
_VirtualDiv, _GetTensorSlice,
|
_VirtualDiv, _GetTensorSlice,
|
||||||
HostAllGather, HostReduceScatter)
|
_HostAllGather, _HostReduceScatter)
|
||||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||||
TensorSummary, HistogramSummary, Debug, Print)
|
TensorSummary, HistogramSummary, Debug, Print)
|
||||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||||
|
@ -244,10 +244,8 @@ __all__ = [
|
||||||
'UnsortedSegmentSum',
|
'UnsortedSegmentSum',
|
||||||
'UnsortedSegmentMin',
|
'UnsortedSegmentMin',
|
||||||
"AllGather",
|
"AllGather",
|
||||||
"HostAllGather",
|
|
||||||
"AllReduce",
|
"AllReduce",
|
||||||
"ReduceScatter",
|
"ReduceScatter",
|
||||||
"HostReduceScatter",
|
|
||||||
"Broadcast",
|
"Broadcast",
|
||||||
"ReduceOp",
|
"ReduceOp",
|
||||||
'ScalarCast',
|
'ScalarCast',
|
||||||
|
|
|
@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||||
Perform the gradient for the communication part of EmbeddingLookup operator.
|
Perform the gradient for the communication part of EmbeddingLookup operator.
|
||||||
|
|
||||||
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
|
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
|
||||||
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
|
this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
|
||||||
"""
|
"""
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
This primitive is implemented by three steps:
|
This primitive is implemented by three steps:
|
||||||
1) Split the 'dy' along dimension 0 into 'split_num' parts.
|
1) Split the 'dy' along dimension 0 into 'split_num' parts.
|
||||||
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
|
2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
|
||||||
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
|
3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
|
||||||
along dimension 0.
|
along dimension 0.
|
||||||
|
|
||||||
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
|
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
|
||||||
|
|
|
@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class HostAllGather(PrimitiveWithInfer):
|
class _HostAllGather(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Gathers tensors from the specified communication group on host.
|
Gathers tensors from the specified communication group on host.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Tensor must have the same shape and format in all processes participating in the collective.
|
Tensor must have the same shape and format in all processes participating in the collective.
|
||||||
HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
|
_HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
|
||||||
to enable it. Using mpirun command to run it:
|
to enable it. Using mpirun command to run it:
|
||||||
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
|
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
|
||||||
|
|
||||||
|
@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer):
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor. If the number of devices in the group is N,
|
Tensor. If the number of devices in the group is N,
|
||||||
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
|
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> import mindspore.nn as nn
|
|
||||||
>>> import mindspore.context as context
|
|
||||||
>>> import mindspore.ops.operations as P
|
|
||||||
>>> from mindspore import Tensor
|
|
||||||
>>>
|
|
||||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|
||||||
>>> context.set_mpi_config(enable_mpi=True)
|
|
||||||
>>>
|
|
||||||
>>> class Net(nn.Cell):
|
|
||||||
>>> def __init__(self):
|
|
||||||
>>> super(Net, self).__init__()
|
|
||||||
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
|
|
||||||
>>>
|
|
||||||
>>> def construct(self, x):
|
|
||||||
>>> return self.hostallgather(x)
|
|
||||||
>>>
|
|
||||||
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
||||||
>>> net = Net()
|
|
||||||
>>> output = net(input_)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class HostReduceScatter(PrimitiveWithInfer):
|
class _HostReduceScatter(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Reduces and scatters tensors from the specified communication group on host.
|
Reduces and scatters tensors from the specified communication group on host.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Tensor must have the same shape and format in all processes participating in the collective.
|
Tensor must have the same shape and format in all processes participating in the collective.
|
||||||
HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
|
_HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
|
||||||
-M on to enable it. Using mpirun command to run it:
|
-M on to enable it. Using mpirun command to run it:
|
||||||
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
|
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
|
||||||
|
|
||||||
|
@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer):
|
||||||
or elements of group are not int.
|
or elements of group are not int.
|
||||||
ValueError: If the first dimension of input can not be divided by group size,
|
ValueError: If the first dimension of input can not be divided by group size,
|
||||||
or group is not set, or rank_id not in [0, 7].
|
or group is not set, or rank_id not in [0, 7].
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> import mindspore.nn as nn
|
|
||||||
>>> import mindspore.context as context
|
|
||||||
>>> import mindspore.ops.operations as P
|
|
||||||
>>> from mindspore import Tensor
|
|
||||||
>>> from mindspore.ops.operations.comm_ops import ReduceOp
|
|
||||||
>>>
|
|
||||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|
||||||
>>> context.set_mpi_config(enable_mpi=True)
|
|
||||||
>>>
|
|
||||||
>>> class Net(nn.Cell):
|
|
||||||
>>> def __init__(self):
|
|
||||||
>>> super(Net, self).__init__()
|
|
||||||
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
|
|
||||||
>>>
|
|
||||||
>>> def construct(self, x):
|
|
||||||
>>> return self.hostreducescatter(x)
|
|
||||||
>>>
|
|
||||||
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
|
|
||||||
>>> net = Net()
|
|
||||||
>>> output = net(input_)
|
|
||||||
"""
|
"""
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, op=ReduceOp.SUM, group=None):
|
def __init__(self, op=ReduceOp.SUM, group=None):
|
||||||
|
|
|
@ -1,76 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore.context as context
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
import mindspore._ms_mpi as mpi
|
|
||||||
# run comand:
|
|
||||||
# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|
||||||
context.set_mpi_config(enable_mpi=True)
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.op = "sum"
|
|
||||||
|
|
||||||
self.reducescatter = P.HostReduceScatter(op=self.op, group=[0,1,2])
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return self.reducescatter(x)
|
|
||||||
|
|
||||||
class AllGatherNet(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(AllGatherNet, self).__init__()
|
|
||||||
self.hostallgather = P.HostAllGather(group=(0, 1, 2))
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return self.hostallgather(x)
|
|
||||||
|
|
||||||
def test_net_reduce_scatter():
|
|
||||||
x = np.arange(12).astype(np.float32) * 0.1
|
|
||||||
|
|
||||||
reducescatter = Net()
|
|
||||||
rankid = mpi.get_rank_id()
|
|
||||||
print("self rankid:", rankid)
|
|
||||||
output = reducescatter(Tensor(x, mstype.float32))
|
|
||||||
print("output:\n", output)
|
|
||||||
if rankid == 0:
|
|
||||||
expect_result = np.arange(4).astype(np.float32) * 0.3
|
|
||||||
if rankid == 1:
|
|
||||||
expect_result = np.arange(4, 8).astype(np.float32) * 0.3
|
|
||||||
if rankid == 2:
|
|
||||||
expect_result = np.arange(8, 12).astype(np.float32) * 0.3
|
|
||||||
diff = abs(output.asnumpy() - expect_result)
|
|
||||||
error = np.ones(shape=expect_result.shape) * 1.0e-6
|
|
||||||
assert np.all(diff < error)
|
|
||||||
|
|
||||||
allgather = AllGatherNet()
|
|
||||||
allgather_output = allgather(output)
|
|
||||||
print("allgather result:\n", allgather_output)
|
|
||||||
expect_allgather_result = np.arange(12).astype(np.float32) * 0.3
|
|
||||||
diff = abs(allgather_output.asnumpy() - expect_allgather_result)
|
|
||||||
error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
|
|
||||||
assert np.all(diff < error)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_net_reduce_scatter()
|
|
|
@ -26,7 +26,6 @@ from mindspore.nn import Momentum
|
||||||
from mindspore.nn import ReLU
|
from mindspore.nn import ReLU
|
||||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
||||||
from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter
|
|
||||||
from mindspore.ops.operations.comm_ops import Broadcast
|
from mindspore.ops.operations.comm_ops import Broadcast
|
||||||
|
|
||||||
# pylint: disable=W0212
|
# pylint: disable=W0212
|
||||||
|
@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell):
|
||||||
return self.relu(x)
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
class HostAllGatherNet(nn.Cell):
|
|
||||||
"""HostAllGatherNet definition"""
|
|
||||||
|
|
||||||
def __init__(self, input_channel, output_channel):
|
|
||||||
super(HostAllGatherNet, self).__init__()
|
|
||||||
self.dense = Dense(input_channel, output_channel)
|
|
||||||
self.hostallgather = HostAllGather((0, 1))
|
|
||||||
self.relu = ReLU()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
x = self.dense(x)
|
|
||||||
x = self.hostallgather(x)
|
|
||||||
return self.relu(x)
|
|
||||||
|
|
||||||
|
|
||||||
class ReduceScatterNet(nn.Cell):
|
class ReduceScatterNet(nn.Cell):
|
||||||
"""ReduceScatterNet definition"""
|
"""ReduceScatterNet definition"""
|
||||||
|
|
||||||
|
@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell):
|
||||||
return self.relu(x)
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
class HostReduceScatterNet(nn.Cell):
|
|
||||||
"""HostReduceScatterNet definition"""
|
|
||||||
|
|
||||||
def __init__(self, input_channel, out_channel, op):
|
|
||||||
super(HostReduceScatterNet, self).__init__()
|
|
||||||
self.dense = Dense(input_channel, out_channel)
|
|
||||||
self.hostreducescatter = HostReduceScatter(op, (0, 1))
|
|
||||||
self.relu = ReLU()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
x = self.dense(x)
|
|
||||||
x = self.hostreducescatter(x)
|
|
||||||
return self.relu(x)
|
|
||||||
|
|
||||||
|
|
||||||
class AlltoAllNet(nn.Cell):
|
class AlltoAllNet(nn.Cell):
|
||||||
"""AlltoAllNet definition"""
|
"""AlltoAllNet definition"""
|
||||||
|
|
||||||
|
@ -185,21 +154,6 @@ def test_allgather():
|
||||||
_executor.compile(network, input_tensor, label_tensor)
|
_executor.compile(network, input_tensor, label_tensor)
|
||||||
|
|
||||||
|
|
||||||
def test_hostallgather():
|
|
||||||
"""test_hostallgather"""
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
|
||||||
label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32))
|
|
||||||
network = HostAllGatherNet(2, 1)
|
|
||||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
|
||||||
learning_rate=0.1,
|
|
||||||
momentum=0.9)
|
|
||||||
network = WithLossCell(network, loss_fn)
|
|
||||||
network = TrainOneStepCell(network, optimizer)
|
|
||||||
_executor.compile(network, input_tensor, label_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def run_reducescatter(op):
|
def run_reducescatter(op):
|
||||||
"""run_reducescatter"""
|
"""run_reducescatter"""
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
@ -221,21 +175,6 @@ def test_reducescatter():
|
||||||
run_reducescatter(ReduceOp.SUM)
|
run_reducescatter(ReduceOp.SUM)
|
||||||
|
|
||||||
|
|
||||||
def test_hostreducescatter():
|
|
||||||
"""test_hostreducescatter"""
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
|
||||||
label_tensor = Tensor(np.array([[1.2]], dtype=np.float32))
|
|
||||||
network = HostReduceScatterNet(2, 1, ReduceOp.SUM)
|
|
||||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
|
||||||
learning_rate=0.1,
|
|
||||||
momentum=0.9)
|
|
||||||
network = WithLossCell(network, loss_fn)
|
|
||||||
network = TrainOneStepCell(network, optimizer)
|
|
||||||
_executor.compile(network, input_tensor, label_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_broadcast():
|
def test_broadcast():
|
||||||
"""test_broadcast"""
|
"""test_broadcast"""
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
Loading…
Reference in New Issue