forked from mindspore-Ecosystem/mindspore
add sparse gatherv2
This commit is contained in:
parent
871d6524c3
commit
e0e055a0b8
|
@ -121,6 +121,7 @@ REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
|
|||
REGISTER(AssignSubInfo);
|
||||
REGISTER(ReLUInfo);
|
||||
REGISTER(GatherV2Info);
|
||||
REGISTER(SparseGatherV2Info);
|
||||
REGISTER(SqrtInfo);
|
||||
REGISTER(SigmoidInfo);
|
||||
REGISTER(GetNextInfo);
|
||||
|
|
|
@ -399,7 +399,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
|
||||
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
|
||||
auto gather_v2 =
|
||||
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
|
||||
gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
|
||||
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
|
||||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)});
|
||||
|
|
|
@ -63,6 +63,7 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
|
||||
int32_t axis_;
|
||||
std::string target_;
|
||||
std::string replace_op_name_ = GATHERV2;
|
||||
int32_t bias_;
|
||||
int32_t slice_size_;
|
||||
Shape out_dev_matrix_shape_;
|
||||
|
@ -70,6 +71,17 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
bool reduce_scatter_flag_ = false;
|
||||
int32_t split_num_ = 1;
|
||||
};
|
||||
|
||||
class SparseGatherV2Info : public GatherV2PInfo {
|
||||
public:
|
||||
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~SparseGatherV2Info() override = default;
|
||||
|
||||
private:
|
||||
std::string replace_op_name_ = SPARSE_GATHERV2;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||
|
|
|
@ -205,6 +205,7 @@ constexpr char EQUAL[] = "Equal";
|
|||
constexpr char NOT_EQUAL[] = "NotEqual";
|
||||
constexpr char LOGICALNOT[] = "LogicalNot";
|
||||
constexpr char GATHERV2[] = "GatherV2";
|
||||
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
|
||||
constexpr char STRIDEDSLICE[] = "StridedSlice";
|
||||
constexpr char BROADCAST[] = "Broadcast";
|
||||
constexpr char SQRT[] = "Sqrt";
|
||||
|
|
|
@ -261,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -535,7 +535,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
}
|
||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
if (prim->name() == GATHERV2) {
|
||||
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
|
||||
}
|
||||
if (!params.empty()) {
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
|
||||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.SparseGatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.gatherv2(x, self.index, self.axis)
|
||||
out = self.mul(out, y)
|
||||
return out
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto3():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto4():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto5():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto6():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto7():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto8():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8,), (1, 1))
|
||||
strategy2 = ((4, 2), (4, 2))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(0)))
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_auto1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(1)))
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu1():
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
Loading…
Reference in New Issue