forked from mindspore-Ecosystem/mindspore
!7345 fix a bug case in reshape redistribution
Merge pull request !7345 from yao_yf/reshape_redistribution_all_scene_support_add
This commit is contained in:
commit
a5e8c1eab3
|
@ -770,7 +770,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
bool FindReshape(const CNodePtr &cnode) {
|
||||
bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
|
@ -780,7 +780,16 @@ bool FindReshape(const CNodePtr &cnode) {
|
|||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return (prim->name() == RESHAPE);
|
||||
if (prim->name() == RESHAPE) {
|
||||
auto operator_info = cnode->user_data<OperatorInfo>();
|
||||
std::string op_info_name = operator_info->name();
|
||||
if (op_cache->find(op_info_name) != op_cache->end()) {
|
||||
return false;
|
||||
}
|
||||
op_cache->insert(op_info_name);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
|
||||
|
@ -871,9 +880,10 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
|
|||
}
|
||||
|
||||
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::unordered_set<std::string> op_cache;
|
||||
for (auto node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!FindReshape(cnode)) {
|
||||
if (!FindReshape(cnode, &op_cache)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(cnode->inputs().size() == 3);
|
||||
|
|
|
@ -36,11 +36,14 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrange
|
|||
while (!is_unified) {
|
||||
std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo();
|
||||
if (temp_layout_ptr == nullptr) {
|
||||
return nullptr;
|
||||
out_layout_ptr->SetExpandAble(false);
|
||||
return out_layout_ptr;
|
||||
}
|
||||
out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom();
|
||||
if (out_layout_ptr == nullptr) {
|
||||
return nullptr;
|
||||
std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
|
||||
layout_ptr->SetExpandAble(false);
|
||||
return layout_ptr;
|
||||
}
|
||||
is_unified = out_layout_ptr->IsSameTensorShape();
|
||||
}
|
||||
|
|
|
@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens
|
|||
MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
|
||||
return Status::SUCCESS;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
|
||||
if (layout_transfer_) {
|
||||
MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
|
||||
}
|
||||
return Status::FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const {
|
|||
return false;
|
||||
}
|
||||
if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
|
||||
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
||||
if (layout_transfer_) {
|
||||
MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -214,6 +222,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDevice
|
|||
return nullptr;
|
||||
}
|
||||
TensorLayout tensor_layout_new;
|
||||
tensor_layout_new.set_layout_transfer(true);
|
||||
Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
|
||||
if (status != Status::SUCCESS) {
|
||||
return nullptr;
|
||||
|
@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const {
|
|||
}
|
||||
|
||||
TensorLayout TensorLayout::TransferRepeatLayout() const {
|
||||
Shape dev_mat(device_arrangement_.array());
|
||||
Shape tensor_map(tensor_map_.GetDimSize(), -1);
|
||||
Shape tensor_shape(tensor_shape_.array());
|
||||
Shape dev_mat(device_arrangement_origin_.array());
|
||||
Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
|
||||
Shape tensor_shape(tensor_shape_origin_.array());
|
||||
TensorLayout repeat;
|
||||
repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
|
||||
return repeat;
|
||||
|
|
|
@ -46,6 +46,10 @@ class TensorLayout {
|
|||
|
||||
void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
|
||||
|
||||
bool layout_transfer() const { return layout_transfer_; }
|
||||
|
||||
void set_layout_transfer(bool flag) { layout_transfer_ = flag; }
|
||||
|
||||
int32_t get_field_size() const { return field_size_; }
|
||||
|
||||
void set_field_size(int32_t field_size) { field_size_ = field_size; }
|
||||
|
@ -113,14 +117,15 @@ class TensorLayout {
|
|||
int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const;
|
||||
|
||||
Arrangement device_arrangement_origin_;
|
||||
Map tensor_map_origin_;
|
||||
Arrangement tensor_shape_origin_;
|
||||
Arrangement device_arrangement_;
|
||||
Map tensor_map_;
|
||||
Arrangement tensor_shape_;
|
||||
Map tensor_map_;
|
||||
Map tensor_map_origin_;
|
||||
bool skip_redistribution_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
bool uniform_split_ = true;
|
||||
bool layout_transfer_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
Shape opt_shard_slice_shape_;
|
||||
std::string opt_shard_group_ = "";
|
||||
};
|
||||
|
|
|
@ -43,7 +43,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
TensorLayout from_repeat = from_origin_.TransferRepeatLayout();
|
||||
TensorLayout to_repeat = to_origin_.TransferRepeatLayout();
|
||||
MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
|
||||
MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
|
||||
|
|
|
@ -204,3 +204,35 @@ def test_reshape_unexpand_6():
|
|||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_unexpand_7():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
|
||||
mul_size=(32, 1, 220, 220)):
|
||||
super().__init__()
|
||||
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
|
||||
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
|
||||
self.mul = P.Mul()
|
||||
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
||||
kernel_size=5, has_bias=True, weight_init='ones',
|
||||
bias_init='ones', pad_mode='valid')
|
||||
self.softmax = nn.Softmax(axis=axis)
|
||||
self.relu = nn.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.input_shape = input_shape
|
||||
|
||||
def construct(self, inputs):
|
||||
x = self.conv(inputs)
|
||||
x = self.softmax(x)
|
||||
x = self.relu(x)
|
||||
x = self.mul(x, self.mul_weight)
|
||||
x = self.reshape(x, self.input_shape)
|
||||
return x
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
|
Loading…
Reference in New Issue