data sync

This commit is contained in:
nomindcarry 2023-01-03 09:03:08 +08:00
parent 6708971c1c
commit 40ac114498
1 changed files with 16 additions and 8 deletions

View File

@ -1929,7 +1929,6 @@ static const mindspore::HashMap<std::string, std::set<int64_t>> try_get_value_in
{prim::kPrimROIAlignGrad->name(), ShapeSet{2}},
{prim::kSlice, ShapeSet{1, 2}},
{prim::kPrimSliceGrad->name(), ShapeSet{2, 3, 4}},
{prim::kStridedSliceGrad, ShapeSet{2, 3, 4}},
{prim::kPrimTensorCopySlices->name(), ShapeSet{2, 3, 4}},
{prim::kTranspose, ShapeSet{1}},
{prim::kPrimGatherD->name(), ShapeSet{1}},
@ -1938,9 +1937,7 @@ static const mindspore::HashMap<std::string, std::set<int64_t>> try_get_value_in
{prim::kPrimScatterNd->name(), ShapeSet{2}},
{prim::kStridedSlice, ShapeSet{1, 2, 3}},
{prim::kStridedSliceGrad, ShapeSet{1, 2, 3, 4}},
{prim::kPrimTensorCopySlices->name(), ShapeSet{2, 3, 4}},
{prim::kTile, ShapeSet{1}},
{prim::kTranspose, ShapeSet{1}},
{prim::kPrimConv2DBackpropFilter->name(), ShapeSet{2}},
{prim::kPrimConv2DBackpropInput->name(), ShapeSet{2}},
{prim::kMatrixDiagPartV3, ShapeSet{1, 2}},
@ -1965,13 +1962,24 @@ bool IfNeedSkipResize(const CNodePtr &node) {
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, false);
auto real_input = input_node_with_index.first;
// Inverse op have constant input need infer ,then resize
auto shape_set = GetShapeSetFromResizeMap(node);
if (shape_set.find(i) != shape_set.end() && real_input->isa<Parameter>()) {
MS_LOG(DEBUG) << "Set Node Attr is Dynamic Shape";
common::AnfAlgo::SetNodeAttr(mindspore::kAttrOutputIsDynamicShape, MakeValue(true), node);
node->func_graph()->cast<KernelGraphPtr>()->SetGraphDynamicAttr(true);
return true;
if (shape_set.find(i) != shape_set.end()) {
if (real_input->isa<Parameter>()) {
MS_LOG(DEBUG) << "Set Node Attr is Dynamic Shape";
common::AnfAlgo::SetNodeAttr(mindspore::kAttrOutputIsDynamicShape, MakeValue(true), node);
node->func_graph()->cast<KernelGraphPtr>()->SetGraphDynamicAttr(true);
return true;
} else if (real_input->isa<ValueNode>()) {
auto value_node = real_input->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto value_tensor_ptr = value->cast<tensor::TensorPtr>();
value_tensor_ptr->data_sync();
}
}
}
}
}