forked from mindspore-Ecosystem/mindspore
!30094 [MSLITE] Add dynamicShape adjust.
Merge pull request !30094 from wangshaocong/dynamic_shape
This commit is contained in:
commit
a97209dd00
|
@ -69,6 +69,7 @@
|
|||
#include "ops/resize.h"
|
||||
#include "ops/resize_bilinear.h"
|
||||
#include "ops/resize_nearest_neighbor.h"
|
||||
#include "ops/shape.h"
|
||||
#include "ops/sigmoid.h"
|
||||
#include "ops/stack.h"
|
||||
#include "ops/tanh.h"
|
||||
|
@ -152,6 +153,7 @@ constexpr auto kNameTanhGrad = "TanhGrad";
|
|||
constexpr auto kNameResizeBilinearGrad = "ResizeBilinearGrad";
|
||||
constexpr auto kNameResizeNearestNeighborGrad = "ResizeNearestNeighborGrad";
|
||||
constexpr auto kNameStandardNormal = "StandardNormal";
|
||||
constexpr auto kNameDynamicShape = "DynamicShape";
|
||||
constexpr int kNCHW_H = 2;
|
||||
constexpr int kNCHW_W = 3;
|
||||
|
||||
|
@ -675,5 +677,6 @@ REGIST_PRIMITIVE_ADJUST(kNameSparseSoftmaxCrossEntropyWithLogits,
|
|||
REGIST_PRIMITIVE_ADJUST(kNameResizeBilinearGrad, MoveAttrMapResizeGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameResizeNearestNeighborGrad, MoveAttrMapResizeGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameSoftplus, MoveAttrMapActivation)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameDynamicShape, MoveAttrMapCommon<ops::Shape>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -542,12 +542,15 @@ bool IsParamOrValueNodeWithData(const BaseRef &n) {
|
|||
if (utils::isa<ValueNode>(n)) {
|
||||
auto value_node = utils::cast<ValueNodePtr>(n);
|
||||
auto value = value_node->value();
|
||||
if (value != nullptr && value->isa<tensor::Tensor>()) {
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr || tensor->data_c() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return tensor != nullptr && tensor->data_c() != nullptr;
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
auto sequence_ptr = value->cast<ValueSequencePtr>();
|
||||
return sequence_ptr != nullptr && !sequence_ptr->value().empty();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -137,7 +137,7 @@ VectorRef TransposeFusion::DefineTransTransPattern() const {
|
|||
MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
|
||||
auto is_transpose2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
|
||||
MS_CHECK_TRUE_RET(transpose_param != nullptr, {});
|
||||
VectorRef trans_trans_ref = VectorRef({is_transpose2, is_transpose1, transpose_param});
|
||||
return trans_trans_ref;
|
||||
|
|
Loading…
Reference in New Issue