forked from mindspore-Ecosystem/mindspore
support reshape redistribution in all scenes
This commit is contained in:
parent
478200d2fe
commit
f60d81a15f
|
@ -488,10 +488,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
|||
}
|
||||
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
|
||||
SetOutputLayout(next_in_tensor_info.tensor_layout());
|
||||
if (Init(nullptr) == FAILED) {
|
||||
MS_LOG(DEBUG) << "Failure:operator reshape init failed";
|
||||
continue;
|
||||
}
|
||||
InferTensorInfoByLayout();
|
||||
SetCostForReshape(reshape_stra);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,6 +63,14 @@ std::shared_ptr<ReshapeLayoutTransfer> RedistributionLayoutTransfer::UnifyDevice
|
|||
if (unified_device_arrangement_ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
Shape in_expand_shape;
|
||||
Status status = ExpandShape(unified_device_arrangement_ptr->from_in().tensor_shape().array(),
|
||||
unified_device_arrangement_ptr->to_in().tensor_shape().array(), &in_expand_shape);
|
||||
if (status != Status::SUCCESS) {
|
||||
MS_LOG(INFO) << "The shape of from and to cannot transfer by unify";
|
||||
unified_device_arrangement_ptr->SetExpandAble(false);
|
||||
return unified_device_arrangement_ptr;
|
||||
}
|
||||
return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape();
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -35,12 +35,15 @@ class ReshapeLayoutTransfer : public LayoutTransfer {
|
|||
std::shared_ptr<ReshapeLayoutTransfer> ExpandFromTensorShapeAndExpandToDeviceArrangement(
|
||||
const Arrangement &expand_shape) const;
|
||||
std::shared_ptr<ReshapeLayoutTransfer> ExchangeFromAndTo() const;
|
||||
bool ExpandAble() const { return is_expand_able_; }
|
||||
bool FromTensorShapeCanBeExpandByTo() const;
|
||||
bool ToTensorShapeCanBeExpandByFrom() const;
|
||||
void SetExpandAble(const bool is_expand_able) { is_expand_able_ = is_expand_able; }
|
||||
|
||||
private:
|
||||
Status CheckValidTransfer() override;
|
||||
std::shared_ptr<Arrangement> ComputeExpandedFromTensorShapeByTo() const;
|
||||
bool FromTensorShapeCanBeExpandByTo() const;
|
||||
bool ToTensorShapeCanBeExpandByFrom() const;
|
||||
bool is_expand_able_ = true;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -97,11 +97,11 @@ Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *
|
|||
int64_t value = 1;
|
||||
for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) {
|
||||
if (*iter == 0) {
|
||||
MS_LOG(ERROR) << "element of shape_accum should not be zero";
|
||||
MS_LOG(WARNING) << "element of shape_accum should not be zero";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if ((*iter) % value != 0) {
|
||||
MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order";
|
||||
MS_LOG(WARNING) << "shape_accum is not a accumulate product in ascending order";
|
||||
return Status::FAILED;
|
||||
}
|
||||
(void)shape->insert(shape->begin(), static_cast<int64_t>((*iter) / value));
|
||||
|
|
|
@ -390,6 +390,15 @@ TensorLayout TensorLayout::SqueezeShape() const {
|
|||
return out;
|
||||
}
|
||||
|
||||
TensorLayout TensorLayout::TransferRepeatLayout() const {
|
||||
Shape dev_mat(device_arrangement_.array());
|
||||
Shape tensor_map(tensor_map_.GetDimSize(), -1);
|
||||
Shape tensor_shape(tensor_shape_.array());
|
||||
TensorLayout repeat;
|
||||
repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
|
||||
return repeat;
|
||||
}
|
||||
|
||||
// Generate a totally shard tensor slice shape for parallel optimizer
|
||||
Status TensorLayout::GenerateOptShardSliceShape() {
|
||||
MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
|
||||
|
|
|
@ -88,6 +88,8 @@ class TensorLayout {
|
|||
|
||||
TensorLayout SqueezeShape() const;
|
||||
|
||||
TensorLayout TransferRepeatLayout() const;
|
||||
|
||||
Status GenerateOptShardSliceShape();
|
||||
|
||||
Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; }
|
||||
|
|
|
@ -39,6 +39,42 @@ Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) {
|
||||
TensorLayout from_repeat = from_origin_.TransferRepeatLayout();
|
||||
TensorLayout to_repeat = to_origin_.TransferRepeatLayout();
|
||||
MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString();
|
||||
MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
|
||||
OperatorVector operator_vector;
|
||||
OutPutInfoVector output_info_vector;
|
||||
if (InferRedistribution(from_origin_, from_repeat, &operator_vector, &output_info_vector, is_cost_model) ==
|
||||
Status::FAILED) {
|
||||
return nullptr;
|
||||
}
|
||||
if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) {
|
||||
reshape_flag_ = true;
|
||||
ConstructOperator constructor;
|
||||
constructor.UpdateTensorShape(from_repeat.slice_shape().array());
|
||||
Arrangement shape = to_repeat.slice_shape();
|
||||
MS_LOG(DEBUG) << "reshape " << shape.ToString();
|
||||
if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
|
||||
return nullptr;
|
||||
} else {
|
||||
(void)operator_vector.push_back(constructor.GetOperator());
|
||||
(void)output_info_vector.push_back(std::make_pair(false, 0));
|
||||
}
|
||||
}
|
||||
if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) ==
|
||||
Status::FAILED) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
|
||||
std::make_pair(operator_vector, output_info_vector));
|
||||
}
|
||||
|
||||
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
|
||||
// Step 1: Match device arrangement between from_ and to_
|
||||
RedistributionLayoutTransfer layout_transfer;
|
||||
|
@ -51,6 +87,10 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
MS_LOG(ERROR) << "Infer tensor layout return nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (!ptr->ExpandAble()) {
|
||||
expand_able_ = false;
|
||||
return InferTensorRedistributionOperatorListUnExpand(is_cost_model);
|
||||
}
|
||||
TensorLayout from_layout = ptr->from_in();
|
||||
TensorLayout to_layout = ptr->to_in();
|
||||
MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString();
|
||||
|
@ -61,27 +101,17 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
|
||||
// Step 2: Infer redistribution and insert operators
|
||||
RedistributionOperatorInfer operator_infer(construct_op_flag_);
|
||||
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Init operatorInfer failed!";
|
||||
return nullptr;
|
||||
}
|
||||
OperatorVector operator_vector;
|
||||
OutPutInfoVector output_info_vector;
|
||||
if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Infer redistribution failed!";
|
||||
if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
|
||||
Status::SUCCESS) {
|
||||
return nullptr;
|
||||
} else {
|
||||
operator_vector = operator_infer.operator_vector();
|
||||
output_info_vector = operator_infer.output_info_vector();
|
||||
operator_list_ = operator_infer.operator_list();
|
||||
}
|
||||
|
||||
// Step 3: Infer reshape and insert operators
|
||||
if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Construct Reshape operator failed!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
|
||||
std::make_pair(operator_vector, output_info_vector));
|
||||
}
|
||||
|
@ -136,6 +166,31 @@ Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
|
||||
OperatorVector *const operator_vector,
|
||||
OutPutInfoVector *const output_info_vector, bool is_cost_model) {
|
||||
RedistributionOperatorInfer operator_infer(construct_op_flag_);
|
||||
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Init operatorInfer failed";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Infer redistribution failed";
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
for (auto op : operator_infer.operator_vector()) {
|
||||
operator_vector->insert(operator_vector->end(), op);
|
||||
}
|
||||
for (auto info : operator_infer.output_info_vector()) {
|
||||
output_info_vector->insert(output_info_vector->end(), info);
|
||||
}
|
||||
for (auto opc : operator_infer.operator_list()) {
|
||||
operator_list_.insert(operator_list_.end(), opc);
|
||||
}
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status TensorRedistribution::ComputeCost() {
|
||||
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
|
||||
if (redistribution_oplist_ptr == nullptr) {
|
||||
|
@ -162,8 +217,13 @@ Status TensorRedistribution::ComputeCost() {
|
|||
}
|
||||
}
|
||||
if (reshape_flag()) {
|
||||
Shape prev_slice_shape = from_.slice_shape().array();
|
||||
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
|
||||
Shape prev_shape;
|
||||
if (expand_able_) {
|
||||
prev_shape = from_.slice_shape().array();
|
||||
} else {
|
||||
prev_shape = from_.tensor_shape().array();
|
||||
}
|
||||
double prev_prod = std::accumulate(prev_shape.begin(), prev_shape.end(), 1, std::multiplies<int>());
|
||||
computation_cost_ += 2.0 * prev_prod;
|
||||
memory_cost_ += 2.0 * prev_prod;
|
||||
}
|
||||
|
|
|
@ -61,8 +61,12 @@ class TensorRedistribution {
|
|||
private:
|
||||
Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
|
||||
OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector);
|
||||
Status InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
|
||||
OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector,
|
||||
bool is_cost_model);
|
||||
Status ComputeConcatCost(double input_size, Shape attrs);
|
||||
Status ComputePermuteCost(double input_size, Shape attrs);
|
||||
RedistributionOpListPtr InferTensorRedistributionOperatorListUnExpand(bool is_cost_model = false);
|
||||
TensorLayout from_origin_;
|
||||
TensorLayout to_origin_;
|
||||
TensorLayout from_;
|
||||
|
@ -84,6 +88,7 @@ class TensorRedistribution {
|
|||
double memory_cost_;
|
||||
bool construct_op_flag_;
|
||||
bool keep_reshape_;
|
||||
bool expand_able_ = true;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
return grad_all(self.network)(x)
|
||||
|
||||
def test_reshape_unexpand():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul().shard(((1, 8), (1, 1, 8)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.reshape(self.mul_weight, (1, 128, 96))
|
||||
out = self.mul(x, weight)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul().shard(((1, 8), (1, 1, 8)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.reshape(self.mul_weight, (1, 128, 96))
|
||||
out = self.mul(x, weight)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, data):
|
||||
x = self.reshape(self.mul_weight, (1, 128, 96))
|
||||
out = self.mul(x, self.mul_weight)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_3():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu1 = P.ReLU().shard(((4, 1),))
|
||||
self.relu2 = P.ReLU().shard(((1, 4),))
|
||||
|
||||
def construct(self, data):
|
||||
x = self.relu1(data)
|
||||
x = self.reshape(x, (3, 4))
|
||||
x = self.relu2(x)
|
||||
return x
|
||||
|
||||
size = 4
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 3]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_4():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu1 = P.ReLU().shard(((4, 1),))
|
||||
self.relu2 = P.ReLU().shard(((1, 2, 2),))
|
||||
|
||||
def construct(self, data):
|
||||
x = self.relu1(data)
|
||||
x = self.reshape(x, (3, 2, 2))
|
||||
x = self.relu2(x)
|
||||
return x
|
||||
|
||||
size = 4
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 3]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_5():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu1 = P.ReLU().shard(((2, 2, 1),))
|
||||
self.relu2 = P.ReLU().shard(((1, 4),))
|
||||
|
||||
def construct(self, data):
|
||||
x = self.relu1(data)
|
||||
x = self.reshape(x, (3, 4))
|
||||
x = self.relu2(x)
|
||||
return x
|
||||
|
||||
size = 4
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_6():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu1 = P.ReLU().shard(((2, 1),))
|
||||
self.relu2 = P.ReLU().shard(((1, 1, 4),))
|
||||
|
||||
def construct(self, data):
|
||||
x = self.relu1(data)
|
||||
x = self.reshape(x, (1, 3, 4))
|
||||
x = self.relu2(x)
|
||||
return x
|
||||
|
||||
size = 4
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 3]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
Loading…
Reference in New Issue