!69063 Add pass for dyn paddings in PadV3 on Ascend
Merge pull request !69063 from yangruoqi713/master
This commit is contained in:
commit
7b8f8e7193
|
@ -1562,18 +1562,24 @@ AnfNodePtr CreateValueNodeWithKernelInfo(const FuncGraphPtr &graph, const ValueP
|
|||
value_node->set_kernel_info(kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
MS_EXCEPTION_IF_NULL(value->type());
|
||||
auto type_id = value->type()->type_id();
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_sequence = value->cast<ValueSequencePtr>()->value();
|
||||
if (value_sequence.empty()) {
|
||||
type_id = kNumberTypeInt64;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(value_sequence[0]->type());
|
||||
type_id = value_sequence[0]->type()->type_id();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
builder.SetOutputsDeviceType({tensor->data_type()});
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(value->type());
|
||||
auto type_id = value->type()->type_id();
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_sequence = value->cast<ValueSequencePtr>()->value();
|
||||
if (value_sequence.empty()) {
|
||||
type_id = kNumberTypeInt64;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(value_sequence[0]->type());
|
||||
type_id = value_sequence[0]->type()->type_id();
|
||||
}
|
||||
}
|
||||
builder.SetOutputsDeviceType({type_id});
|
||||
}
|
||||
builder.SetOutputsDeviceType({type_id});
|
||||
auto object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(value_abs));
|
||||
builder.SetOutputsKernelObjectType({object_type});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "mindspore/core/ops/array_op_name.h"
|
||||
#include "mindspore/core/ops/sequence_op_name.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_ops_name.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -37,9 +39,6 @@ bool ConvertBasePaddings::HasDynPaddings(const CNodePtr &cnode) const {
|
|||
MS_EXCEPTION_IF_NULL(paddings_abstract);
|
||||
auto paddings_value = paddings_abstract->GetValue();
|
||||
MS_EXCEPTION_IF_NULL(paddings_value);
|
||||
if (paddings_value->isa<ValueAny>() || paddings_value->isa<None>()) {
|
||||
return true;
|
||||
}
|
||||
auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, kIndex1);
|
||||
if (input_paddings_type_id == kNumberTypeInt32) {
|
||||
auto paddings_array_value = ops::GetArrayValue<int32_t>(paddings_value);
|
||||
|
@ -64,6 +63,146 @@ const CNodePtr ConvertBasePaddings::CreateReshapeNode(const FuncGraphPtr &graph,
|
|||
return reshape_node;
|
||||
}
|
||||
|
||||
const CNodePtr ConvertBasePaddings::CreateStridedSliceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
int64_t index) const {
|
||||
// set inputs
|
||||
auto begin_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector<int64_t>{index}));
|
||||
MS_EXCEPTION_IF_NULL(begin_node);
|
||||
auto end_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector<int64_t>{index + 1}));
|
||||
MS_EXCEPTION_IF_NULL(end_node);
|
||||
auto strides_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(std::vector<int64_t>{1}));
|
||||
MS_EXCEPTION_IF_NULL(strides_node);
|
||||
int64_t const_value = 0;
|
||||
auto begin_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value));
|
||||
MS_EXCEPTION_IF_NULL(begin_mask);
|
||||
auto end_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value));
|
||||
MS_EXCEPTION_IF_NULL(end_mask);
|
||||
auto ellipsis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value));
|
||||
MS_EXCEPTION_IF_NULL(ellipsis_mask);
|
||||
auto new_axis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value));
|
||||
MS_EXCEPTION_IF_NULL(new_axis_mask);
|
||||
auto shrink_axis_mask = CreateValueNodeWithKernelInfo(func_graph, MakeValue(const_value));
|
||||
MS_EXCEPTION_IF_NULL(shrink_axis_mask);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kStridedSliceOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim), input_node, begin_node, end_node, strides_node,
|
||||
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask};
|
||||
auto strided_slice_node = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_node);
|
||||
auto abs = InferAbstract(prim, {input_node, begin_node, end_node, strides_node, begin_mask, end_mask, ellipsis_mask,
|
||||
new_axis_mask, shrink_axis_mask});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
strided_slice_node->set_abstract(abs);
|
||||
static size_t slice_index = 0;
|
||||
strided_slice_node->set_fullname_with_scope(input_node->fullname_with_scope() + "_strided_slice_" +
|
||||
std::to_string(slice_index++));
|
||||
return strided_slice_node;
|
||||
}
|
||||
|
||||
const CNodePtr ConvertBasePaddings::CreateConcatNode(const FuncGraphPtr &func_graph,
|
||||
const std::vector<AnfNodePtr> &concat_input_vec,
|
||||
const std::string &concat_node_name) const {
|
||||
auto concat_prim = std::make_shared<Primitive>(kConcatOpName);
|
||||
MS_EXCEPTION_IF_NULL(concat_prim);
|
||||
std::vector<int64_t> dyn_input_sizes = {SizeToLong(concat_input_vec.size()), -1};
|
||||
concat_prim->AddAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes));
|
||||
|
||||
AnfNodePtrList inputs = {NewValueNode(concat_prim)};
|
||||
inputs.insert(inputs.end(), concat_input_vec.begin(), concat_input_vec.end());
|
||||
int64_t axis = 0;
|
||||
auto axis_node = CreateValueNodeWithKernelInfo(func_graph, MakeValue(axis));
|
||||
inputs.push_back(axis_node);
|
||||
auto concat_node = NewCNode(inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(concat_node);
|
||||
|
||||
std::vector<AnfNodePtr> concat_inputs = concat_input_vec;
|
||||
concat_inputs.push_back(axis_node);
|
||||
auto concat_abs = InferAbstract(concat_prim, concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat_abs);
|
||||
concat_node->set_abstract(concat_abs);
|
||||
concat_node->set_fullname_with_scope(concat_node_name);
|
||||
return concat_node;
|
||||
}
|
||||
|
||||
const CNodePtr ConvertBasePaddings::ProcessSliceNConcat(const FuncGraphPtr &func_graph, const AnfNodePtr &pad_node,
|
||||
const AnfNodePtr &input_node, const int64_t &padding_dst_length,
|
||||
const int64_t &padding_src_length) const {
|
||||
auto prim = GetCNodePrimitive(pad_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto paddings_contiguous = GetValue<bool>(prim->GetAttr("paddings_contiguous"));
|
||||
std::vector<AnfNodePtr> concat_input_vec;
|
||||
|
||||
// slice and insert to concat in reverse order
|
||||
if (paddings_contiguous) {
|
||||
for (int64_t i = 0; i < padding_src_length; i += static_cast<int64_t>(kSizeTwo)) {
|
||||
auto slice_node_2 = CreateStridedSliceNode(func_graph, input_node, i + kSizeOne);
|
||||
concat_input_vec.insert(concat_input_vec.begin(), slice_node_2);
|
||||
|
||||
auto slice_node_1 = CreateStridedSliceNode(func_graph, input_node, i);
|
||||
concat_input_vec.insert(concat_input_vec.begin(), slice_node_1);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < padding_src_length / 2; ++i) {
|
||||
auto slice_node_2 = CreateStridedSliceNode(func_graph, input_node, i + padding_src_length / 2);
|
||||
concat_input_vec.insert(concat_input_vec.begin(), slice_node_2);
|
||||
|
||||
auto slice_node_1 = CreateStridedSliceNode(func_graph, input_node, i);
|
||||
concat_input_vec.insert(concat_input_vec.begin(), slice_node_1);
|
||||
}
|
||||
prim->AddAttr("paddings_contiguous", MakeValue(True));
|
||||
}
|
||||
|
||||
if (padding_dst_length > padding_src_length) {
|
||||
auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(pad_node, kIndex1);
|
||||
std::shared_ptr<tensor::Tensor> fill_tensor;
|
||||
if (input_paddings_type_id == kNumberTypeInt32) {
|
||||
fill_tensor =
|
||||
std::make_shared<tensor::Tensor>(std::vector<int32_t>(padding_dst_length - padding_src_length, 0), kInt32);
|
||||
} else if (input_paddings_type_id == kNumberTypeInt64) {
|
||||
fill_tensor =
|
||||
std::make_shared<tensor::Tensor>(std::vector<int64_t>(padding_dst_length - padding_src_length, 0), kInt64);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "Unsupported data type for PadV3 padddings input.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(fill_tensor);
|
||||
auto fill_node = CreateValueNodeWithKernelInfo(func_graph, fill_tensor);
|
||||
MS_EXCEPTION_IF_NULL(fill_node);
|
||||
concat_input_vec.insert(concat_input_vec.begin(), fill_node);
|
||||
}
|
||||
static size_t concat_index = 0;
|
||||
auto concat_node =
|
||||
CreateConcatNode(func_graph, concat_input_vec,
|
||||
pad_node->fullname_with_scope() + "_pad_slice_concat" + std::to_string(concat_index++));
|
||||
return concat_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertBasePaddings::CreateDynPaddingsPass(const FuncGraphPtr &graph, const CNodePtr &pad_node,
|
||||
const bool &is_grad) const {
|
||||
// For dyn paddings in PadV3 and PadV3Grad on Ascend, add StridedSlice -> Concat to adjust paddings in ge::PadV3.
|
||||
auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(pad_node, kIndex0);
|
||||
size_t dst_length = input_x_shape.size() * 2;
|
||||
auto prim_name = "PadV3";
|
||||
if (is_grad) {
|
||||
prim_name = "PadV3Grad";
|
||||
}
|
||||
|
||||
auto paddings = common::AnfAlgo::GetInputNode(pad_node, kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(paddings);
|
||||
auto paddings_abstract = paddings->abstract();
|
||||
MS_EXCEPTION_IF_NULL(paddings_abstract);
|
||||
auto paddings_shape_ptr = paddings_abstract->GetShape();
|
||||
MS_EXCEPTION_IF_NULL(paddings_shape_ptr);
|
||||
auto paddings_shape = paddings_shape_ptr->GetShapeVector();
|
||||
(void)CheckAndConvertUtils::CheckInteger("paddings_shape_size", SizeToLong(paddings_shape.size()), kEqual, kDim1,
|
||||
prim_name);
|
||||
auto paddings_length = paddings_shape[0];
|
||||
// Not implemented: if is_grad and dst_length < 8, the filled paddings should be expanded to 8.
|
||||
auto concat_node = ProcessSliceNConcat(graph, pad_node, paddings, dst_length, paddings_length);
|
||||
MS_EXCEPTION_IF_NULL(concat_node);
|
||||
return concat_node;
|
||||
}
|
||||
|
||||
template <typename T, TypeId type_id>
|
||||
const AnfNodePtr ConvertBasePaddings::OptimizePaddingsValue(const FuncGraphPtr &graph,
|
||||
const AbstractBasePtr &ori_paddings,
|
||||
|
@ -127,6 +266,25 @@ const AnfNodePtr ConvertBasePaddings::OptimizePaddingsValue(const FuncGraphPtr &
|
|||
return extend_paddings;
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertBasePaddings::CreateConstPaddingsNode(const FuncGraphPtr &graph,
|
||||
const CNodePtr &pad_node) const {
|
||||
auto prim = GetCNodePrimitive(pad_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto paddings_contiguous = GetValue<bool>(prim->GetAttr("paddings_contiguous"));
|
||||
// ge::padV3 only support that the length of `paddings` is twice than the rank of `x`
|
||||
auto input_paddings = common::AnfAlgo::GetInputNode(pad_node, kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(input_paddings);
|
||||
auto paddings_abstract = input_paddings->abstract();
|
||||
MS_EXCEPTION_IF_NULL(paddings_abstract);
|
||||
|
||||
auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(pad_node, kIndex0);
|
||||
auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(pad_node, kIndex1);
|
||||
auto paddings_value_node = CreateConstPaddingsPass(graph, paddings_abstract, paddings_contiguous,
|
||||
input_x_shape.size() * 2, input_paddings_type_id);
|
||||
MS_EXCEPTION_IF_NULL(paddings_value_node);
|
||||
return paddings_value_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -135,30 +293,24 @@ const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const A
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex0);
|
||||
auto input_paddings_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, kIndex1);
|
||||
auto opt_paddings_size = 2 * input_x_shape.size();
|
||||
auto padding_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex1);
|
||||
if (IsDynamicRank(input_x_shape) || IsDynamic(padding_shape)) {
|
||||
MS_LOG_EXCEPTION << "The input is dynamic rank";
|
||||
}
|
||||
|
||||
if (HasDynPaddings(cnode)) {
|
||||
MS_EXCEPTION(TypeError) << "While running in Ascend, the input [paddings] of PadV3 is required to be constant, but "
|
||||
"that is dynamic in node["
|
||||
<< node->fullname_with_scope() << "]";
|
||||
auto concat_node = CreateDynPaddingsNode(graph, cnode);
|
||||
MS_EXCEPTION_IF_NULL(concat_node);
|
||||
auto node_prim = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
node_prim->AddAttr("is_dyn_paddings", MakeValue(true));
|
||||
cnode->set_input(kIndex2, concat_node);
|
||||
} else {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto paddings_contiguous = GetValue<bool>(prim->GetAttr("paddings_contiguous"));
|
||||
// ge::padV3 only support that the length of `paddings` is twice than the rank of `x`
|
||||
auto input_paddings = common::AnfAlgo::GetInputNode(cnode, kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(input_paddings);
|
||||
auto paddings_abstract = input_paddings->abstract();
|
||||
MS_EXCEPTION_IF_NULL(paddings_abstract);
|
||||
auto paddings_type = paddings_abstract->GetType();
|
||||
MS_EXCEPTION_IF_NULL(paddings_type);
|
||||
|
||||
auto paddings_value_node =
|
||||
CreatePaddingsNode(graph, paddings_abstract, paddings_contiguous, opt_paddings_size, input_paddings_type_id);
|
||||
auto paddings_value_node = CreateConstPaddingsNode(graph, cnode);
|
||||
MS_EXCEPTION_IF_NULL(paddings_value_node);
|
||||
cnode->set_input(kIndex2, paddings_value_node);
|
||||
}
|
||||
// Not verified: for PadV3Grad, if the input tensor rand < 4, the input should be expanded to 4.
|
||||
auto is_expand = ExpandInputXDims(graph, cnode);
|
||||
if (is_expand) {
|
||||
ReduceOutputDims(graph, cnode);
|
||||
|
@ -166,6 +318,18 @@ const AnfNodePtr ConvertBasePaddings::Process(const FuncGraphPtr &graph, const A
|
|||
return node;
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertPadV3GradPaddings::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (HasDynPaddings(cnode)) {
|
||||
MS_EXCEPTION(RuntimeError) << "PadV3Grad doesn't support dynamic paddings input.";
|
||||
}
|
||||
return ConvertBasePaddings::Process(graph, node, equiv);
|
||||
}
|
||||
|
||||
bool ConvertPadV3GradPaddings::ExpandInputXDims(const FuncGraphPtr &graph, const CNodePtr &node) const {
|
||||
auto input_x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, kIndex0);
|
||||
auto input_x_rank = input_x_shape.size();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_CONVERT_PAD_V3_PADDINGS_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_CONVERT_PAD_V3_PADDINGS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/backend/optimizer/optimizer.h"
|
||||
|
||||
|
@ -31,13 +32,25 @@ class ConvertBasePaddings : public PatternProcessPass {
|
|||
|
||||
bool HasDynPaddings(const CNodePtr &) const;
|
||||
const CNodePtr CreateReshapeNode(const FuncGraphPtr &, const AnfNodePtr &, const ShapeVector &) const;
|
||||
const CNodePtr CreateStridedSliceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
int64_t index) const;
|
||||
const CNodePtr CreateConcatNode(const FuncGraphPtr &, const std::vector<AnfNodePtr> &, const std::string &) const;
|
||||
const CNodePtr ProcessSliceNConcat(const FuncGraphPtr &, const AnfNodePtr &, const AnfNodePtr &, const int64_t &,
|
||||
const int64_t &) const;
|
||||
|
||||
const AnfNodePtr CreateDynPaddingsPass(const FuncGraphPtr &, const CNodePtr &, const bool &) const;
|
||||
virtual const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &, const CNodePtr &) const { return nullptr; }
|
||||
|
||||
template <typename T, TypeId type_id>
|
||||
const AnfNodePtr OptimizePaddingsValue(const FuncGraphPtr &, const AbstractBasePtr &, const bool &, const size_t &,
|
||||
bool force_length8) const;
|
||||
virtual const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &, const AbstractBasePtr &, const bool &,
|
||||
const size_t &, const TypeId &) const {
|
||||
virtual const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &, const AbstractBasePtr &, const bool &,
|
||||
const size_t &, const TypeId &) const {
|
||||
return nullptr;
|
||||
}
|
||||
const AnfNodePtr CreateConstPaddingsNode(const FuncGraphPtr &, const CNodePtr &) const;
|
||||
|
||||
private:
|
||||
virtual bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const { return false; }
|
||||
virtual void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const {}
|
||||
};
|
||||
|
@ -50,15 +63,21 @@ class ConvertPadV3Paddings : public ConvertBasePaddings {
|
|||
const BaseRef DefinePattern() const override;
|
||||
|
||||
private:
|
||||
const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings,
|
||||
const bool &paddings_contiguous, const size_t &dst_length,
|
||||
const TypeId &type_id) const override {
|
||||
const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings,
|
||||
const bool &paddings_contiguous, const size_t &dst_length,
|
||||
const TypeId &type_id) const override {
|
||||
if (type_id == kNumberTypeInt32) {
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int32_t, kNumberTypeInt32>(
|
||||
graph, ori_paddings, paddings_contiguous, dst_length, false);
|
||||
} else if (type_id == kNumberTypeInt64) {
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int64_t, kNumberTypeInt64>(
|
||||
graph, ori_paddings, paddings_contiguous, dst_length, false);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "Unsupported data type for PadV3 paddings input.";
|
||||
}
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int64_t, kNumberTypeInt64>(
|
||||
graph, ori_paddings, paddings_contiguous, dst_length, false);
|
||||
}
|
||||
const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &graph, const CNodePtr &pad_node) const override {
|
||||
return ConvertBasePaddings::CreateDynPaddingsPass(graph, pad_node, false);
|
||||
}
|
||||
bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const override { return false; }
|
||||
void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const override {}
|
||||
|
@ -70,17 +89,24 @@ class ConvertPadV3GradPaddings : public ConvertBasePaddings {
|
|||
: ConvertBasePaddings("convert_pad_v3_grad_paddings", multi_graph) {}
|
||||
~ConvertPadV3GradPaddings() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
const AnfNodePtr CreatePaddingsNode(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings,
|
||||
const bool &paddings_contiguous, const size_t &dst_length,
|
||||
const TypeId &type_id) const override {
|
||||
const AnfNodePtr CreateConstPaddingsPass(const FuncGraphPtr &graph, const AbstractBasePtr &ori_paddings,
|
||||
const bool &paddings_contiguous, const size_t &dst_length,
|
||||
const TypeId &type_id) const override {
|
||||
if (type_id == kNumberTypeInt32) {
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int32_t, kNumberTypeInt32>(
|
||||
graph, ori_paddings, paddings_contiguous, dst_length, true);
|
||||
} else if (type_id == kNumberTypeInt64) {
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int64_t, kNumberTypeInt64>(
|
||||
graph, ori_paddings, paddings_contiguous, dst_length, true);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "Unsupported data type for PadV3Grad paddings input.";
|
||||
}
|
||||
return ConvertBasePaddings::OptimizePaddingsValue<int64_t, kNumberTypeInt64>(graph, ori_paddings,
|
||||
paddings_contiguous, dst_length, true);
|
||||
}
|
||||
const AnfNodePtr CreateDynPaddingsNode(const FuncGraphPtr &graph, const CNodePtr &pad_node) const override {
|
||||
return ConvertBasePaddings::CreateDynPaddingsPass(graph, pad_node, true);
|
||||
}
|
||||
bool ExpandInputXDims(const FuncGraphPtr &, const CNodePtr &) const override;
|
||||
void ReduceOutputDims(const FuncGraphPtr &, const CNodePtr &) const override;
|
||||
|
|
|
@ -60,6 +60,17 @@ void PaddingsSizeCheck(const PrimitivePtr &primitive, const int64_t paddings_siz
|
|||
constexpr int64_t nFour = 4;
|
||||
constexpr int64_t nFive = 5;
|
||||
auto prim_name = primitive->name();
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
auto is_dyn_paddings = primitive->GetAttr("is_dyn_paddings");
|
||||
if (is_dyn_paddings != nullptr && GetValue<bool>(is_dyn_paddings)) {
|
||||
if (paddings_size / nTwo != size) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings length must be equal to " << size * nTwo;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto mode = GetValue<std::string>(primitive->GetAttr("mode"));
|
||||
if (mode == kConstant) {
|
||||
if (paddings_size / nTwo > size) {
|
||||
|
@ -162,6 +173,23 @@ void CheckAscendInputXDim(const size_t &x_dim, const std::string &prim_name) {
|
|||
}
|
||||
}
|
||||
|
||||
void AscendTransformPaddingsAttr(const PrimitivePtr &primitive,
|
||||
std::vector<std::pair<int64_t, int64_t>> *ori_paddings_attr) {
|
||||
// If the `paddings` comes from the node added by pass, there are two features as followed:
|
||||
// 1. the length of `paddings` is twice than the rank of `x`.
|
||||
// 2. the mapper between `x` and `paddings` is lower to lower,
|
||||
// which is different from that in another backends, which is lower to higher.
|
||||
// So, the transform should be activated only where the `paddings` is from the node added by pass.
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
auto is_dyn_paddings = primitive->GetAttr("is_dyn_paddings");
|
||||
if (is_dyn_paddings != nullptr && GetValue<bool>(is_dyn_paddings)) {
|
||||
std::reverse(ori_paddings_attr->begin(), ori_paddings_attr->end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr PadV3InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr int64_t kEdgeMaxDims = 5;
|
||||
constexpr int64_t kOtherMinDims = 3;
|
||||
|
@ -242,6 +270,7 @@ abstract::ShapePtr PadV3InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
std::make_pair(paddings_val[LongToSize(nTwo * i)], paddings_val[LongToSize(nTwo * i + 1)]));
|
||||
}
|
||||
}
|
||||
AscendTransformPaddingsAttr(primitive, &paddings_attr);
|
||||
std::vector<int64_t> out_shape;
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
int64_t now_dim_size = x_shape[LongToSize(i)] + paddings_attr[LongToSize(size - i - 1)].first +
|
||||
|
|
|
@ -184,7 +184,8 @@ def test_padv3_constant_shape_4d(x_data_type, mode, ms_mode):
|
|||
@pytest.mark.parametrize('x_data_type', [np.int16, np.float32])
|
||||
@pytest.mark.parametrize('mode', ["constant", "edge"])
|
||||
@pytest.mark.parametrize('ms_mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode):
|
||||
@pytest.mark.parametrize('is_dyn_paddings', [True, False])
|
||||
def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode, is_dyn_paddings):
|
||||
"""
|
||||
Feature: test padv3 x and const shape paddings
|
||||
Description: test padv3 with const shape paddings
|
||||
|
@ -193,6 +194,8 @@ def test_padv3_constant_shape_5d(x_data_type, mode, ms_mode):
|
|||
context.set_context(mode=ms_mode, device_target="Ascend")
|
||||
x = Tensor(np.arange(18).reshape(1, 1, 2, 3, 3).astype(x_data_type))
|
||||
paddings = (1, 2, 1, 1, 0, 1)
|
||||
if is_dyn_paddings:
|
||||
paddings = Tensor(paddings, dtype=ms.int64)
|
||||
value = None
|
||||
if mode == "constant":
|
||||
value = 99 if x_data_type == np.int16 else 99.0
|
||||
|
|
Loading…
Reference in New Issue