add reduceany operator and extend onehot to multi-dimensions

This commit is contained in:
Xiaoda Zhang 2021-07-31 14:33:36 +08:00
parent 9720bab9c9
commit 4b4b3cdaf4
7 changed files with 176 additions and 17 deletions

View File

@ -203,6 +203,7 @@ REGISTER(BatchNormInfo);
REGISTER(MaxPoolInfo);
REGISTER(AvgPoolInfo);
REGISTER(GatherDInfo);
REGISTER(ReduceAnyInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -42,11 +42,6 @@ Status OneHotInfo::GetAttrs() {
}
}
if (inputs_shape_[0].size() != 1) {
MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now.";
return FAILED;
}
if ((axis_ > 1) || (axis_ < -1)) {
MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1].";
return FAILED;
@ -55,23 +50,38 @@ Status OneHotInfo::GetAttrs() {
}
Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)});
if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}) != SUCCESS) {
return FAILED;
}
auto stra = strategy->GetInputDim().at(0);
bool invalid = false;
for (size_t i = 1; i < stra.size(); ++i) {
if (stra.at(i) != 1) {
invalid = true;
break;
}
}
if ((inputs_shape_.at(0).size() > 1) && ((axis_ != -1) || invalid)) {
MS_LOG(ERROR) << "When input dimension is > 1, axis must be -1, and strategy must be data parallel.";
return FAILED;
}
return SUCCESS;
}
Status OneHotInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0);
// Now input only support 1-D tensor, so the output is a 2-D tensor
// If input is a vector of length features, the output shape will be:
// [features, depth] if axis == -1 (or axis == 1)
// [depth, features] if axis == 0
if (axis_ == 0) {
// Here, only support 1-D input tensor, so the output is a 2-D tensor
// If input is a vector of length features, the output shape will be:
// [depth, features] if axis == 0
dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable
dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable
} else {
dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable
dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable
for (const auto &input_stra : input_strategy) {
dev_matrix_shape_.push_back(input_stra);
}
}
old_dev_matrix_back_ = dev_matrix_shape_.back();
repeated_num_in_dev_matrix_right_ = false;
@ -81,21 +91,21 @@ Status OneHotInfo::InferDevMatrixShape() {
Status OneHotInfo::InferTensorMap() {
Shape input_tensor_map_index, output_tensor_map_index;
size_t size = outputs_shape_[0].size();
// such as 2: tensor_map_index [1,0]
if (axis_ == 0) {
for (size_t i = 0; i < size; ++i) {
output_tensor_map_index.push_back((int64_t)(i));
}
input_tensor_map_index.push_back(1);
} else {
for (size_t i = 0; i < size; ++i) {
output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
}
for (size_t i = 0; i < size - 1; ++i) {
input_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
}
}
outputs_tensor_map_.push_back(output_tensor_map_index);
// Now input only support 1-D tensor
input_tensor_map_index.push_back(1);
inputs_tensor_map_.push_back(input_tensor_map_index);
return SUCCESS;
}

View File

@ -87,6 +87,7 @@ constexpr char GEN_STRATEGY[] = "gen_strategy";
constexpr char REDUCE_OP_SUM[] = "sum";
constexpr char REDUCE_OP_MAX[] = "max";
constexpr char REDUCE_OP_MIN[] = "min";
constexpr char REDUCE_OP_ANY[] = "any";
constexpr char OP_PATH[] = "mindspore.ops.operations";
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";

View File

@ -539,5 +539,24 @@ std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage
return sp_vector;
}
Status ReduceAnyInfo::CheckStrategy(const StrategyPtr &strategy) {
if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": checking strategy failed.";
return FAILED;
}
auto dim_list = ReduceMethod::reduce_dim();
Dimensions stra = strategy->GetInputDim().at(0);
for (size_t index = 0; index < stra.size(); ++index) {
auto pos =
std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
if (pos != dim_list.end() && stra[index] != 1) {
MS_LOG(ERROR) << name_
<< ": checking strategy failed. ReduceAny operator does not support reduced dimension split.";
return FAILED;
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -124,6 +124,19 @@ class ReduceSumInfo : public ReduceMethod {
~ReduceSumInfo() override = default;
};
class ReduceAnyInfo : public ReduceMethod {
public:
ReduceAnyInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceSumCost>()) {
reduce_method_ = REDUCE_OP_ANY;
}
~ReduceAnyInfo() override = default;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
};
class ReduceMinInfo : public ReduceMethod {
public:
ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, wi, stra1=None, stra2=None, stra3=None):
super(Net, self).__init__()
self.wi = Parameter(wi, "wi")
self.matmul = P.MatMul().shard(stra1)
self.onehot = P.OneHot(axis=-1).shard(stra2)
self.mul = P.Mul().shard(stra3)
self.on_value = Tensor(1.0, ms.float32)
self.off_value = Tensor(0.0, ms.float32)
self.cast = P.Cast()
self.depth = 48
def construct(self, x):
output = self.matmul(x, self.wi)
output = self.cast(output, ms.int32)
output = self.onehot(output, self.depth, self.on_value, self.off_value)
output = self.mul(output, output)
return output
_x = Tensor(np.ones([16, 48]), dtype=ms.float32)
_wi = Tensor(np.ones([48, 16]), dtype=ms.float32)
def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x)
context.reset_auto_parallel_context()
def test_onehot():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True,
global_rank=0)
stra1 = ((8, 1), (1, 1))
stra2 = ((8, 1, 1), (), ())
stra3 = ((8, 1, 1), (8, 1, 1))
net = Net(_wi, stra1=stra1, stra2=stra2, stra3=stra3)
compile_net(net)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
@ -585,3 +585,54 @@ def test_max_empty_tuple():
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_any_mul():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.mul1 = P.Mul().shard(strategy1)
self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
self.cast = P.Cast()
def construct(self, x, y):
out = self.mul1(x, y)
out = self.cast(out, ms.bool_)
out = self.reduce_any(out, 1)
return out
context.set_auto_parallel_context(device_num=64, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
with pytest.raises(RuntimeError):
compile_net_no_bias(net, x, y)
def test_any_mul2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.mul1 = P.Mul().shard(strategy1)
self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
self.cast = P.Cast()
def construct(self, x, y):
out = self.mul1(x, y)
out = self.cast(out, ms.bool_)
out = self.reduce_any(out, -1)
return out
context.set_auto_parallel_context(device_num=64, global_rank=0)
strategy1 = ((8, 1, 1), (8, 1, 1))
strategy2 = ((8, 1, 1),)
net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
compile_net_no_bias(net, x, y)