!115 Add parallel operator for ExpandDims
Merge pull request !115 from yangzhenzhang/expand-dim
This commit is contained in:
commit
2c3c1577b1
|
@ -125,6 +125,7 @@ REGISTER(SqrtInfo);
|
|||
REGISTER(GetNextInfo);
|
||||
REGISTER(NegInfo);
|
||||
REGISTER(BatchMatMulInfo);
|
||||
REGISTER(ExpandDimsInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -381,5 +381,168 @@ Status CastInfo::InferMirrorOps() {
|
|||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExpandDimsInfo::GetAttrs() {
|
||||
if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (!input_value_.back()->isa<Int32Imm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": The type of axis is not int";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int32_t axis = GetValue<int32_t>(input_value_.back());
|
||||
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int32_t dim = SizeToInt(inputs_shape_[0].size());
|
||||
if ((axis > dim) || (axis < -dim - 1)) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (axis < 0) {
|
||||
positive_axis_ = dim + axis + 1;
|
||||
} else {
|
||||
positive_axis_ = axis;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExpandDimsInfo::InferTensorMap() {
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// for example: if the dimension of input is 3, and the axis is 2,
|
||||
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0]
|
||||
std::vector<int32_t> input_tensor_map, output_tensor_map;
|
||||
size_t size = inputs_shape_[0].size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
input_tensor_map.push_back(SizeToInt(size - i - 1));
|
||||
}
|
||||
|
||||
inputs_tensor_map_.push_back(input_tensor_map);
|
||||
|
||||
output_tensor_map = input_tensor_map;
|
||||
if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
|
||||
return FAILED;
|
||||
}
|
||||
(void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP);
|
||||
outputs_tensor_map_.push_back(output_tensor_map);
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
|
||||
<< ", and the tensor map of output is " << ShapeToString(output_tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExpandDimsInfo::InferTensorStrategy() {
|
||||
if (strategy_ == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is null";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
inputs_strategy_ = strategy_->GetInputDim();
|
||||
if (inputs_strategy_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape output_strategy = inputs_strategy_[0];
|
||||
if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
|
||||
return FAILED;
|
||||
}
|
||||
(void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY);
|
||||
outputs_strategy_ = {output_strategy};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExpandDimsInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_shape = inputs_shape_[0];
|
||||
Shape output_shape = outputs_shape_[0];
|
||||
|
||||
// infer slice shape
|
||||
if (InferTensorStrategy() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_slice_shape = inputs_slice_shape[0];
|
||||
Shape output_slice_shape = outputs_slice_shape[0];
|
||||
|
||||
TensorLayout input_tensor_layout, output_tensor_layout;
|
||||
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
|
||||
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExpandDimsInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
|
||||
if (inputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": No need to create mirror ops";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector mirror_op, placeholder_op;
|
||||
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
mirror_ops_.push_back(placeholder_op);
|
||||
MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -174,6 +174,26 @@ class NegInfo : public ActivationOther {
|
|||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~NegInfo() override = default;
|
||||
};
|
||||
|
||||
class ExpandDimsInfo : public ActivationOther {
|
||||
public:
|
||||
ExpandDimsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~ExpandDimsInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferTensorInfo() override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferTensorStrategy();
|
||||
|
||||
private:
|
||||
int32_t positive_axis_ = -1;
|
||||
Strategys inputs_strategy_;
|
||||
Strategys outputs_strategy_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_
|
||||
|
|
|
@ -24,6 +24,8 @@ constexpr size_t PRELU_OUTPUTS_SIZE = 1;
|
|||
constexpr size_t PRELU_SECOND_INPUT_SIZE = 1;
|
||||
constexpr int32_t PRELU_CHANNEL_INDEX = 1;
|
||||
constexpr int32_t PRELU_CHANNEL_STRATEGY = 1;
|
||||
constexpr int32_t NO_SPLIT_MAP = -1;
|
||||
constexpr int32_t NO_SPLIT_STRATEGY = 1;
|
||||
constexpr size_t MATMUL_ATTRS_SIZE = 2;
|
||||
constexpr size_t MATMUL_INPUTS_SIZE = 2;
|
||||
constexpr size_t MATMUL_OUTPUTS_SIZE = 1;
|
||||
|
@ -31,6 +33,7 @@ constexpr size_t ACTIVATION_ATTR_SIZE = 1;
|
|||
constexpr size_t SOFTMAX_ATTR_SIZE = 1;
|
||||
constexpr size_t ACTIVATION_INPUTS_SIZE = 1;
|
||||
constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1;
|
||||
constexpr size_t EXPANDDIMS_INPUT_SIZE = 2;
|
||||
constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1;
|
||||
constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2;
|
||||
constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2;
|
||||
|
@ -191,6 +194,7 @@ constexpr char GET_NEXT[] = "GetNext";
|
|||
constexpr char SQUEEZE[] = "Squeeze";
|
||||
constexpr char Neg[] = "Neg";
|
||||
constexpr char BATCH_MATMUL[] = "BatchMatMul";
|
||||
constexpr char EXPAND_DIMS[] = "ExpandDims";
|
||||
|
||||
// Parallel don't care
|
||||
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
|
||||
|
|
|
@ -104,6 +104,7 @@ std::vector<std::string> splittable_op_ = {MATMUL,
|
|||
CAST,
|
||||
Neg,
|
||||
BATCH_MATMUL,
|
||||
EXPAND_DIMS,
|
||||
SQUEEZE};
|
||||
|
||||
std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT,
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.expand_dims = P.ExpandDims().set_strategy(strategy2)
|
||||
self.mul2 = P.Mul().set_strategy(strategy3)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.expand_dims(out, -1)
|
||||
out = self.mul2(out, b)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.expand_dims = P.ExpandDims().set_strategy(strategy1)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.expand_dims(self.mul_weight, -1)
|
||||
out = self.mul(out, b)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32, 1]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_expand_dims_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((16, 1, 1), )
|
||||
strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1))
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_expand_dims_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 16), (1, 1, 16))
|
||||
strategy2 = ((1, 1, 16), )
|
||||
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_expand_dims_hybrid_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4), (2, 2, 4))
|
||||
strategy2 = ((2, 2, 4), )
|
||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_expand_dims_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(_w1)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_expand_dims_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4), (2, 2, 4))
|
||||
strategy2 = ((1, 2, 2), )
|
||||
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_expand_dims_parameter():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 2, 2), )
|
||||
strategy2 = ((2, 2, 4, 1), (2, 2, 4, 1))
|
||||
net = Net2(_w1, strategy1, strategy2)
|
||||
compile(net)
|
Loading…
Reference in New Issue