!49241 ResizeLinear1D 对接 ResizeD PASS,ResizeLinear1D 输入 size 支持 tuple,list
Merge pull request !49241 from haozhang/ResizeLinear1D
This commit is contained in:
commit
d92b7fbac5
|
@ -11,19 +11,20 @@ mindspore.ops.ResizeLinear1D
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
实验特性,接口可能发生变化。
|
实验特性,接口可能发生变化。
|
||||||
|
目前,昇腾平台仅支持输入 `size` 为Tuple或List的场景。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **coordinate_transformation_mode** (str) - 指定进行坐标变换的方式,默认值是"align_corners",还可选"half_pixel"和"asymmetric"。
|
- **coordinate_transformation_mode** (str) - 指定进行坐标变换的方式,默认值是"align_corners",还可选"half_pixel"和"asymmetric"。
|
||||||
|
|
||||||
输入:
|
输入:
|
||||||
- **x** (Tensor) - ResizeBilinear的输入,三维的Tensor,其shape为 :math:`(batch, channels, width)`。支持以下数据类型:float16、float32、double。
|
- **x** (Tensor) - ResizeBilinear的输入,三维的Tensor,其shape为 :math:`(batch, channels, width)`。支持以下数据类型:float16、float32、double。
|
||||||
- **size** (Tensor) - 指定 `x` 宽的新尺寸,一维的Tensor,其shape为 :math:`(1)` ,数据类型为int64。
|
- **size** (Union[Tuple[int], List[int], Tensor[int]) - 指定 `x` 宽的新尺寸,仅含一个整数 :math:`(new\_width)` 的Tuple、List或1-D Tensor。
|
||||||
|
|
||||||
输出:
|
输出:
|
||||||
Tensor,调整大小后的Tensor。shape为 :math:`(batch, channels, new\_width)` 的三维Tensor,数据类型和输入是一致的。
|
Tensor,调整大小后的Tensor。shape为 :math:`(batch, channels, new\_width)` 的三维Tensor,数据类型和输入是一致的。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `x` 的数据类型不支持。
|
- **TypeError** - `x` 的数据类型不支持。
|
||||||
- **TypeError** - `size` 不是int64的数据类型。
|
- **TypeError** - `size` 不是Tuple[int]、List[int]或Tensor[int]。
|
||||||
- **TypeError** - `coordinate_transformation_mode` 不是string。
|
- **TypeError** - `coordinate_transformation_mode` 不是string。
|
||||||
- **TypeError** - `coordinate_transformation_mode` 不在支持的列表中。
|
- **TypeError** - `coordinate_transformation_mode` 不在支持的列表中。
|
||||||
|
|
|
@ -671,6 +671,8 @@ constexpr auto kResizeNearestNeighborV2OpName = "ResizeNearestNeighborV2";
|
||||||
constexpr auto kResizeNearestNeighborV2DOpName = "ResizeNearestNeighborV2D";
|
constexpr auto kResizeNearestNeighborV2DOpName = "ResizeNearestNeighborV2D";
|
||||||
constexpr auto kReverseV2OpName = "ReverseV2";
|
constexpr auto kReverseV2OpName = "ReverseV2";
|
||||||
constexpr auto kReverseV2DOpName = "ReverseV2D";
|
constexpr auto kReverseV2DOpName = "ReverseV2D";
|
||||||
|
constexpr auto kResizeDOpName = "ResizeD";
|
||||||
|
constexpr auto kResizeGradDOpName = "ResizeGradD";
|
||||||
constexpr auto kReturnOpName = "Return";
|
constexpr auto kReturnOpName = "Return";
|
||||||
constexpr auto kRGBToHSVOpName = "RGBToHSV";
|
constexpr auto kRGBToHSVOpName = "RGBToHSV";
|
||||||
constexpr auto kROIAlignGradName = "ROIAlignGrad";
|
constexpr auto kROIAlignGradName = "ROIAlignGrad";
|
||||||
|
|
|
@ -42,6 +42,7 @@
|
||||||
#include "plugin/device/ascend/optimizer/ir_fission/renorm_split.h"
|
#include "plugin/device/ascend/optimizer/ir_fission/renorm_split.h"
|
||||||
#include "plugin/device/ascend/optimizer/ir_fission/tensor_scatter_fission.h"
|
#include "plugin/device/ascend/optimizer/ir_fission/tensor_scatter_fission.h"
|
||||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.h"
|
#include "plugin/device/ascend/optimizer/ir_fission/ascend_clip_by_norm_fission.h"
|
||||||
|
#include "plugin/device/ascend/optimizer/ir_fission/resize_linear1d_fission.h"
|
||||||
#include "backend/common/pass/communication_op_fusion.h"
|
#include "backend/common/pass/communication_op_fusion.h"
|
||||||
#include "backend/common/pass/dropout_gen_mask_fusion.h"
|
#include "backend/common/pass/dropout_gen_mask_fusion.h"
|
||||||
#include "backend/common/pass/dynamic_sequence_ops_adaptation.h"
|
#include "backend/common/pass/dynamic_sequence_ops_adaptation.h"
|
||||||
|
@ -389,6 +390,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
||||||
ir_fusion_pm->AddPass(std::make_shared<RenormSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<RenormSplit>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
|
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DFission>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DGradFission>());
|
||||||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||||
|
|
||||||
|
@ -485,6 +488,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
|
ir_fusion_pm->AddPass(std::make_shared<RemoveHostKernel>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DFission>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<ResizeLinear1DGradFission>());
|
||||||
const auto &pass_creators =
|
const auto &pass_creators =
|
||||||
opt::Factory<PatternProcessPass>::Instance().GetPassCreatorsByType(kPassType::kIRFusionFissionPass);
|
opt::Factory<PatternProcessPass>::Instance().GetPassCreatorsByType(kPassType::kIRFusionFissionPass);
|
||||||
for (const auto &pass_creator : pass_creators) {
|
for (const auto &pass_creator : pass_creators) {
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "plugin/device/ascend/optimizer/ir_fission/resize_linear1d_fission.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
#include "include/common/utils/utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kResizeLinear1DInputNum = 2;
|
||||||
|
constexpr int64_t kExpandDim = -1;
|
||||||
|
constexpr int64_t kSqueezeDim = 2;
|
||||||
|
|
||||||
|
AnfNodePtr AddExpandDimsNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||||
|
const PatternProcessPass &pass) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
|
||||||
|
// Add ExpandDims Node
|
||||||
|
std::vector<AnfNodePtr> expand_dims_inputs = {
|
||||||
|
NewValueNode(std::make_shared<Primitive>(prim::kPrimExpandDims->name())), input_node};
|
||||||
|
auto expand_dims = pass.NewCNode(expand_dims_inputs, func_graph);
|
||||||
|
|
||||||
|
// Set ExpandDims OutShape and Type
|
||||||
|
auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||||
|
auto expand_shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||||
|
(void)expand_shape.insert(expand_shape.end() + kExpandDim, 1);
|
||||||
|
(void)common::AnfAlgo::SetOutputInferTypeAndShape({dtype}, {expand_shape}, expand_dims.get());
|
||||||
|
|
||||||
|
// Set ExpandDims Attr
|
||||||
|
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(kExpandDim), expand_dims);
|
||||||
|
common::AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), expand_dims);
|
||||||
|
|
||||||
|
return expand_dims;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef ResizeLinear1DFission::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
auto resize_linear1d_prim = std::make_shared<Primitive>(prim::kPrimResizeLinear1D->name());
|
||||||
|
return VectorRef({resize_linear1d_prim, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr ResizeLinear1DFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto resize_linear1d = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(resize_linear1d);
|
||||||
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resize_linear1d->size() != kResizeLinear1DInputNum + 1) {
|
||||||
|
MS_LOG(INFO) << "The node " << resize_linear1d->DebugString() << " is not equal to " << kResizeLinear1DInputNum
|
||||||
|
<< "inputs";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!common::AnfAlgo::HasNodeAttr("coordinate_transformation_mode", resize_linear1d)) {
|
||||||
|
MS_LOG(EXCEPTION) << "ResizeLinear1D need to set coordinate_transformation_mode attribute.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto ori_inputs = resize_linear1d->inputs();
|
||||||
|
|
||||||
|
// Add ExpandDims Node
|
||||||
|
auto expand_dims = AddExpandDimsNode(func_graph, ori_inputs[kDim1], *this);
|
||||||
|
|
||||||
|
// Get ResizeD Node
|
||||||
|
std::vector<AnfNodePtr> resize_d_inputs = {NewValueNode(std::make_shared<Primitive>(kResizeDOpName)), expand_dims};
|
||||||
|
auto resize_d = func_graph->NewCNode(resize_d_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(resize_d);
|
||||||
|
|
||||||
|
// Set ResizeD OutShape and Type
|
||||||
|
auto out_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||||
|
auto out_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
|
||||||
|
(void)out_shape.insert(out_shape.end() + kExpandDim, 1);
|
||||||
|
(void)common::AnfAlgo::SetOutputInferTypeAndShape({out_type}, {out_shape}, resize_d.get());
|
||||||
|
|
||||||
|
// Set ResizeD Attr
|
||||||
|
std::vector<int64_t> size_value = {out_shape[kIndex3]};
|
||||||
|
auto x_shape = common::AnfAlgo::GetOutputInferShape(ori_inputs[kDim1], 0);
|
||||||
|
float scale = static_cast<float>(size_value[kDim0]) / static_cast<float>(x_shape[kDim2]);
|
||||||
|
std::vector<float> scales = {scale};
|
||||||
|
common::AnfAlgo::SetNodeAttr("sizes", MakeValue(size_value), resize_d);
|
||||||
|
common::AnfAlgo::SetNodeAttr("scales", MakeValue(scales), resize_d);
|
||||||
|
common::AnfAlgo::SetNodeAttr("mode", MakeValue("linear"), resize_d);
|
||||||
|
common::AnfAlgo::CopyNodeAttr("coordinate_transformation_mode", resize_linear1d, resize_d);
|
||||||
|
|
||||||
|
// Get Squeeze Node
|
||||||
|
std::vector<AnfNodePtr> squeeze_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSqueeze->name())),
|
||||||
|
resize_d};
|
||||||
|
auto squeeze = func_graph->NewCNode(squeeze_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(squeeze);
|
||||||
|
|
||||||
|
// Set Squeeze Attr
|
||||||
|
std::vector<int64_t> axis = {kSqueezeDim};
|
||||||
|
common::AnfAlgo::SetNodeAttr("axis", MakeValue(axis), squeeze);
|
||||||
|
|
||||||
|
// Set abstract and scope
|
||||||
|
squeeze->set_abstract(resize_linear1d->abstract());
|
||||||
|
squeeze->set_scope(resize_linear1d->scope());
|
||||||
|
|
||||||
|
return squeeze;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef ResizeLinear1DGradFission::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
auto resize_linear1d_grad_prim = std::make_shared<Primitive>(prim::kPrimResizeLinear1DGrad->name());
|
||||||
|
return VectorRef({resize_linear1d_grad_prim, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr ResizeLinear1DGradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto resize_linear1d_grad = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(resize_linear1d_grad);
|
||||||
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
|
||||||
|
if (resize_linear1d_grad->size() != kResizeLinear1DInputNum + 1) {
|
||||||
|
MS_LOG(INFO) << "The node " << resize_linear1d_grad->DebugString() << " is not equal to " << kResizeLinear1DInputNum
|
||||||
|
<< "inputs";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!common::AnfAlgo::HasNodeAttr("coordinate_transformation_mode", resize_linear1d_grad)) {
|
||||||
|
MS_LOG(EXCEPTION) << "ResizeLinear1DGrad need to set coordinate_transformation_mode attribute.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto ori_inputs = resize_linear1d_grad->inputs();
|
||||||
|
|
||||||
|
// Add ExpandDims Node
|
||||||
|
auto expand_dims = AddExpandDimsNode(func_graph, ori_inputs[kDim1], *this);
|
||||||
|
|
||||||
|
// Get ResizeGradD Node
|
||||||
|
std::vector<AnfNodePtr> resize_grad_d_inputs = {NewValueNode(std::make_shared<Primitive>(kResizeGradDOpName)),
|
||||||
|
expand_dims};
|
||||||
|
auto resize_grad_d = func_graph->NewCNode(resize_grad_d_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(resize_grad_d);
|
||||||
|
|
||||||
|
// Set ResizeGradD OutShape and Type
|
||||||
|
auto out_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||||
|
auto out_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
|
||||||
|
(void)out_shape.insert(out_shape.end() + kExpandDim, 1);
|
||||||
|
(void)common::AnfAlgo::SetOutputInferTypeAndShape({out_type}, {out_shape}, resize_grad_d.get());
|
||||||
|
|
||||||
|
// Set ResizeGradD Attr
|
||||||
|
auto origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex1);
|
||||||
|
auto x_shape = common::AnfAlgo::GetOutputInferShape(ori_inputs[kDim1], 0);
|
||||||
|
float scale = static_cast<float>(x_shape[kDim2]) / static_cast<float>(origin_shape[kDim2]);
|
||||||
|
std::vector<float> scales = {scale};
|
||||||
|
common::AnfAlgo::SetNodeAttr("original_size", MakeValue(origin_shape), resize_grad_d);
|
||||||
|
common::AnfAlgo::SetNodeAttr("scales", MakeValue(scales), resize_grad_d);
|
||||||
|
common::AnfAlgo::SetNodeAttr("mode", MakeValue("linear"), resize_grad_d);
|
||||||
|
common::AnfAlgo::CopyNodeAttr("coordinate_transformation_mode", resize_linear1d_grad, resize_grad_d);
|
||||||
|
|
||||||
|
// Get Squeeze Node
|
||||||
|
std::vector<AnfNodePtr> squeeze_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSqueeze->name())),
|
||||||
|
resize_grad_d};
|
||||||
|
auto squeeze = func_graph->NewCNode(squeeze_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(squeeze);
|
||||||
|
|
||||||
|
// Set Squeeze Attr
|
||||||
|
std::vector<int64_t> axis = {kSqueezeDim};
|
||||||
|
common::AnfAlgo::SetNodeAttr("axis", MakeValue(axis), squeeze);
|
||||||
|
|
||||||
|
// Set abstract and scope
|
||||||
|
squeeze->set_abstract(resize_linear1d_grad->abstract());
|
||||||
|
squeeze->set_scope(resize_linear1d_grad->scope());
|
||||||
|
|
||||||
|
return squeeze;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "backend/common/optimizer/pattern_engine.h"
|
||||||
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
#include "backend/common/optimizer/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class ResizeLinear1DFission : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit ResizeLinear1DFission(bool multigraph = true) : PatternProcessPass("resize_linear1d_fission", multigraph) {}
|
||||||
|
~ResizeLinear1DFission() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ResizeLinear1DGradFission : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit ResizeLinear1DGradFission(bool multigraph = true)
|
||||||
|
: PatternProcessPass("resize_linear1d_grad_fission", multigraph) {}
|
||||||
|
~ResizeLinear1DGradFission() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESIZE_LINEAR1D_FISSION_H_
|
|
@ -20,6 +20,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "abstract/ops/primitive_infer_map.h"
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
#include "abstract/dshape.h"
|
#include "abstract/dshape.h"
|
||||||
|
@ -50,89 +51,68 @@ const int64_t kInputShape0Dim = 3;
|
||||||
const int64_t kInputShape1Dim = 1;
|
const int64_t kInputShape1Dim = 1;
|
||||||
abstract::ShapePtr ResizeLinear1DInferShape(const PrimitivePtr &primitive,
|
abstract::ShapePtr ResizeLinear1DInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto input_x_arg = input_args[kInputIndex0];
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto size_arg = input_args[kInputIndex1];
|
auto prim_name = primitive->name();
|
||||||
if (!input_x_arg->isa<abstract::AbstractTensor>()) {
|
|
||||||
MS_EXCEPTION(TypeError) << "Images only support tensor!";
|
|
||||||
}
|
|
||||||
if (!size_arg->isa<abstract::AbstractTensor>()) {
|
|
||||||
MS_EXCEPTION(TypeError) << "Size only support tensor!";
|
|
||||||
}
|
|
||||||
auto input_x_shape = input_x_arg->cast<abstract::AbstractTensorPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_shape);
|
|
||||||
auto input_x_shape_value_ptr = input_x_shape->BuildValue();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_shape_value_ptr);
|
|
||||||
auto input_x_type = input_x_arg->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_type);
|
|
||||||
auto input_x_type_id = input_x_type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_type_id);
|
|
||||||
auto input_x_type_element = input_x_type_id->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_x_type_element);
|
|
||||||
auto size_shape = size_arg->cast<abstract::AbstractTensorPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(size_shape);
|
|
||||||
auto size_shape_value_ptr = size_shape->BuildValue();
|
|
||||||
MS_EXCEPTION_IF_NULL(size_shape_value_ptr);
|
|
||||||
auto size_shape_tensor = size_shape_value_ptr->cast<tensor::TensorPtr>();
|
|
||||||
auto size_type = size_arg->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(size_type);
|
|
||||||
auto size_type_id = size_type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(size_type_id);
|
|
||||||
auto size_type_element = size_type_id->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(size_type_element);
|
|
||||||
auto shape0_ptr = std::make_shared<abstract::Shape>(
|
|
||||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_arg->BuildShape())[kShape]);
|
|
||||||
auto shape1_ptr =
|
|
||||||
std::make_shared<abstract::Shape>(CheckAndConvertUtils::ConvertShapePtrToShapeMap(size_arg->BuildShape())[kShape]);
|
|
||||||
auto shape0_v = shape0_ptr->shape();
|
|
||||||
auto shape1_v = shape1_ptr->shape();
|
|
||||||
// support dynamic shape
|
|
||||||
if (IsDynamicRank(shape0_v) || IsDynamicRank(shape1_v)) {
|
|
||||||
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
|
|
||||||
}
|
|
||||||
if (shape0_v.size() < kInputShape0Dim) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the rank of images tensor must be greater than 3. But got "
|
|
||||||
<< shape0_v.size();
|
|
||||||
}
|
|
||||||
if (shape1_v.size() != kInputShape1Dim) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For 'ResizeLinear1D', the size tensor must be a 1-D tensor. But got "
|
|
||||||
<< shape1_v.size() << "-D";
|
|
||||||
}
|
|
||||||
if (size_arg->isa<abstract::AbstractTensor>() && size_arg->BuildValue()->isa<tensor::Tensor>()) {
|
|
||||||
int64_t out_width = 0;
|
|
||||||
if (size_shape_tensor->data_type() == kNumberTypeInt32) {
|
|
||||||
auto size_shape_ptr = reinterpret_cast<int32_t *>(size_shape_tensor->data_c());
|
|
||||||
out_width = static_cast<int64_t>(size_shape_ptr[kInputIndex0]);
|
|
||||||
} else if (size_shape_tensor->data_type() == kNumberTypeInt64) {
|
|
||||||
auto size_shape_ptr = reinterpret_cast<int64_t *>(size_shape_tensor->data_c());
|
|
||||||
out_width = size_shape_ptr[kInputIndex0];
|
|
||||||
}
|
|
||||||
if (out_width <= 0) {
|
|
||||||
MS_EXCEPTION(ValueError) << "The size must be positive , but got " << out_width;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> output_shape = shape0_v;
|
const int64_t shape0_dim = 3;
|
||||||
output_shape.pop_back();
|
std::vector<int64_t> output_shape(shape0_dim, abstract::Shape::kShapeDimAny);
|
||||||
output_shape.push_back(out_width);
|
|
||||||
return std::make_shared<abstract::Shape>(output_shape);
|
auto shape0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
} else {
|
if (!IsDynamicRank(shape0)) {
|
||||||
ShapeVector shape_out = shape0_v;
|
(void)CheckAndConvertUtils::CheckInteger("images' rank", SizeToLong(shape0.size()), kEqual, shape0_dim, prim_name);
|
||||||
shape_out.pop_back();
|
output_shape[kInputIndex0] = shape0[kInputIndex0];
|
||||||
shape_out.push_back(abstract::Shape::kShapeDimAny);
|
output_shape[kInputIndex1] = shape0[kInputIndex1];
|
||||||
return std::make_shared<abstract::Shape>(shape_out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
|
||||||
|
if (!IsValueKnown(value_ptr)) {
|
||||||
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto size_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
std::vector<int64_t> size_value{};
|
||||||
|
if (size_type->isa<TensorType>()) {
|
||||||
|
const int64_t kDimOne = 1;
|
||||||
|
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("rank of size's shape", SizeToLong(size_shape.size()), kEqual, kDimOne,
|
||||||
|
prim_name);
|
||||||
|
size_value = CheckAndConvertUtils::CheckTensorIntValue("size", value_ptr, prim_name);
|
||||||
|
} else if (IsIdentidityOrSubclass(size_type, kTuple) || IsIdentidityOrSubclass(size_type, kList)) {
|
||||||
|
size_value = CheckAndConvertUtils::CheckIntOrTupleInt("size", value_ptr, prim_name);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the `size` "
|
||||||
|
<< " must be a tuple、list or tensor with all Int elements, but got "
|
||||||
|
<< value_ptr->type_name() << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t size_num = 1;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("size", SizeToLong(size_value.size()), kEqual, size_num, prim_name);
|
||||||
|
const int64_t kNumZero = 0;
|
||||||
|
for (size_t i = 0; i < size_value.size(); ++i) {
|
||||||
|
CheckAndConvertUtils::CheckInteger("size", size_value[i], kGreaterThan, kNumZero, prim_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
output_shape[kInputIndex2] = size_value[kInputIndex0];
|
||||||
|
|
||||||
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr ResizeLinear1DInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr ResizeLinear1DInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||||
MS_LOG(EXCEPTION) << "For 'ResizeLinear1D', input args contain nullptr.";
|
MS_LOG(EXCEPTION) << "For 'ResizeLinear1D', input args contain nullptr.";
|
||||||
}
|
}
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
auto input_x_arg = input_args[kInputIndex0];
|
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||||
auto size_arg = input_args[kInputIndex1];
|
auto size_type = input_args[kInputIndex1]->BuildType();
|
||||||
const std::set<TypePtr> valid0_types = {kFloat16, kFloat32, kFloat64};
|
const std::set<TypePtr> valid0_types = {kFloat16, kFloat32, kFloat64};
|
||||||
const std::set<TypePtr> valid1_types = {kInt64, kInt32};
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", x_type, valid0_types, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", input_x_arg->BuildType(), valid0_types, prim_name);
|
if (size_type->isa<TensorType>()) {
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", size_arg->BuildType(), valid1_types, prim_name);
|
const std::set<TypePtr> valid1_types = {kInt32, kInt64};
|
||||||
return input_x_arg->BuildType();
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", size_type, valid1_types, prim_name);
|
||||||
|
}
|
||||||
|
return x_type;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -487,7 +487,7 @@ class Upsample(Cell):
|
||||||
upsample = P.image_ops.ResizeLinear1D(
|
upsample = P.image_ops.ResizeLinear1D(
|
||||||
coordinate_transformation_mode=coordinate_transformation_mode
|
coordinate_transformation_mode=coordinate_transformation_mode
|
||||||
)
|
)
|
||||||
return upsample(tensor, Tensor(size, dtype=mstype.int32))
|
return upsample(tensor, size)
|
||||||
|
|
||||||
def run_bilinear(tensor, ndim, size):
|
def run_bilinear(tensor, ndim, size):
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
|
|
|
@ -2102,7 +2102,7 @@ def interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=N
|
||||||
resize = _get_cache_prim(P.image_ops.ResizeLinear1D)(
|
resize = _get_cache_prim(P.image_ops.ResizeLinear1D)(
|
||||||
coordinate_transformation_mode
|
coordinate_transformation_mode
|
||||||
)
|
)
|
||||||
return resize(x, Tensor(size, dtype=mstype.int32))
|
return resize(x, size)
|
||||||
|
|
||||||
def run_bilinear(x, size, align_corners=None, scale_factor=None):
|
def run_bilinear(x, size, align_corners=None, scale_factor=None):
|
||||||
resize = _get_cache_prim(P.ResizeBilinearV2)(align_corners, not align_corners)
|
resize = _get_cache_prim(P.ResizeBilinearV2)(align_corners, not align_corners)
|
||||||
|
|
|
@ -641,6 +641,7 @@ class ResizeLinear1D(Primitive):
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
This is an experimental feature and is subjected to change.
|
This is an experimental feature and is subjected to change.
|
||||||
|
Currently, the Ascend platform only supports scenarios where the input `size` is Tuple or List.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
coordinate_transformation_mode (str): Default is 'align_corners'. Describes how to transform the coordinate
|
coordinate_transformation_mode (str): Default is 'align_corners'. Describes how to transform the coordinate
|
||||||
|
@ -649,19 +650,20 @@ class ResizeLinear1D(Primitive):
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - A 3-D tensor which to resize, with shape [batch, channel, width]. Must be one of the
|
- **x** (Tensor) - A 3-D tensor which to resize, with shape [batch, channel, width]. Must be one of the
|
||||||
following types: uint8, int8, int16, int32, int64, float16, float32, double.
|
following types: uint8, int8, int16, int32, int64, float16, float32, double.
|
||||||
- **size** (Tensor) - A 1-D int64 Tensor, describes the size of the output tensor.
|
- **size** (Union[Tuple[int], List[int], Tensor[int]]): describes the new width of `x` .
|
||||||
|
A tuple or list or 1-D tensor with only one int element :math:`(new\_width)`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
A 3-D tensor which shape is [batch, channel, new_width] with the same type as `x`.
|
A 3-D tensor which shape is [batch, channel, new_width] with the same type as `x`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If dtype of `x` is not in the support list.
|
TypeError: If dtype of `x` is not in the support list.
|
||||||
TypeError: If `size` is not a 1-D int64_t tensor.
|
TypeError: If `size` is not in Union[Tuple[int], List[int], Tensor[int]].
|
||||||
TypeError: If `coordinate_transformation_mode` is not a string.
|
TypeError: If `coordinate_transformation_mode` is not a string.
|
||||||
TypeError: If `coordinate_transformation_mode` is not in the support list.
|
TypeError: If `coordinate_transformation_mode` is not in the support list.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU`` ``CPU``
|
``GPU`` ``CPU`` `Ascend`
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input = Tensor([[[1, 2, 3], [4, 5, 6]]], mindspore.float32)
|
>>> input = Tensor([[[1, 2, 3], [4, 5, 6]]], mindspore.float32)
|
||||||
|
|
Loading…
Reference in New Issue