forked from mindspore-Ecosystem/mindspore
Add AllSwap Op
This commit is contained in:
parent
e1cfeeb1dd
commit
23284f0b35
|
@ -217,6 +217,8 @@ AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -367,6 +367,45 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co
|
|||
return sparse_tensor->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto tensor_in = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_in);
|
||||
MS_EXCEPTION_IF_NULL(tensor_in->shape());
|
||||
auto tensor_in_shape = tensor_in->shape()->shape();
|
||||
|
||||
auto send_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(send_size);
|
||||
auto recv_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
MS_EXCEPTION_IF_NULL(recv_size);
|
||||
|
||||
// Get the content of the recv size
|
||||
auto recv_size_value_ptr = recv_size->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(recv_size_value_ptr);
|
||||
auto recv_size_tensor = recv_size_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_size_tensor);
|
||||
auto data_pos = reinterpret_cast<int64_t *>(recv_size_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(data_pos);
|
||||
int64_t infer_max_size = 0;
|
||||
for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) {
|
||||
infer_max_size += *(data_pos + i);
|
||||
}
|
||||
|
||||
ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]};
|
||||
ShapeVector min_shape = {1, tensor_in_shape[1]};
|
||||
|
||||
ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]};
|
||||
|
||||
auto tensor_out = std::make_shared<AbstractTensor>(tensor_in->element(),
|
||||
std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
|
||||
|
||||
AbstractTensorPtr ret = std::make_shared<AbstractTensor>(
|
||||
tensor_out->element(), std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -135,6 +135,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimAllReduce, {InferImplAllReduce, true}},
|
||||
{prim::kPrimBroadcast, {InferImplBroadcast, true}},
|
||||
{prim::kPrimAllGather, {InferImplAllGather, true}},
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
|
||||
{prim::kPrimCast, {InferImplCast, true}},
|
||||
|
|
|
@ -186,6 +186,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper
|
|||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
|
||||
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
|
||||
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
|
||||
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
|
||||
|
|
|
@ -21,7 +21,7 @@ from ...common.tensor import RowTensor
|
|||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv)
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
|
||||
from .grad_base import bprop_getters
|
||||
|
||||
|
||||
|
@ -155,6 +155,21 @@ def get_bprop_reduce_scatter(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(AllSwap)
|
||||
def get_bprop_allswap(self):
|
||||
"""Generate bprop for AllSwap."""
|
||||
all_swap_grad = AllSwap(self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
all_to_all_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
def bprop(x, send_size, recv_size, out, dout):
|
||||
dx = all_swap_grad(dout, recv_size, send_size)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_HostReduceScatter)
|
||||
def get_bprop_host_reduce_scatter(self):
|
||||
"""Generate bprop for _HostReduceScatter"""
|
||||
|
|
|
@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
Unique, GatherD, Identity, RepeatElements)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice, Send, Receive,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
|
@ -294,6 +294,7 @@ __all__ = [
|
|||
'UnsortedSegmentProd',
|
||||
"AllGather",
|
||||
"AllReduce",
|
||||
"AllSwap",
|
||||
"ReduceScatter",
|
||||
"Broadcast",
|
||||
"ReduceOp",
|
||||
|
|
|
@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator
|
|||
from ..._checkparam import Rel
|
||||
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||
|
||||
|
||||
class ReduceOp:
|
||||
|
@ -507,6 +507,59 @@ class Broadcast(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class AllSwap(PrimitiveWithCheck):
|
||||
"""
|
||||
AllSwap is a collective operation.
|
||||
|
||||
AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
|
||||
|
||||
- The scatter phase: On each process, the operand is split into the send size of blocks along the
|
||||
0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
|
||||
- The gather phase: Each process concatenates the received blocks along the 0-th axis.
|
||||
|
||||
Note:
|
||||
The tensors must have the same format in all processes of the collection.
|
||||
|
||||
Args:
|
||||
group (str): The communication group name.
|
||||
|
||||
Inputs:
|
||||
tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
|
||||
send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
|
||||
recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
|
||||
|
||||
Returns:
|
||||
tensor_out (tensor): The result tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize AllSwap"""
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
|
||||
def __check__(self, tensor_in, send_size, recv_size):
|
||||
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
|
||||
self.name)
|
||||
validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
|
||||
self.name)
|
||||
|
||||
validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
|
||||
validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
|
||||
validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
|
||||
|
||||
out_shape = [-1] + [tensor_in['shape'][1]]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': tensor_in['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class _AlltoAll(PrimitiveWithInfer):
|
||||
"""
|
||||
AlltoAll is a collective operation.
|
||||
|
|
|
@ -26,7 +26,9 @@ from mindspore.nn import Momentum
|
|||
from mindspore.nn import ReLU
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import Broadcast
|
||||
from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
|
||||
from mindspore.ops.operations.math_ops import ReduceSum
|
||||
import mindspore
|
||||
|
||||
# pylint: disable=W0212
|
||||
# W0212: protected-access
|
||||
|
@ -117,6 +119,25 @@ class AlltoAllNet(nn.Cell):
|
|||
return self.relu(x)
|
||||
|
||||
|
||||
class AllSwapNet(nn.Cell):
|
||||
"""AlltoAllNet definition"""
|
||||
|
||||
def __init__(self, batch_size, input_channel, out_channel):
|
||||
super(AllSwapNet, self).__init__()
|
||||
self.dense = Dense(input_channel, out_channel)
|
||||
self.allswap = AllSwap()
|
||||
self.relu = ReLU()
|
||||
self.reduce = ReduceSum()
|
||||
part_slice = batch_size / 2
|
||||
self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)
|
||||
self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64)
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
x = self.allswap(x, self.send_size, self.recv_size)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_allreduce(op):
|
||||
"""run_allreduce"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -154,6 +175,13 @@ def test_allgather():
|
|||
network = TrainOneStepCell(network, optimizer)
|
||||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
def test_allswap():
|
||||
"""run_allswap"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32)
|
||||
network = AllSwapNet(100, 20, 20)
|
||||
_executor.compile(network, input_tensor)
|
||||
|
||||
|
||||
def run_reducescatter(op):
|
||||
"""run_reducescatter"""
|
||||
|
|
Loading…
Reference in New Issue