!30094 [MSLITE] Add dynamicShape adjust.

Merge pull request !30094 from wangshaocong/dynamic_shape
This commit is contained in:
i-robot 2022-02-16 08:41:06 +00:00 committed by Gitee
commit a97209dd00
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 12 additions and 6 deletions

View File

@ -69,6 +69,7 @@
#include "ops/resize.h"
#include "ops/resize_bilinear.h"
#include "ops/resize_nearest_neighbor.h"
#include "ops/shape.h"
#include "ops/sigmoid.h"
#include "ops/stack.h"
#include "ops/tanh.h"
@ -152,6 +153,7 @@ constexpr auto kNameTanhGrad = "TanhGrad";
constexpr auto kNameResizeBilinearGrad = "ResizeBilinearGrad";
constexpr auto kNameResizeNearestNeighborGrad = "ResizeNearestNeighborGrad";
constexpr auto kNameStandardNormal = "StandardNormal";
constexpr auto kNameDynamicShape = "DynamicShape";
constexpr int kNCHW_H = 2;
constexpr int kNCHW_W = 3;
@ -675,5 +677,6 @@ REGIST_PRIMITIVE_ADJUST(kNameSparseSoftmaxCrossEntropyWithLogits,
REGIST_PRIMITIVE_ADJUST(kNameResizeBilinearGrad, MoveAttrMapResizeGrad)
REGIST_PRIMITIVE_ADJUST(kNameResizeNearestNeighborGrad, MoveAttrMapResizeGrad)
REGIST_PRIMITIVE_ADJUST(kNameSoftplus, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameDynamicShape, MoveAttrMapCommon<ops::Shape>)
} // namespace lite
} // namespace mindspore

View File

@ -542,12 +542,15 @@ bool IsParamOrValueNodeWithData(const BaseRef &n) {
if (utils::isa<ValueNode>(n)) {
auto value_node = utils::cast<ValueNodePtr>(n);
auto value = value_node->value();
if (value != nullptr && value->isa<tensor::Tensor>()) {
if (value == nullptr) {
return false;
}
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
if (tensor == nullptr || tensor->data_c() == nullptr) {
return false;
}
return true;
return tensor != nullptr && tensor->data_c() != nullptr;
} else if (value->isa<ValueSequence>()) {
auto sequence_ptr = value->cast<ValueSequencePtr>();
return sequence_ptr != nullptr && !sequence_ptr->value().empty();
} else {
return false;
}

View File

@ -137,7 +137,7 @@ VectorRef TransposeFusion::DefineTransTransPattern() const {
MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
auto is_transpose2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
auto transpose_param = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
VectorRef trans_trans_ref = VectorRef({is_transpose2, is_transpose1, transpose_param});
return trans_trans_ref;