forked from mindspore-Ecosystem/mindspore
Add UnsortedSegmentMax Operation
This commit is contained in:
parent
8cf3a072b9
commit
89e7778497
|
@ -179,6 +179,7 @@ REGISTER(SquareInfo);
|
||||||
REGISTER(UniformCandidateSamplerInfo);
|
REGISTER(UniformCandidateSamplerInfo);
|
||||||
REGISTER(UnsortedSegmentSumInfo);
|
REGISTER(UnsortedSegmentSumInfo);
|
||||||
REGISTER(UnsortedSegmentMinInfo);
|
REGISTER(UnsortedSegmentMinInfo);
|
||||||
|
REGISTER(UnsortedSegmentMaxInfo);
|
||||||
REGISTER(GatherV2PInfo);
|
REGISTER(GatherV2PInfo);
|
||||||
REGISTER(EmbeddingLookupInfo);
|
REGISTER(EmbeddingLookupInfo);
|
||||||
REGISTER(TileInfo);
|
REGISTER(TileInfo);
|
||||||
|
|
|
@ -305,6 +305,7 @@ constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
|
||||||
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
|
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
|
||||||
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
|
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
|
||||||
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
|
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
|
||||||
|
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
|
||||||
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
||||||
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
||||||
constexpr char ADD[] = "Add";
|
constexpr char ADD[] = "Add";
|
||||||
|
|
|
@ -332,5 +332,41 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo
|
||||||
|
// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op
|
||||||
|
ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
|
auto input_id_strategy = strategy_->GetInputDim().at(1);
|
||||||
|
// 1. the two input shapes are same, and the strategy is not all ones
|
||||||
|
if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int64_t &shard) { return shard > 1; })) {
|
||||||
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return replace_graph_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
GenerateGraph gen_g = GenerateGraph();
|
||||||
|
if (gen_g.Init(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// Get the attributes of the UnsortedSegmentMin
|
||||||
|
auto num_segments = GetValue<int64_t>(input_value_.at(2));
|
||||||
|
// Step1: Output branch
|
||||||
|
auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(),
|
||||||
|
gen_g.virtual_input_node(), CreatInt64Imm(num_segments)});
|
||||||
|
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)});
|
||||||
|
auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output});
|
||||||
|
auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MAX), all_gather_output, CreatInt64Imm(0)});
|
||||||
|
|
||||||
|
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(segment_max, 1),
|
||||||
|
std::make_pair(segment_max, 2)};
|
||||||
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||||
|
std::make_pair(input_nodes, final_output));
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -79,6 +79,20 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
||||||
|
public:
|
||||||
|
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
|
||||||
|
~UnsortedSegmentMaxInfo() override = default;
|
||||||
|
|
||||||
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
|
||||||
|
|
|
@ -317,7 +317,8 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
|
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
|
||||||
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
||||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE};
|
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE,
|
||||||
|
UNSORTED_SEGMENT_MAX};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
auto iter = splittable_op.find(op_name);
|
auto iter = splittable_op.find(op_name);
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||||
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2, num_segments):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.virtual_dataset = _VirtualDataset()
|
||||||
|
self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2))
|
||||||
|
self.num_segments = num_segments
|
||||||
|
|
||||||
|
def construct(self, vectors, segment_ids):
|
||||||
|
predict = self.merge_op(vectors, segment_ids, self.num_segments)
|
||||||
|
return predict
|
||||||
|
|
||||||
|
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return grad_all(self.network)(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithLoss(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLoss, self).__init__()
|
||||||
|
self.loss = VirtualLoss()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
predict = self.network(x, y)
|
||||||
|
return self.loss(predict)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
||||||
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
net.set_train()
|
||||||
|
if auto:
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
else:
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_slice_1d():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||||
|
x = Tensor(np.ones(8), ms.float32)
|
||||||
|
y = Tensor(np.ones(8), ms.int32)
|
||||||
|
num_segments = 16
|
||||||
|
strategy1 = (8,)
|
||||||
|
strategy2 = (8,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_no_slice_1d():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||||
|
x = Tensor(np.ones(8), ms.float32)
|
||||||
|
y = Tensor(np.ones(8), ms.int32)
|
||||||
|
num_segments = 16
|
||||||
|
strategy1 = (1,)
|
||||||
|
strategy2 = (1,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_index_slice_2d():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||||
|
y = Tensor(np.arange(4), ms.int32)
|
||||||
|
num_segments = 4
|
||||||
|
strategy1 = (4, 1)
|
||||||
|
strategy2 = (4,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_vector_slice_2d():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||||
|
y = Tensor(np.ones(4), ms.int32)
|
||||||
|
num_segments = 4
|
||||||
|
strategy1 = (1, 4)
|
||||||
|
strategy2 = (1,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_vector_slice_3d():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 8, 8)), ms.float32)
|
||||||
|
y = Tensor(np.ones(4), ms.int32)
|
||||||
|
num_segments = 4
|
||||||
|
strategy1 = (1, 2, 2)
|
||||||
|
strategy2 = (1,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_index_vector_slice_2d():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||||
|
y = Tensor(np.ones(4), ms.int32)
|
||||||
|
num_segments = 4
|
||||||
|
strategy1 = (2, 2)
|
||||||
|
strategy2 = (2,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_index_vector_slice_3d():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||||
|
y = Tensor(np.ones((4)), ms.int32)
|
||||||
|
num_segments = 16
|
||||||
|
strategy1 = (2, 1, 2)
|
||||||
|
strategy2 = (2,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_float16():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 4, 8)), ms.float16)
|
||||||
|
y = Tensor(np.ones((4)), ms.int32)
|
||||||
|
num_segments = 16
|
||||||
|
strategy1 = (2, 1, 2)
|
||||||
|
strategy2 = (2,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||||
|
|
||||||
|
def test_UnsortedSegmentMax_model_parallel_int32():
|
||||||
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||||
|
x = Tensor(np.ones((4, 4, 8)), ms.int32)
|
||||||
|
y = Tensor(np.ones((4)), ms.int32)
|
||||||
|
num_segments = 16
|
||||||
|
strategy1 = (2, 1, 2)
|
||||||
|
strategy2 = (2,)
|
||||||
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue