!49241 ResizeLinear1D 对接 ResizeD PASS,ResizeLinear1D 输入 size 支持 tuple,list

Merge pull request !49241 from haozhang/ResizeLinear1D
This commit is contained in:
i-robot 2023-02-25 06:54:35 +00:00 committed by Gitee
commit d92b7fbac5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 313 additions and 81 deletions

View File

@ -11,19 +11,20 @@ mindspore.ops.ResizeLinear1D
.. warning::
实验特性,接口可能发生变化。
目前,昇腾平台仅支持输入 `size` 为Tuple或List的场景。
参数:
- **coordinate_transformation_mode** (str) - 指定进行坐标变换的方式,默认值是"align_corners",还可选"half_pixel"和"asymmetric"。
输入:
- **x** (Tensor) - ResizeBilinear的输入三维的Tensor其shape为 :math:`(batch, channels, width)`。支持以下数据类型float16、float32、double。
- **size** (Tensor) - 指定 `x` 宽的新尺寸一维的Tensor其shape为 :math:`(1)` 数据类型为int64
- **size** (Union[Tuple[int], List[int], Tensor[int]) - 指定 `x` 宽的新尺寸,仅含一个整数 :math:`(new\_width)` 的Tuple、List或1-D Tensor
输出:
Tensor调整大小后的Tensor。shape为 :math:`(batch, channels, new\_width)` 的三维Tensor数据类型和输入是一致的。
异常:
- **TypeError** - `x` 的数据类型不支持。
- **TypeError** - `size` 不是int64的数据类型
- **TypeError** - `size` 不是Tuple[int]、List[int]或Tensor[int]
- **TypeError** - `coordinate_transformation_mode` 不是string。
- **TypeError** - `coordinate_transformation_mode` 不在支持的列表中。

View File

@ -671,6 +671,8 @@ constexpr auto kResizeNearestNeighborV2OpName = "ResizeNearestNeighborV2";
constexpr auto kResizeNearestNeighborV2DOpName = "ResizeNearestNeighborV2D";
constexpr auto kReverseV2OpName = "ReverseV2";
constexpr auto kReverseV2DOpName = "ReverseV2D";
constexpr auto kResizeDOpName = "ResizeD";
constexpr auto kResizeGradDOpName = "ResizeGradD";
constexpr auto kReturnOpName = "Return";
constexpr auto kRGBToHSVOpName = "RGBToHSV";
constexpr auto kROIAlignGradName = "ROIAlignGrad";

View File

@ -42,6 +42,7 @@
#include "plugin/device/ascend/optimizer/ir_fission/renorm_split.h"
#include "plugin/device/ascend/optimizer/ir_fission/tensor_scatter_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/resize_linear1d_fission.h"
#include "backend/common/pass/communication_op_fusion.h"
#include "backend/common/pass/dropout_gen_mask_fusion.h"
#include "backend/common/pass/dynamic_sequence_ops_adaptation.h"
@ -389,6 +390,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<RenormSplit>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DFission>());
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DGradFission>());
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());
@ -485,6 +488,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DFission>());
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DGradFission>());
const auto &pass_creators =
opt::Factory<PatternProcessPass>::Instance().GetPassCreatorsByType(kPassType::kIRFusionFissionPass);
for (const auto &pass_creator : pass_creators) {

View File

@ -0,0 +1,196 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ir_fission/resize_linear1d_fission.h"
#include <memory>
#include <vector>
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kResizeLinear1DInputNum = 2;
constexpr int64_t kExpandDim = -1;
constexpr int64_t kSqueezeDim = 2;
AnfNodePtr AddExpandDimsNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
// Add ExpandDims Node
std::vector<AnfNodePtr> expand_dims_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimExpandDims->name())), input_node};
auto expand_dims = pass.NewCNode(expand_dims_inputs, func_graph);
// Set ExpandDims OutShape and Type
auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
auto expand_shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
(void)expand_shape.insert(expand_shape.end() + kExpandDim, 1);
(void)common::AnfAlgo::SetOutputInferTypeAndShape({dtype}, {expand_shape}, expand_dims.get());
// Set ExpandDims Attr
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(kExpandDim), expand_dims);
common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), expand_dims);
return expand_dims;
}
} // namespace
const BaseRef ResizeLinear1DFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto resize_linear1d_prim = std::make_shared<Primitive>(prim::kPrimResizeLinear1D->name());
return VectorRef({resize_linear1d_prim, Xs});
}
const AnfNodePtr ResizeLinear1DFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto resize_linear1d = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(resize_linear1d);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (common::AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
if (resize_linear1d->size() != kResizeLinear1DInputNum + 1) {
MS_LOG(INFO) << "The node " << resize_linear1d->DebugString() << " is not equal to " << kResizeLinear1DInputNum
<< "inputs";
return nullptr;
}
if (!common::AnfAlgo::HasNodeAttr("coordinate_transformation_mode", resize_linear1d)) {
MS_LOG(EXCEPTION) << "ResizeLinear1D need to set coordinate_transformation_mode attribute.";
}
const auto ori_inputs = resize_linear1d->inputs();
// Add ExpandDims Node
auto expand_dims = AddExpandDimsNode(func_graph, ori_inputs[kDim1], *this);
// Get ResizeD Node
std::vector<AnfNodePtr> resize_d_inputs = {NewValueNode(std::make_shared<Primitive>(kResizeDOpName)), expand_dims};
auto resize_d = func_graph->NewCNode(resize_d_inputs);
MS_EXCEPTION_IF_NULL(resize_d);
// Set ResizeD OutShape and Type
auto out_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto out_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
(void)out_shape.insert(out_shape.end() + kExpandDim, 1);
(void)common::AnfAlgo::SetOutputInferTypeAndShape({out_type}, {out_shape}, resize_d.get());
// Set ResizeD Attr
std::vector<int64_t> size_value = {out_shape[kIndex3]};
auto x_shape = common::AnfAlgo::GetOutputInferShape(ori_inputs[kDim1], 0);
float scale = static_cast<float>(size_value[kDim0]) / static_cast<float>(x_shape[kDim2]);
std::vector<float> scales = {scale};
common::AnfAlgo::SetNodeAttr("sizes", MakeValue(size_value), resize_d);
common::AnfAlgo::SetNodeAttr("scales", MakeValue(scales), resize_d);
common::AnfAlgo::SetNodeAttr("mode", MakeValue("linear"), resize_d);
common::AnfAlgo::CopyNodeAttr("coordinate_transformation_mode", resize_linear1d, resize_d);
// Get Squeeze Node
std::vector<AnfNodePtr> squeeze_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSqueeze->name())),
resize_d};
auto squeeze = func_graph->NewCNode(squeeze_inputs);
MS_EXCEPTION_IF_NULL(squeeze);
// Set Squeeze Attr
std::vector<int64_t> axis = {kSqueezeDim};
common::AnfAlgo::SetNodeAttr("axis", MakeValue(axis), squeeze);
// Set abstract and scope
squeeze->set_abstract(resize_linear1d->abstract());
squeeze->set_scope(resize_linear1d->scope());
return squeeze;
}
const BaseRef ResizeLinear1DGradFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto resize_linear1d_grad_prim = std::make_shared<Primitive>(prim::kPrimResizeLinear1DGrad->name());
return VectorRef({resize_linear1d_grad_prim, Xs});
}
const AnfNodePtr ResizeLinear1DGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto resize_linear1d_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(resize_linear1d_grad);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (resize_linear1d_grad->size() != kResizeLinear1DInputNum + 1) {
MS_LOG(INFO) << "The node " << resize_linear1d_grad->DebugString() << " is not equal to " << kResizeLinear1DInputNum
<< "inputs";
return nullptr;
}
if (!common::AnfAlgo::HasNodeAttr("coordinate_transformation_mode", resize_linear1d_grad)) {
MS_LOG(EXCEPTION) << "ResizeLinear1DGrad need to set coordinate_transformation_mode attribute.";
}
const auto ori_inputs = resize_linear1d_grad->inputs();
// Add ExpandDims Node
auto expand_dims = AddExpandDimsNode(func_graph, ori_inputs[kDim1], *this);
// Get ResizeGradD Node
std::vector<AnfNodePtr> resize_grad_d_inputs = {NewValueNode(std::make_shared<Primitive>(kResizeGradDOpName)),
expand_dims};
auto resize_grad_d = func_graph->NewCNode(resize_grad_d_inputs);
MS_EXCEPTION_IF_NULL(resize_grad_d);
// Set ResizeGradD OutShape and Type
auto out_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto out_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
(void)out_shape.insert(out_shape.end() + kExpandDim, 1);
(void)common::AnfAlgo::SetOutputInferTypeAndShape({out_type}, {out_shape}, resize_grad_d.get());
// Set ResizeGradD Attr
auto origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex1);
auto x_shape = common::AnfAlgo::GetOutputInferShape(ori_inputs[kDim1], 0);
float scale = static_cast<float>(x_shape[kDim2]) / static_cast<float>(origin_shape[kDim2]);
std::vector<float> scales = {scale};
common::AnfAlgo::SetNodeAttr("original_size", MakeValue(origin_shape), resize_grad_d);
common::AnfAlgo::SetNodeAttr("scales", MakeValue(scales), resize_grad_d);
common::AnfAlgo::SetNodeAttr("mode", MakeValue("linear"), resize_grad_d);
common::AnfAlgo::CopyNodeAttr("coordinate_transformation_mode", resize_linear1d_grad, resize_grad_d);
// Get Squeeze Node
std::vector<AnfNodePtr> squeeze_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSqueeze->name())),
resize_grad_d};
auto squeeze = func_graph->NewCNode(squeeze_inputs);
MS_EXCEPTION_IF_NULL(squeeze);
// Set Squeeze Attr
std::vector<int64_t> axis = {kSqueezeDim};
common::AnfAlgo::SetNodeAttr("axis", MakeValue(axis), squeeze);
// Set abstract and scope
squeeze->set_abstract(resize_linear1d_grad->abstract());
squeeze->set_scope(resize_linear1d_grad->scope());
return squeeze;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_
#include "ir/anf.h"
#include "backend/common/optimizer/pattern_engine.h"
#include "backend/common/optimizer/helper.h"
#include "backend/common/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class ResizeLinear1DFission : public PatternProcessPass {
public:
explicit ResizeLinear1DFission(bool multigraph = true) : PatternProcessPass("resize_linear1d_fission", multigraph) {}
~ResizeLinear1DFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class ResizeLinear1DGradFission : public PatternProcessPass {
public:
explicit ResizeLinear1DGradFission(bool multigraph = true)
: PatternProcessPass("resize_linear1d_grad_fission", multigraph) {}
~ResizeLinear1DGradFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_

View File

@ -20,6 +20,7 @@
#include <algorithm>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/abstract_value.h"
#include "abstract/dshape.h"
@ -50,89 +51,68 @@ const int64_t kInputShape0Dim = 3;
const int64_t kInputShape1Dim = 1;
abstract::ShapePtr ResizeLinear1DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto input_x_arg = input_args[kInputIndex0];
auto size_arg = input_args[kInputIndex1];
if (!input_x_arg->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "Images only support tensor!";
}
if (!size_arg->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "Size only support tensor!";
}
auto input_x_shape = input_x_arg->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_x_shape);
auto input_x_shape_value_ptr = input_x_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input_x_shape_value_ptr);
auto input_x_type = input_x_arg->BuildType();
MS_EXCEPTION_IF_NULL(input_x_type);
auto input_x_type_id = input_x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_x_type_id);
auto input_x_type_element = input_x_type_id->element();
MS_EXCEPTION_IF_NULL(input_x_type_element);
auto size_shape = size_arg->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(size_shape);
auto size_shape_value_ptr = size_shape->BuildValue();
MS_EXCEPTION_IF_NULL(size_shape_value_ptr);
auto size_shape_tensor = size_shape_value_ptr->cast<tensor::TensorPtr>();
auto size_type = size_arg->BuildType();
MS_EXCEPTION_IF_NULL(size_type);
auto size_type_id = size_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(size_type_id);
auto size_type_element = size_type_id->element();
MS_EXCEPTION_IF_NULL(size_type_element);
auto shape0_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_arg->BuildShape())[kShape]);
auto shape1_ptr =
std::make_shared<abstract::Shape>(CheckAndConvertUtils::ConvertShapePtrToShapeMap(size_arg->BuildShape())[kShape]);
auto shape0_v = shape0_ptr->shape();
auto shape1_v = shape1_ptr->shape();
// support dynamic shape
if (IsDynamicRank(shape0_v) || IsDynamicRank(shape1_v)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
if (shape0_v.size() < kInputShape0Dim) {
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the rank of images tensor must be greater than 3. But got "
<< shape0_v.size();
}
if (shape1_v.size() != kInputShape1Dim) {
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the size tensor must be a 1-D tensor. But got "
<< shape1_v.size() << "-D";
}
if (size_arg->isa<abstract::AbstractTensor>() && size_arg->BuildValue()->isa<tensor::Tensor>()) {
int64_t out_width = 0;
if (size_shape_tensor->data_type() == kNumberTypeInt32) {
auto size_shape_ptr = reinterpret_cast<int32_t *>(size_shape_tensor->data_c());
out_width = static_cast<int64_t>(size_shape_ptr[kInputIndex0]);
} else if (size_shape_tensor->data_type() == kNumberTypeInt64) {
auto size_shape_ptr = reinterpret_cast<int64_t *>(size_shape_tensor->data_c());
out_width = size_shape_ptr[kInputIndex0];
}
if (out_width <= 0) {
MS_EXCEPTION(ValueError) << "The size must be positive , but got " << out_width;
}
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
std::vector<int64_t> output_shape = shape0_v;
output_shape.pop_back();
output_shape.push_back(out_width);
return std::make_shared<abstract::Shape>(output_shape);
} else {
ShapeVector shape_out = shape0_v;
shape_out.pop_back();
shape_out.push_back(abstract::Shape::kShapeDimAny);
return std::make_shared<abstract::Shape>(shape_out);
const int64_t shape0_dim = 3;
std::vector<int64_t> output_shape(shape0_dim, abstract::Shape::kShapeDimAny);
auto shape0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (!IsDynamicRank(shape0)) {
(void)CheckAndConvertUtils::CheckInteger("images' rank", SizeToLong(shape0.size()), kEqual, shape0_dim, prim_name);
output_shape[kInputIndex0] = shape0[kInputIndex0];
output_shape[kInputIndex1] = shape0[kInputIndex1];
}
auto value_ptr = input_args[kInputIndex1]->BuildValue();
MS_EXCEPTION_IF_NULL(value_ptr);
if (!IsValueKnown(value_ptr)) {
return std::make_shared<abstract::Shape>(output_shape);
}
auto size_type = input_args[kInputIndex1]->BuildType();
std::vector<int64_t> size_value{};
if (size_type->isa<TensorType>()) {
const int64_t kDimOne = 1;
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("rank of size's shape", SizeToLong(size_shape.size()), kEqual, kDimOne,
prim_name);
size_value = CheckAndConvertUtils::CheckTensorIntValue("size", value_ptr, prim_name);
} else if (IsIdentidityOrSubclass(size_type, kTuple) || IsIdentidityOrSubclass(size_type, kList)) {
size_value = CheckAndConvertUtils::CheckIntOrTupleInt("size", value_ptr, prim_name);
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the `size` "
<< " must be a tuple、list or tensor with all Int elements, but got "
<< value_ptr->type_name() << ".";
}
const int64_t size_num = 1;
(void)CheckAndConvertUtils::CheckInteger("size", SizeToLong(size_value.size()), kEqual, size_num, prim_name);
const int64_t kNumZero = 0;
for (size_t i = 0; i < size_value.size(); ++i) {
CheckAndConvertUtils::CheckInteger("size", size_value[i], kGreaterThan, kNumZero, prim_name);
}
output_shape[kInputIndex2] = size_value[kInputIndex0];
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr ResizeLinear1DInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "For 'ResizeLinear1D', input args contain nullptr.";
}
auto prim_name = primitive->name();
auto input_x_arg = input_args[kInputIndex0];
auto size_arg = input_args[kInputIndex1];
auto x_type = input_args[kInputIndex0]->BuildType();
auto size_type = input_args[kInputIndex1]->BuildType();
const std::set<TypePtr> valid0_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> valid1_types = {kInt64, kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", input_x_arg->BuildType(), valid0_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", size_arg->BuildType(), valid1_types, prim_name);
return input_x_arg->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", x_type, valid0_types, prim_name);
if (size_type->isa<TensorType>()) {
const std::set<TypePtr> valid1_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", size_type, valid1_types, prim_name);
}
return x_type;
}
} // namespace

View File

@ -487,7 +487,7 @@ class Upsample(Cell):
upsample = P.image_ops.ResizeLinear1D(
coordinate_transformation_mode=coordinate_transformation_mode
)
return upsample(tensor, Tensor(size, dtype=mstype.int32))
return upsample(tensor, size)
def run_bilinear(tensor, ndim, size):
if self.scale_factor is not None:

View File

@ -2102,7 +2102,7 @@ def interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=N
resize = _get_cache_prim(P.image_ops.ResizeLinear1D)(
coordinate_transformation_mode
)
return resize(x, Tensor(size, dtype=mstype.int32))
return resize(x, size)
def run_bilinear(x, size, align_corners=None, scale_factor=None):
resize = _get_cache_prim(P.ResizeBilinearV2)(align_corners, not align_corners)

View File

@ -641,6 +641,7 @@ class ResizeLinear1D(Primitive):
.. warning::
This is an experimental feature and is subjected to change.
Currently, the Ascend platform only supports scenarios where the input `size` is Tuple or List.
Args:
coordinate_transformation_mode (str): Default is 'align_corners'. Describes how to transform the coordinate
@ -649,19 +650,20 @@ class ResizeLinear1D(Primitive):
Inputs:
- **x** (Tensor) - A 3-D tensor which to resize, with shape [batch, channel, width]. Must be one of the
following types: uint8, int8, int16, int32, int64, float16, float32, double.
- **size** (Tensor) - A 1-D int64 Tensor, describes the size of the output tensor.
- **size** (Union[Tuple[int], List[int], Tensor[int]]): describes the new width of `x` .
A tuple or list or 1-D tensor with only one int element :math:`(new\_width)`.
Outputs:
A 3-D tensor which shape is [batch, channel, new_width] with the same type as `x`.
Raises:
TypeError: If dtype of `x` is not in the support list.
TypeError: If `size` is not a 1-D int64_t tensor.
TypeError: If `size` is not in Union[Tuple[int], List[int], Tensor[int]].
TypeError: If `coordinate_transformation_mode` is not a string.
TypeError: If `coordinate_transformation_mode` is not in the support list.
Supported Platforms:
``GPU`` ``CPU``
``GPU`` ``CPU`` `Ascend`
Examples:
>>> input = Tensor([[[1, 2, 3], [4, 5, 6]]], mindspore.float32)