diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 1c8128f2660..4264a6b3a5d 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -203,6 +203,7 @@ REGISTER(BatchNormInfo); REGISTER(MaxPoolInfo); REGISTER(AvgPoolInfo); REGISTER(GatherDInfo); +REGISTER(ReduceAnyInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index 8f318be3b0e..763e787a537 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -42,11 +42,6 @@ Status OneHotInfo::GetAttrs() { } } - if (inputs_shape_[0].size() != 1) { - MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now."; - return FAILED; - } - if ((axis_ > 1) || (axis_ < -1)) { MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1]."; return FAILED; @@ -55,23 +50,38 @@ Status OneHotInfo::GetAttrs() { } Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { - return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}); + if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}) != SUCCESS) { + return FAILED; + } + auto stra = strategy->GetInputDim().at(0); + bool invalid = false; + for (size_t i = 1; i < stra.size(); ++i) { + if (stra.at(i) != 1) { + invalid = true; + break; + } + } + if ((inputs_shape_.at(0).size() > 1) && ((axis_ != -1) || invalid)) { + MS_LOG(ERROR) << "When input dimension is > 1, axis must be -1, and strategy must be data parallel."; + return FAILED; + } + return SUCCESS; } Status OneHotInfo::InferDevMatrixShape() { Strategys stra = strategy_->GetInputDim(); Dimensions input_strategy = stra.at(0); - // Now input only support 1-D tensor, so the output is a 2-D tensor - // If input is a vector of length features, the output shape will be: - // [features, depth] if axis == -1 (or axis == 1) - // [depth, features] if axis == 0 if (axis_ == 0) { + // Here, only support 1-D input tensor, so the output is a 2-D tensor + // If input is a vector of length features, the output shape will be: + // [depth, features] if axis == 0 dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable } else { - dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable - dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable + for (const auto &input_stra : input_strategy) { + dev_matrix_shape_.push_back(input_stra); + } } old_dev_matrix_back_ = dev_matrix_shape_.back(); repeated_num_in_dev_matrix_right_ = false; @@ -81,21 +91,21 @@ Status OneHotInfo::InferDevMatrixShape() { Status OneHotInfo::InferTensorMap() { Shape input_tensor_map_index, output_tensor_map_index; size_t size = outputs_shape_[0].size(); - // such as 2: tensor_map_index [1,0] if (axis_ == 0) { for (size_t i = 0; i < size; ++i) { output_tensor_map_index.push_back((int64_t)(i)); } + input_tensor_map_index.push_back(1); } else { for (size_t i = 0; i < size; ++i) { output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i)); } + for (size_t i = 0; i < size - 1; ++i) { + input_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i)); + } } outputs_tensor_map_.push_back(output_tensor_map_index); - // Now input only support 1-D tensor - input_tensor_map_index.push_back(1); - inputs_tensor_map_.push_back(input_tensor_map_index); return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 118ddff948b..dd6a3237da5 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -87,6 +87,7 @@ constexpr char GEN_STRATEGY[] = "gen_strategy"; constexpr char REDUCE_OP_SUM[] = "sum"; constexpr char REDUCE_OP_MAX[] = "max"; constexpr char REDUCE_OP_MIN[] = "min"; +constexpr char REDUCE_OP_ANY[] = "any"; constexpr char OP_PATH[] = "mindspore.ops.operations"; constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index 26f7fefbf70..f9d468489ab 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -539,5 +539,24 @@ std::vector ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage return sp_vector; } + +Status ReduceAnyInfo::CheckStrategy(const StrategyPtr &strategy) { + if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": checking strategy failed."; + return FAILED; + } + auto dim_list = ReduceMethod::reduce_dim(); + Dimensions stra = strategy->GetInputDim().at(0); + for (size_t index = 0; index < stra.size(); ++index) { + auto pos = + std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; }); + if (pos != dim_list.end() && stra[index] != 1) { + MS_LOG(ERROR) << name_ + << ": checking strategy failed. ReduceAny operator does not support reduced dimension split."; + return FAILED; + } + } + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h index 33b09aeed66..15680184c56 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h @@ -124,6 +124,19 @@ class ReduceSumInfo : public ReduceMethod { ~ReduceSumInfo() override = default; }; +class ReduceAnyInfo : public ReduceMethod { + public: + ReduceAnyInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared()) { + reduce_method_ = REDUCE_OP_ANY; + } + ~ReduceAnyInfo() override = default; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; +}; + class ReduceMinInfo : public ReduceMethod { public: ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, diff --git a/tests/ut/python/parallel/test_onehot_2dim.py b/tests/ut/python/parallel/test_onehot_2dim.py new file mode 100644 index 00000000000..26d4bb22a3b --- /dev/null +++ b/tests/ut/python/parallel/test_onehot_2dim.py @@ -0,0 +1,64 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops import operations as P + +class Net(nn.Cell): + def __init__(self, wi, stra1=None, stra2=None, stra3=None): + super(Net, self).__init__() + self.wi = Parameter(wi, "wi") + self.matmul = P.MatMul().shard(stra1) + self.onehot = P.OneHot(axis=-1).shard(stra2) + self.mul = P.Mul().shard(stra3) + self.on_value = Tensor(1.0, ms.float32) + self.off_value = Tensor(0.0, ms.float32) + self.cast = P.Cast() + self.depth = 48 + + def construct(self, x): + output = self.matmul(x, self.wi) + output = self.cast(output, ms.int32) + output = self.onehot(output, self.depth, self.on_value, self.off_value) + output = self.mul(output, output) + return output + +_x = Tensor(np.ones([16, 48]), dtype=ms.float32) +_wi = Tensor(np.ones([48, 16]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, _x) + context.reset_auto_parallel_context() + + +def test_onehot(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True, + global_rank=0) + stra1 = ((8, 1), (1, 1)) + stra2 = ((8, 1, 1), (), ()) + stra3 = ((8, 1, 1), (8, 1, 1)) + net = Net(_wi, stra1=stra1, stra2=stra2, stra3=stra3) + compile_net(net) diff --git a/tests/ut/python/parallel/test_reduce_method_info.py b/tests/ut/python/parallel/test_reduce_method_info.py index 14021c0a3c4..ddf44098c48 100644 --- a/tests/ut/python/parallel/test_reduce_method_info.py +++ b/tests/ut/python/parallel/test_reduce_method_info.py @@ -13,7 +13,7 @@ # limitations under the License. import numpy as np - +import pytest import mindspore as ms import mindspore.nn as nn from mindspore import Tensor @@ -585,3 +585,54 @@ def test_max_empty_tuple(): b = Tensor(np.ones([128, 32]), dtype=ms.float32) compile_net(net, x, y, b) + + +def test_any_mul(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.mul1 = P.Mul().shard(strategy1) + self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2) + self.cast = P.Cast() + + def construct(self, x, y): + out = self.mul1(x, y) + out = self.cast(out, ms.bool_) + out = self.reduce_any(out, 1) + return out + + context.set_auto_parallel_context(device_num=64, global_rank=0) + strategy1 = ((1, 8, 1), (1, 8, 1)) + strategy2 = ((1, 8, 1),) + net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) + y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) + with pytest.raises(RuntimeError): + compile_net_no_bias(net, x, y) + + +def test_any_mul2(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.mul1 = P.Mul().shard(strategy1) + self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2) + self.cast = P.Cast() + + def construct(self, x, y): + out = self.mul1(x, y) + out = self.cast(out, ms.bool_) + out = self.reduce_any(out, -1) + return out + + context.set_auto_parallel_context(device_num=64, global_rank=0) + strategy1 = ((8, 1, 1), (8, 1, 1)) + strategy2 = ((8, 1, 1),) + net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) + y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) + compile_net_no_bias(net, x, y)