Add UnsortedSegmentMax Operation

This commit is contained in:
huangxinjing 2020-11-19 17:32:40 +08:00
parent 8cf3a072b9
commit 89e7778497
8 changed files with 220 additions and 3 deletions

View File

@ -179,6 +179,7 @@ REGISTER(SquareInfo);
REGISTER(UniformCandidateSamplerInfo);
REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo);
REGISTER(UnsortedSegmentMaxInfo);
REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);

View File

@ -305,6 +305,7 @@ constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";

View File

@ -332,5 +332,41 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS;
}
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
auto input_id_strategy = strategy_->GetInputDim().at(1);
// 1. the two input shapes are same, and the strategy is not all ones
if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int64_t &shard) { return shard > 1; })) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
}
return replace_graph_;
}
Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph();
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
// Get the attributes of the UnsortedSegmentMin
auto num_segments = GetValue<int64_t>(input_value_.at(2));
// Step1: Output branch
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});
auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output});
auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MAX), all_gather_output, CreatInt64Imm(0)});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(segment_max, 1),
std::make_pair(segment_max, 2)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, final_output));
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -79,6 +79,20 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
Status ComputeReplaceGraph(const CNodePtr &cnode);
};
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
~UnsortedSegmentMaxInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
Status InferForwardCommunication() override { return SUCCESS; }
protected:
Status ComputeReplaceGraph(const CNodePtr &cnode);
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_

View File

@ -317,7 +317,8 @@ bool IsSplittableOperator(const std::string &op_name) {
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE};
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE,
UNSORTED_SEGMENT_MAX};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,162 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2))
self.num_segments = num_segments
def construct(self, vectors, segment_ids):
predict = self.merge_op(vectors, segment_ids, self.num_segments)
return predict
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
_executor.compile(net, x, y)
def test_UnsortedSegmentMax_model_parallel_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (8,)
strategy2 = (8,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_no_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (1,)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_index_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.arange(4), ms.int32)
num_segments = 4
strategy1 = (4, 1)
strategy2 = (4,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 4)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 2, 2)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_index_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (2, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_index_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_float16():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float16)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_UnsortedSegmentMax_model_parallel_int32():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.int32)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np