[MSLITE] Fix bug for strideslice stride and tensorflow parser for resize

This commit is contained in:
张勇贤 2023-01-05 10:26:12 +08:00
parent 7f7cda558d
commit 0ef2d96d2d
5 changed files with 2 additions and 12 deletions

View File

@ -23,10 +23,6 @@ constexpr int BIAS_INDEX = 2;
int ConvolutionTensorRT::IsSupport(const schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;

View File

@ -22,10 +22,6 @@ namespace mindspore::lite {
int PoolTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;

View File

@ -249,6 +249,7 @@ int StrideSliceTensorRT::ComputeDimsSingle(TensorRTContext *ctx, ITensorHelper *
} else {
size_dims_.d[i] = input_dims.d[i];
}
size_dims_.d[i] = std::abs(size_dims_.d[i] / stride_dims_.d[i]) + ((size_dims_.d[i] % stride_dims_.d[i]) != 0);
}
}
return RET_OK;

View File

@ -44,6 +44,7 @@ PrimitiveCPtr TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
} else if (TensorFlowUtils::FindAttrValue(tf_op, "half_pixel_centers", &attr_value) && attr_value.b()) {
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::HALF_PIXEL);
prim->set_cubic_coeff(-0.5f);
prim->set_nearest_mode(mindspore::NearestMode::ROUND_HALF_UP);
prim_c->AddAttr("half_pixel_centers", MakeValue(true));
} else {
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC);

View File

@ -93,7 +93,6 @@ const BaseRef ResizeFusion1::DefinePattern() const {
}
const BaseRef ResizeFusion2::DefinePattern() const {
MS_LOG(WARNING) << "DefinePattern begin";
input_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_ != nullptr, false);
@ -145,7 +144,6 @@ const BaseRef ResizeFusion2::DefinePattern() const {
auto is_resize = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimResize>);
MS_CHECK_TRUE_RET(is_resize != nullptr, {});
VectorRef resize_ref = VectorRef({is_resize, input_, gather_ref});
MS_LOG(WARNING) << "DefinePattern end";
return resize_ref;
}
@ -211,7 +209,6 @@ int ResizeFusion1::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &no
}
int ResizeFusion2::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_LOG(WARNING) << "DoFuison begin";
MS_ASSERT(node != nullptr);
auto resize_cnode = node->cast<CNodePtr>();
MS_ASSERT(resize_cnode != nullptr);
@ -233,7 +230,6 @@ int ResizeFusion2::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &no
MS_ASSERT(manager != nullptr);
manager->SetEdge(resize_cnode, kInputIndexTwo, resize_input);
MS_LOG(WARNING) << "DoFuison end";
return lite::RET_OK;
}