forked from mindspore-Ecosystem/mindspore
add squeeze distributed op
This commit is contained in:
parent
ce9f77346d
commit
32cd280c1a
|
@ -125,6 +125,7 @@ REGISTER(GetNextInfo);
|
|||
REGISTER(NegInfo);
|
||||
REGISTER(BatchMatMulInfo);
|
||||
REGISTER(ExpandDimsInfo);
|
||||
REGISTER(SqueezeInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "parallel/auto_parallel/costmodel.h"
|
||||
|
@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() {
|
|||
MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) {
|
||||
std::vector<int32_t> axis;
|
||||
auto axis_list = value_tuple->value();
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
size_t input_size = input_shape.size();
|
||||
// if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
|
||||
if (axis_list.empty()) {
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
if (input_shape[i] == 1) {
|
||||
axis.push_back(i);
|
||||
}
|
||||
}
|
||||
axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// convert negative axis to positive.
|
||||
for (auto& dim : axis_list) {
|
||||
if (!dim->isa<Int32Imm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": The type of axis is not int";
|
||||
return FAILED;
|
||||
}
|
||||
int32_t dim_value = GetValue<int32_t>(dim);
|
||||
int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value;
|
||||
axis.push_back(positive_value);
|
||||
}
|
||||
axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::GetAttrs() {
|
||||
auto iter = attrs_.find(AXIS);
|
||||
if (iter == attrs_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can't find axis attribute.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
auto value_tuple = iter->second->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
InferAxis(value_tuple);
|
||||
attrs_[AXIS] = axis_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) {
|
||||
Attr attr = std::make_pair(AXIS, axis_);
|
||||
OperatorAttrs attrs = {attr};
|
||||
OperatorParams params;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
replace_op_ = {std::make_pair(SQUEEZE, args)};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::InferTensorMap() {
|
||||
// for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
|
||||
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
|
||||
std::vector<int32_t> input_tensor_map, output_tensor_map;
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
size_t size = inputs_shape_[0].size();
|
||||
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
size_t index = size - i - 1;
|
||||
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
|
||||
if (iter == axis.end()) {
|
||||
output_tensor_map.push_back(SizeToInt(index));
|
||||
}
|
||||
input_tensor_map.push_back(SizeToInt(index));
|
||||
}
|
||||
inputs_tensor_map_.push_back(input_tensor_map);
|
||||
outputs_tensor_map_.push_back(output_tensor_map);
|
||||
MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
|
||||
<< ", and the tensor map of output is " << ShapeToString(output_tensor_map);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_shape = inputs_shape_[0];
|
||||
Shape output_shape = outputs_shape_[0];
|
||||
|
||||
// infer slice shape
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
||||
Dimensions output_strategy;
|
||||
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
|
||||
if (iter == axis.end()) {
|
||||
output_strategy.push_back(inputs_strategy[0].at(i));
|
||||
}
|
||||
}
|
||||
Strategys outputs_strategy = {output_strategy};
|
||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_slice_shape = inputs_slice_shape[0];
|
||||
Shape output_slice_shape = outputs_slice_shape[0];
|
||||
|
||||
// infer tensor layout
|
||||
TensorLayout input_tensor_layout, output_tensor_layout;
|
||||
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
|
||||
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SqueezeInfo::Init(const StrategyPtr& strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed.";
|
||||
}
|
||||
|
||||
if (InferReplaceOps(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Infer replace ops failed";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << " : Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther {
|
|||
Strategys inputs_strategy_;
|
||||
Strategys outputs_strategy_;
|
||||
};
|
||||
|
||||
class SqueezeInfo : public ActivationOther {
|
||||
public:
|
||||
SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
const PrimitiveAttrs& attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~SqueezeInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status InferAxis(const ValueTuplePtr& value_tuple);
|
||||
Status GetAttrs() override;
|
||||
Status InferReplaceOps(const StrategyPtr& strategy);
|
||||
Status InferTensorMap() override;
|
||||
Status InferTensorInfo() override;
|
||||
Status Init(const StrategyPtr& strategy) override;
|
||||
|
||||
private:
|
||||
ValueTuplePtr axis_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
|
||||
|
|
|
@ -116,4 +116,4 @@ class AssignSubInfo : public ArithmeticBase {
|
|||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ARITHMETIC_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_
|
||||
|
|
|
@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase {
|
|||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_COMPARISON_FUNCTION_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
|
||||
|
|
|
@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo {
|
|||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ONEHOT_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_
|
||||
|
|
|
@ -47,8 +47,8 @@ using mindspore::tensor::Tensor;
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
||||
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||
// it will be one item in map with key: C, and value: (B, i)
|
||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
|
||||
|
@ -1832,7 +1832,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector<AnfNodePt
|
|||
if (cnode == loss_cnode) {
|
||||
is_loss_cnode = true;
|
||||
}
|
||||
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy1=None, strategy2=None, axis=()):
|
||||
super().__init__()
|
||||
self.squeeze = P.Squeeze(axis=axis).set_strategy(strategy1)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.squeeze(x)
|
||||
out = self.mul(out, b)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([64, 1, 32, 1]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile(net):
|
||||
_executor.compile(net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_squeeze_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1, 1), )
|
||||
strategy2 = ((16, 1), (16, 1))
|
||||
net = Net(strategy1, strategy2)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_squeeze_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 16, 1), )
|
||||
strategy2 = ((1, 16), (1, 16))
|
||||
net = Net(strategy1, strategy2)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_squeeze_specified_axis():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 1, 4, 1), )
|
||||
strategy2 = ((8, 2), (8, 2))
|
||||
net = Net(strategy1, strategy2, (1, 3))
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_squeeze_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net()
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_squeeze_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 8, 1), )
|
||||
strategy2 = ((2, 8), (2, 8))
|
||||
net = Net(strategy1, strategy2)
|
||||
compile(net)
|
Loading…
Reference in New Issue