add sparse gatherv2

This commit is contained in:
lichenever 2020-06-11 20:21:28 +08:00
parent 871d6524c3
commit e0e055a0b8
7 changed files with 237 additions and 3 deletions

View File

@ -121,6 +121,7 @@ REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
REGISTER(AssignSubInfo);
REGISTER(ReLUInfo);
REGISTER(GatherV2Info);
REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo);
REGISTER(SigmoidInfo);
REGISTER(GetNextInfo);

View File

@ -399,7 +399,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
auto gather_v2 =
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)});

View File

@ -63,6 +63,7 @@ class GatherV2PInfo : public OperatorInfo {
int32_t axis_;
std::string target_;
std::string replace_op_name_ = GATHERV2;
int32_t bias_;
int32_t slice_size_;
Shape out_dev_matrix_shape_;
@ -70,6 +71,17 @@ class GatherV2PInfo : public OperatorInfo {
bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1;
};
class SparseGatherV2Info : public GatherV2PInfo {
public:
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
~SparseGatherV2Info() override = default;
private:
std::string replace_op_name_ = SPARSE_GATHERV2;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_

View File

@ -205,6 +205,7 @@ constexpr char EQUAL[] = "Equal";
constexpr char NOT_EQUAL[] = "NotEqual";
constexpr char LOGICALNOT[] = "LogicalNot";
constexpr char GATHERV2[] = "GatherV2";
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
constexpr char STRIDEDSLICE[] = "StridedSlice";
constexpr char BROADCAST[] = "Broadcast";
constexpr char SQRT[] = "Sqrt";

View File

@ -261,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) {
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
// clang-format on

View File

@ -535,7 +535,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2) {
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
}
if (!params.empty()) {

View File

@ -0,0 +1,220 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.SparseGatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().set_strategy(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
def construct(self, x, y):
out = self.gatherv2(x, self.index, self.axis)
out = self.mul(out, y)
return out
def test_gatherv2_semi_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto3():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto4():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto5():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto6():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto7():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_auto8():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8,), (1, 1))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(1)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)