!49620 dynamic rank

Merge pull request !49620 from liubuyu/master
This commit is contained in:
i-robot 2023-03-08 09:51:37 +00:00 committed by Gitee
commit 917aaf6784
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 604 additions and 70 deletions

View File

@ -551,7 +551,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
if (trans::IsNeedPadding(format, infer_shape)) {
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
}
auto dtype = GetOutputDeviceDataType(node, output_idx);
@ -570,7 +570,7 @@ ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, si
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
if (trans::IsNeedPadding(format, infer_shape)) {
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
}
auto dtype = GetOutputDeviceDataType(node, output_idx);
@ -591,7 +591,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const A
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
if (trans::IsNeedPadding(format, infer_shape)) {
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
}
auto dtype = GetInputDeviceDataType(node, input_idx);
@ -605,7 +605,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &
return infer_shape;
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
if (trans::IsNeedPadding(format, infer_shape)) {
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
}
auto dtype = GetInputDeviceDataType(node, input_idx);

View File

@ -177,6 +177,11 @@ class COMMON_EXPORT AnfAlgo {
static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
static bool IsNodeOutputDynamicShape(const AnfNodePtr &node);
static bool IsDynamicShape(const AnfNodePtr &node);
static bool IsDynamicRankNode(const AnfNodePtr &node);
static bool IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr);
static bool IsNodeOutputDynamicRank(const AnfNodePtr &node);
static bool IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
static bool IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
static bool IsCondControlKernel(const CNodePtr &node);
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);

View File

@ -1031,6 +1031,9 @@ constexpr auto kAttrKsizes = "ksizes";
constexpr auto kAttrIsKernelDynamicImpl = "is_kernel_dynamic_impl";
constexpr auto kAttrIsKernelDynamicShape = "is_kernel_dynamic_shape";
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
constexpr auto kAttrIsDynamicRank = "is_dynamic_rank";
constexpr auto kAttrInputIsDynamicRank = "input_is_dynamic_rank";
constexpr auto kAttrOutputIsDynamicRank = "output_is_dynamic_rank";
constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape";
constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape";
constexpr auto kAttrPynativeNextOpName = "next_op";
@ -1323,6 +1326,7 @@ COMMON_EXPORT bool IsOneOfComputeDepend(const std::string &name);
COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format);
COMMON_EXPORT bool IsOneOfFormat(const std::string &format);
COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format);
COMMON_EXPORT bool IsOneOfDynRankNeedPadShape(const std::string &format);
// The map between kernel's output and input ref relationship.
// Key is the output index while the value is input index which will be used as the reference of output.

View File

@ -178,11 +178,20 @@ void KernelBuildInfo::SetInputsDeviceType(const std::vector<TypeId> &inputs_devi
void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) {
if (index >= outputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output";
MS_LOG(EXCEPTION) << "The index [" << index
<< "] is exceed the length of output formats list, total size:" << outputs_format_.size();
}
outputs_format_[index] = format;
}
void KernelBuildInfo::SetInputFormat(const std::string &format, size_t index) {
if (index >= inputs_format_.size()) {
MS_LOG(EXCEPTION) << "The index [" << index
<< "] is exceed the length of input formats list, total size:" << inputs_format_.size();
}
inputs_format_[index] = format;
}
void KernelBuildInfo::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
outputs_format_ = outputs_format;
}

View File

@ -118,6 +118,8 @@ class BACKEND_EXPORT KernelBuildInfo {
void SetOutputFormat(const std::string &format, size_t index);
void SetInputFormat(const std::string &format, size_t index);
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
void SetInputsFormat(const std::vector<std::string> &inputs_format);

View File

@ -183,6 +183,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
bool is_init = false;
bool need_change_nd = false;
bool is_5d_input = false;
bool is_dyn_rank = common::AnfAlgo::IsDynamicRankNode(cnode);
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t index = 0; index < input_num; ++index) {
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
@ -194,7 +195,11 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
priority_matched_format = kOpFormat_DEFAULT;
}
auto input_shape_size = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
const auto &prev_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (IsDynamicRank(prev_shape)) {
is_dyn_rank = true;
}
auto input_shape_size = prev_shape.size();
if (input_shape_size == k5dSize) {
is_5d_input = true;
}
@ -206,6 +211,9 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
priority_matched_format = kOpFormat_NDC1HWC0;
}
if (is_dyn_rank) {
priority_matched_format = kOpFormat_ND;
}
common::AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
return priority_matched_format;
}

View File

@ -19,6 +19,7 @@
#include <string>
#include <utility>
#include <vector>
#include <set>
#include "utils/ms_context.h"
#include "backend/common/optimizer/helper.h"
#include "include/backend/anf_runtime_algorithm.h"
@ -93,8 +94,36 @@ inline void GetRangeByShape(const AnfNodePtr &anf_node, const ShapeVector &shape
}
}
ShapeVector TbeDynamicShapeUtil::UpdateShape(const AnfNodePtr &node, const std::string &format,
const ShapeVector &shape, size_t index, bool is_input,
bool *is_change_nd) {
MS_EXCEPTION_IF_NULL(node);
const std::set<std::string> op_names = {kTransDataOpName};
if (!node->isa<CNode>() || op_names.find(common::AnfAlgo::GetCNodeName(node)) == op_names.end()) {
return shape;
}
std::string sp_format = format;
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (kernel_info->select_kernel_build_info() != nullptr) {
auto in_format = AnfAlgo::GetInputFormat(node, 0);
auto out_format = AnfAlgo::GetOutputFormat(node, 0);
sp_format = IsOneOfHWSpecialFormat(in_format) ? in_format : out_format;
}
const auto &pad_idx =
is_input ? AnfAlgo::GetInputReshapeType(node, index) : AnfAlgo::GetOutputReshapeType(node, index);
if (format == kOpFormat_NCHW && shape.size() < kDim4 && IsOneOfDynRankNeedPadShape(sp_format)) {
if (is_change_nd) {
*is_change_nd = true;
}
return trans::PaddingShape(shape, sp_format, pad_idx);
}
return shape;
}
RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index,
const std::string &def_format, const TypeId &type) {
const std::string &def_format, const std::string &ori_format,
const TypeId &type) {
MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -110,13 +139,17 @@ RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node,
auto prev_node = common::AnfAlgo::GetPrevNodeOutput(anf_node, index);
MS_EXCEPTION_IF_NULL(prev_node.first);
auto shape = common::AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
if (anf_node->isa<CNode>()) {
shape = UpdateShape(anf_node, ori_format, shape, index, true);
}
GetRangeByShape(anf_node, shape, &ret);
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);
}
RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index,
const std::string &def_format, const TypeId &type) {
const std::string &def_format, const std::string &ori_format,
const TypeId &type) {
MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -130,6 +163,9 @@ RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node,
RangePair ret;
auto shape = common::AnfAlgo::GetOutputInferShape(anf_node, index);
if (anf_node->isa<CNode>()) {
shape = UpdateShape(anf_node, ori_format, shape, index, false);
}
GetRangeByShape(anf_node, shape, &ret);
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);

View File

@ -32,15 +32,17 @@ class TbeDynamicShapeUtil {
public:
TbeDynamicShapeUtil() = default;
~TbeDynamicShapeUtil() = default;
static ShapeVector UpdateShape(const AnfNodePtr &node, const std::string &format, const ShapeVector &shape,
size_t index, bool is_input, bool *is_change_nd = nullptr);
static bool GetDynamicShapeAttr(const CNodePtr &cnode);
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
static std::shared_ptr<OpInfo> FindOp(const CNodePtr &cnode);
static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
const TypeId &type);
const std::string &ori_format, const TypeId &type);
static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
const TypeId &type);
const std::string &ori_format, const TypeId &type);
};
} // namespace tbe
} // namespace kernel

View File

@ -221,10 +221,12 @@ void SingleTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetInputReshapeType(anf_node, real_input_index));
(*input_desc)[kJCValue] = infer_shape[1];
}
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true);
(*input_desc)[kJShape] = shape;
(*input_desc)[kJFormat] = format;
(*input_desc)[kJValid] = true;
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type);
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format,
(*input_desc)[kJOriFormat], d_type);
GenInputConstValue(anf_node, real_input_index, input_desc);
}
@ -236,8 +238,8 @@ void SingleTbeJsonCreator::GenOutputDescJson(const AnfNodePtr &anf_node, size_t
auto type_str = GetJsonValue<std::string>(*output_desc, kJDtype);
auto d_type = tbe::DtypeToTypeId(type_str);
(*output_desc)[kJValid] = true;
(*output_desc)[kJRange] =
tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, node_out_idx, (*output_desc)[kJFormat], d_type);
(*output_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(
anf_node, node_out_idx, (*output_desc)[kJFormat], (*output_desc)[kJOriFormat], d_type);
}
bool SingleTbeJsonCreator::AssignInputsJson(const AnfNodePtr &anf_node, const std::vector<nlohmann::json> &inputs_desc,
@ -398,6 +400,7 @@ void SelectTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_o
(*output_desc)[kJFormat] = format;
(*output_desc)[kJOriFormat] = def_format;
(*output_desc)[kJOriShape] = ori_shape;
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false);
(*output_desc)[kJShape] = shape;
(*output_desc)[kJOutputIndex] = desc_output_idx;
}
@ -422,13 +425,15 @@ void SelectTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
if (common::AnfAlgo::GetCNodeName(anf_node) == kMaxPool3DGradGradDOpName) {
(*input_desc)[kJOriFormat] = kOpFormat_NDHWC;
}
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true);
(*input_desc)[kJShape] = shape;
(*input_desc)[kJFormat] = format;
(*input_desc)[kJValid] = true;
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, real_input_index);
MS_EXCEPTION_IF_NULL(input_node_with_index.first);
if (!input_node_with_index.first->isa<ValueNode>()) {
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type);
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format,
(*input_desc)[kJOriFormat], d_type);
}
}
bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,

View File

@ -487,6 +487,7 @@ void TbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx
(*output_desc)[kJFormat] = format;
(*output_desc)[kJOriFormat] = def_format;
(*output_desc)[kJOriShape] = ori_shape;
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false);
(*output_desc)[kJShape] = shape;
(*output_desc)[kJName] = output_desc_name;
// !! Note: output_index, only node's output use it

View File

@ -83,7 +83,7 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_
return infer_shape;
}
if (trans::IsNeedPadding(format, infer_shape.size())) {
if (trans::IsNeedPadding(format, infer_shape)) {
auto reshape_type =
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);

View File

@ -36,6 +36,23 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
struct CreateNodeArgs {
FuncGraphPtr func_graph{nullptr};
AnfNodePtr node{nullptr};
AnfNodePtr input_node{nullptr};
AnfNodePtr orig_node{nullptr};
KernelSelectPtr kernel_select{nullptr};
std::string trans_opname;
std::string input_format;
std::string dst_format;
std::string spec_format;
std::string reshape_type;
TypeId type_id;
ShapeVector out_shape;
bool is_dynamic_shape;
bool need_padding;
};
std::string GetTransOpName(const std::string &spec_format) {
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
? prim::kPrimTransDataRNN->name()
@ -83,6 +100,61 @@ CNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inp
return reshape;
}
AnfNodePtr CreateTransDataWithOutReshape(const CreateNodeArgs &args) {
// don't need padding insert TransData only
auto trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select,
args.need_padding, args.trans_opname);
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
args.type_id);
return trans_data;
}
AnfNodePtr CreateTransDataWithReshape(const CreateNodeArgs &args) {
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
if (!args.need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding,
args.trans_opname);
trans_node = trans_data;
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
args.type_id);
} else if (args.spec_format == args.dst_format) {
// if need padding & default to special format
// ori_shape -> reshape[padding shape] -> transdata[device shape]
auto padding_shape = trans::PaddingShape(args.out_shape, args.dst_format, args.reshape_type, args.node);
std::vector<int> padding_axis;
if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) {
padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node);
}
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(padding_shape);
auto reshape_node = CreateReshapeNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select,
pad_shape_ptr, args.is_dynamic_shape, padding_axis);
trans_data = NewTransOpNode(args.func_graph, reshape_node, args.orig_node, args.kernel_select, args.need_padding,
args.trans_opname);
trans_node = trans_data;
trans_data->set_abstract(args.input_node->abstract());
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
args.type_id);
} else {
// if need padding & special to default format
// device shape -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding,
args.trans_opname);
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
args.type_id);
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(args.out_shape);
std::vector<int> padding_axis;
if (std::count(args.out_shape.begin(), args.out_shape.end(), -1) > 1) {
padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node);
}
auto reshape_node = CreateReshapeNode(args.func_graph, trans_data, args.orig_node, args.kernel_select,
pad_shape_ptr, args.is_dynamic_shape, padding_axis);
trans_node = reshape_node;
}
return trans_node;
}
void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(trans_node);
MS_EXCEPTION_IF_NULL(node);
@ -136,6 +208,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
MS_LOG(DEBUG)
<< "ValueNode or Parameter has no inputs, try to insert for ValueNode or Parameter at out anchor, node: "
<< node->fullname_with_scope();
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
common::AnfAlgo::SetNodeInput(node, input_node, index);
@ -143,8 +218,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
ShapeVector origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (NeedInsertTransData(origin_shape, dest_format)) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
MS_LOG(DEBUG) << "Need insert TransData change format from [" << dest_format
<< "] to [DefaultFormat], input index:" << index << ", node: " << node->fullname_with_scope();
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
if (real_input->isa<Parameter>()) {
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
@ -165,7 +240,8 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
<< node->DebugString() << trace::DumpSourceLines(node);
}
if (NeedInsertTransData(origin_shape, output_format)) {
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
MS_LOG(DEBUG) << "Inserted TransData change format from [" << output_format
<< "] to [DefaultFormat], single output index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
}
return node;
@ -243,61 +319,32 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(out_shape_ptr);
ShapeVector out_shape = out_shape_ptr->shape();
auto is_dyn_rank = out_shape_ptr->IsDimUnknown();
auto is_dynamic_shape = out_shape_ptr->IsDynamic();
bool need_padding = trans::IsNeedPadding(spec_format, out_shape.size());
bool need_padding = trans::IsNeedPadding(spec_format, out_shape);
std::string trans_opname = GetTransOpName(spec_format);
bool is_insert_output = node == input_node;
auto orig_node = GetOriginNode(is_insert_output, node);
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data;
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
} else if (spec_format == dst_format) {
// if need padding & default to special format
// ori_shape -> reshape[padding shape] -> transdata[device shape]
auto padding_shape = trans::PaddingShape(out_shape, dst_format, reshape_type, node);
std::vector<int> padding_axis;
if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) {
padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node);
}
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(padding_shape);
auto reshape_node = CreateReshapeNode(func_graph, input_node, orig_node, kernel_select, pad_shape_ptr,
is_dynamic_shape, padding_axis);
trans_data = NewTransOpNode(func_graph, reshape_node, orig_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data;
trans_data->set_abstract(input_node->abstract());
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
AnfNodePtr trans_data = nullptr;
CreateNodeArgs args = {func_graph, node, input_node, orig_node, kernel_select,
trans_opname, input_format, dst_format, spec_format, reshape_type,
type_id, out_shape, is_dynamic_shape, need_padding};
if (is_dyn_rank) {
trans_data = CreateTransDataWithOutReshape(args);
} else {
// if need padding & special to default format
// device shape -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(out_shape);
std::vector<int> padding_axis;
if (std::count(out_shape.begin(), out_shape.end(), -1) > 1) {
padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node);
}
auto reshape_node = CreateReshapeNode(func_graph, trans_data, orig_node, kernel_select, pad_shape_ptr,
is_dynamic_shape, padding_axis);
trans_node = reshape_node;
trans_data = CreateTransDataWithReshape(args);
}
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), trans_data);
}
if (is_insert_output) {
ReFreshInferShape(trans_node, node);
ReFreshInferShape(trans_data, node);
}
return trans_node;
return trans_data;
}
void RefreshKernelBuildInfo(const KernelSelectPtr &kernel_select, const std::string &input_format,
@ -368,15 +415,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
MS_EXCEPTION_IF_NULL(out_shape_base);
ShapeVector out_shape;
bool is_dyn_rank = false;
bool is_dynamic_shape = false;
if (out_shape_base->isa<abstract::Shape>()) {
auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(out_shape_ptr);
out_shape = out_shape_ptr->shape();
is_dynamic_shape = out_shape_ptr->IsDynamic();
is_dyn_rank = out_shape_ptr->IsDimUnknown();
}
if (need_padding) {
if (need_padding && !is_dyn_rank) {
// if need padding we should set the transdata node's shape to the padding shape
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
@ -495,6 +544,7 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
size_t in_num = common::AnfAlgo::GetInputNum(cnode); // include monads.
MS_LOG(DEBUG) << "Try to insert TransData at input anchor for node: " << cnode->fullname_with_scope();
for (size_t input_index = 0; input_index < in_num; ++input_index) {
// Monad inputs keep unchanged from GetTransInputNodePtr().
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);

View File

@ -247,13 +247,16 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
}
}
bool IsNeedPadding(const std::string &format, size_t shape_size) {
if (shape_size == 0) {
bool IsNeedPadding(const std::string &format, const ShapeVector &shape) {
if (shape.size() == 0) {
return false;
}
if (IsDynamicRank(shape) && !IsOneOfDynRankNeedPadShape(format)) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || IsOneOfNoPaddingFormat(format)) {
return false;
} else if (shape_size < kDim4) {
} else if (shape.size() < kDim4) {
return true;
}
return false;
@ -288,7 +291,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
host_shape = common::AnfAlgo::GetOutputInferShape(node, index);
}
auto format = AnfAlgo::GetOutputFormat(node, index);
if (IsNeedPadding(format, host_shape.size())) {
if (IsNeedPadding(format, host_shape)) {
host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index), node);
}
return host_shape;

View File

@ -39,6 +39,7 @@
#include "utils/log_adapter.h"
#include "include/common/utils/utils.h"
#include "include/backend/visible.h"
#include "mindapi/base/shape_vector.h"
namespace mindspore {
namespace trans {
@ -329,7 +330,7 @@ BACKEND_EXPORT ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t
/**
* If need padding
* */
BACKEND_EXPORT bool IsNeedPadding(const std::string &format, size_t shape_size);
BACKEND_EXPORT bool IsNeedPadding(const std::string &format, const ShapeVector &shape);
/**
* Padding shape to 5D by default mode
@ -442,16 +443,21 @@ std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &form
MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format
<< ", detail info: " << node->DebugString();
}
std::vector<T> host_shape;
if (IsOneOf3DFormat(format)) {
if (shape.size() >= kDim5) {
return shape;
}
host_shape = PaddingShapeTo5d(shape, pad_index);
} else {
host_shape = PaddingShapeTo4d(shape, pad_index);
if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
return {-1, -1, -1, -1, -1};
}
return PaddingShapeTo5d(shape, pad_index);
}
return host_shape;
if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
return {-1, -1, -1, -1};
}
return PaddingShapeTo4d(shape, pad_index);
}
/**

View File

@ -1382,6 +1382,93 @@ bool AnfAlgo::HasDynamicShapeFlag(const PrimitivePtr &prim) {
return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape);
}
bool IsNodeDynamicRank(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node is not a cnode";
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto in_dyn_rank = AnfAlgo::IsNodeInputDynamicRank(cnode);
auto out_dyn_rank = AnfAlgo::IsNodeOutputDynamicRank(cnode);
if (in_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicRank, cnode)) {
AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicRank, MakeValue(true), cnode);
MS_LOG(DEBUG) << "Set input dynamic rank attr for node:" << cnode->fullname_with_scope();
}
if (out_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicRank, cnode)) {
AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicRank, MakeValue(true), cnode);
MS_LOG(DEBUG) << "Set output dynamic rank attr for node:" << cnode->fullname_with_scope();
}
return in_dyn_rank || out_dyn_rank;
}
bool AnfAlgo::IsDynamicRankNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>()) {
return IsOutputAnchorDynamicRank(node, 0);
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if ((!HasNodeAttr(kAttrInputIsDynamicRank, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicRank, cnode))) {
auto ret = IsNodeDynamicRank(node);
MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic rank: [" << ret << "]";
return ret;
}
return GetBooleanAttr(node, kAttrInputIsDynamicRank) || GetBooleanAttr(node, kAttrOutputIsDynamicRank) ||
GetBooleanAttr(node, kAttrIsDynamicRank);
}
bool AnfAlgo::IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has inputs, node: " << node->fullname_with_scope();
}
const auto &in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, idx);
if (mindspore::IsDynamicRank(in_shape)) {
return true;
}
return false;
}
bool AnfAlgo::IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
MS_EXCEPTION_IF_NULL(node);
const auto &out_shape = common::AnfAlgo::GetOutputInferShape(node, idx);
if (mindspore::IsDynamicRank(out_shape)) {
return true;
}
return false;
}
bool AnfAlgo::IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
const auto &inputs = anf_node_ptr->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
if (IsNodeOutputDynamicRank(input)) {
return true;
}
}
return false;
}
bool AnfAlgo::IsNodeOutputDynamicRank(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::DynamicSequenceShape>()) {
auto b_ptr = base_shape->cast<abstract::DynamicSequenceShapePtr>();
if (b_ptr->IsDimUnknown()) {
return true;
}
}
return base_shape->IsDimUnknown();
}
bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {

View File

@ -211,4 +211,11 @@ bool IsOneOfServerFormatC04(const std::string &format) {
static const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
}
bool IsOneOfDynRankNeedPadShape(const std::string &format) {
const std::set<std::string> kOpFormats = {kOpFormat_NC1HWC0, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z,
kOpFormat_NDC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_3D, kOpFormat_FRACTAL_Z_C04, kOpFormat_NCDHW};
return kOpFormats.find(format) != kOpFormats.end();
}
} // namespace mindspore

View File

@ -244,6 +244,7 @@ class MS_CORE_API DynamicSequenceShape : public BaseShape {
// element's shape
BaseShapePtr element_shape_{nullptr};
};
using DynamicSequenceShapePtr = std::shared_ptr<DynamicSequenceShape>;
GVAR_DEF(std::shared_ptr<DynamicSequenceShape>, kDynamicSequenceShape, std::make_shared<DynamicSequenceShape>());
/// \brief SequequeShape defines base class of multiple-shape classes.

View File

@ -0,0 +1,111 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.ops.operations as ops
from mindspore.nn import Cell
from mindspore import Tensor
from mindspore import context
np.random.seed(3)
class MSDynRankNet(Cell):
def __init__(self, is_training=True):
super(MSDynRankNet, self).__init__()
self.is_training = is_training
self.batch_norm = ops.BatchNorm(is_training=self.is_training)
self.reduce_mean = ops.ReduceMean(keep_dims=False)
self.relu = ops.ReLU()
def construct(self, input_x, scale, offset, mean, variance, indices):
unique_indices = self.relu(indices)
reduced_in = self.reduce_mean(input_x, unique_indices)
reduced_scale = self.reduce_mean(scale, unique_indices)
reduced_offset = self.reduce_mean(offset, unique_indices)
reduced_mean = self.reduce_mean(mean, unique_indices)
reduced_variance = self.reduce_mean(variance, unique_indices)
out, _, _, _, _ = self.batch_norm(reduced_in, reduced_scale, reduced_offset, reduced_mean, reduced_variance)
return out
class NetFactory:
def __init__(self, x, scale, offset, mean, variance, indices, dtype=np.float32, is_training=False):
super(NetFactory, self).__init__()
self.x = x
self.scale = scale
self.offset = offset
self.mean = mean
self.variance = variance
self.indices = indices
self.dtype = dtype
self.is_training = is_training
self.nh2nc = [0, 3, 1, 2]
self.nc2nh = [0, 2, 3, 1]
def mindspore_case(self):
ms_x = Tensor(self.x)
ms_indices = Tensor(self.indices)
ms_scale = Tensor(self.scale)
ms_offset = Tensor(self.offset)
ms_mean = Tensor(self.mean)
ms_variance = Tensor(self.variance)
ms_dyn_x = Tensor(shape=[None for _ in ms_x.shape], dtype=ms_x.dtype)
ms_dyn_scale = Tensor(shape=[None for _ in ms_scale.shape], dtype=ms_scale.dtype)
ms_dyn_offset = Tensor(shape=[None for _ in ms_offset.shape], dtype=ms_offset.dtype)
ms_dyn_mean = Tensor(shape=[None for _ in ms_mean.shape], dtype=ms_mean.dtype)
ms_dyn_variance = Tensor(shape=[None for _ in ms_variance.shape], dtype=ms_variance.dtype)
ms_dyn_indices = Tensor(shape=[None], dtype=ms_indices.dtype)
ms_net = MSDynRankNet(is_training=self.is_training)
ms_net.set_inputs(ms_dyn_x, ms_dyn_scale, ms_dyn_offset, ms_dyn_mean, ms_dyn_variance, ms_dyn_indices)
ms_out = ms_net(ms_x, ms_scale, ms_offset, ms_mean, ms_variance, ms_indices)
return ms_out.asnumpy()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_batch_norm_dynamic_rank():
"""
Feature: test batch norm dynamic rank
Description: test batch norm dynamic rank with input tensor's type float32
Expectation: none.
"""
input_x = np.random.randn(3, 3, 4, 3, 3).astype(np.float32)
scale_ = np.ones((4, 4)).astype(np.float32)
offset_ = np.ones((4, 4)).astype(np.float32)
mean_ = np.ones((4, 4)).astype(np.float32)
variance_ = np.ones((4, 4)).astype(np.float32)
indices_ = np.unique(np.random.randint(1, 2, (1,)).astype(np.int32))
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
graph_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_, indices=indices_,
dtype=np.float32)
graph_mode_out = graph_mode_net.mindspore_case()
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
pynative_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_,
indices=indices_,
dtype=np.float32)
pynative_mode_out = pynative_mode_net.mindspore_case()
assert np.allclose(pynative_mode_out, graph_mode_out, 1e-4, 1e-4)

View File

@ -0,0 +1,99 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import torch
from mindspore import Tensor
from mindspore import context
from mindspore.nn import Cell
from mindspore.ops import operations as P
np.random.seed(3)
context.set_context(mode=context.GRAPH_MODE)
class MSBiasAddDynRankNet(Cell):
def __init__(self, data_format="NCHW"):
super(MSBiasAddDynRankNet, self).__init__()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.bias_add = P.BiasAdd(data_format=data_format)
self.relu = P.ReLU()
def construct(self, input_a, input_b, indices):
relu_indices = self.relu(indices)
reduce_a = self.reduce_sum(input_a, relu_indices)
out = self.bias_add(reduce_a, input_b)
return out
class TorchAddNet(torch.nn.Module):
def __init__(self):
super(TorchAddNet, self).__init__()
self.keep_dims = False
def forward(self, input_a, input_b, indices):
relu_indices = torch.relu(indices)
reduce_a = torch.sum(input_a, relu_indices.tolist(), keepdim=self.keep_dims)
out = torch.add(reduce_a, input_b)
return out
class BiasAddOpFactory:
def __init__(self, in_shape, indices, dtype=np.float32, data_format="NCHW"):
super(BiasAddOpFactory, self).__init__()
self.dtype = dtype
self.input_x = np.random.randn(*in_shape).astype(self.dtype)
self.data_format = data_format
self.indices = indices
self.input_b = np.random.randn(in_shape[-1]).astype(self.dtype)
self.loss = 1e-4
def ms_biass_add_forward(self):
a = Tensor(self.input_x)
b = Tensor(self.input_b)
indices = Tensor(self.indices)
dyn_a = Tensor(shape=[None for _ in a.shape], dtype=a.dtype)
dyn_b = Tensor(shape=[None for _ in b.shape], dtype=b.dtype)
dyn_indices = Tensor(shape=[None], dtype=indices.dtype)
ms_net = MSBiasAddDynRankNet(data_format=self.data_format)
ms_net.set_inputs(dyn_a, dyn_b, dyn_indices)
out = ms_net(a, b, indices)
return out.asnumpy()
def torch_bias_add_forward(self):
torch_net = TorchAddNet()
out = torch_net(torch.from_numpy(self.input_x), torch.from_numpy(self.input_b), torch.from_numpy(self.indices))
return out.detach().numpy()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bias_add_dyn_rank():
"""
Feature: test bias add dynamic rank
Description: test bias add dynamic rank with input tensor's type float32
Expectation: none.
"""
in_shape = (16, 16, 16, 16, 16)
indices_np = np.unique(np.random.randint(0, 2, size=3).astype(np.int32))
factory = BiasAddOpFactory(in_shape=in_shape, indices=indices_np, dtype=np.float32, data_format="NCHW")
ms_out = factory.ms_biass_add_forward()
torch_out = factory.torch_bias_add_forward()
np.allclose(ms_out, torch_out, factory.loss, factory.loss)

View File

@ -0,0 +1,98 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import torch
from mindspore import Tensor
from mindspore import context
from mindspore.nn import Cell
from mindspore.ops import operations as P
np.random.seed(3)
context.set_context(mode=context.GRAPH_MODE)
class MSReduceSumNet(Cell):
def __init__(self, keep_dims=False):
super(MSReduceSumNet, self).__init__()
self.reduce_sum = P.ReduceSum(keep_dims=keep_dims)
self.reduce = P.ReduceSum(keep_dims=False)
def construct(self, x, indices, axis):
x = self.reduce(x, axis)
return self.reduce_sum(x, indices)
class TorchReduceSumNet(torch.nn.Module):
def __init__(self, keep_dims=False):
super(TorchReduceSumNet, self).__init__()
self.keep_dims = keep_dims
def forward(self, input_x, indices, axis):
x = torch.sum(input_x, axis.tolist(), False)
out = torch.sum(x, indices.tolist(), self.keep_dims)
return out
class ReduceOpFactory:
def __init__(self, input_x, indices, axis, keep_dims, dtype=np.float32, loos=1e-4):
super(ReduceOpFactory, self).__init__()
self.out_grad = None
self.input_x = input_x
self.indices = indices
self.axis = axis
self.keep_dims = keep_dims
self.dtype = dtype
self.loss = loos
def ms_reduce_sum_forward(self):
net = MSReduceSumNet(self.keep_dims)
in_tensor = Tensor(self.input_x)
in_indices = Tensor(self.indices)
in_axis = Tensor(self.axis)
in_tensor_dyn = Tensor(shape=[None for _ in in_tensor.shape], dtype=in_tensor.dtype)
in_indices_dyn = Tensor(shape=[None for _ in in_indices.shape], dtype=in_indices.dtype)
in_axis_dyn = Tensor(shape=[None for _ in in_axis.shape], dtype=in_axis.dtype)
net.set_inputs(in_tensor_dyn, in_indices_dyn, in_axis_dyn)
out = net(in_tensor, in_indices, in_axis)
return out.asnumpy()
def torch_reduce_sum_forward(self):
net = TorchReduceSumNet(self.keep_dims)
out = net(torch.from_numpy(self.input_x.astype(self.dtype)), self.indices, self.axis)
return out.detach().numpy()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_reduce_sum_dyn_rank():
"""
Feature: test reduce sum dynamic rank
Description: test reduce sum dynamic rank with input tensor's type float32
Expectation: none.
"""
dtype = np.float32
x = np.random.randn(22, 20, 28, 36, 24, 23).astype(dtype)
indices = np.array([0, -1])
axis = np.unique(np.random.randint(0, 2, size=5).astype(np.int32))
factory = ReduceOpFactory(x, indices, axis, keep_dims=True, dtype=dtype, loos=1e-4)
ms_data = factory.ms_reduce_sum_forward()
torch_data = factory.torch_reduce_sum_forward()
np.allclose(torch_data, ms_data, factory.loss, factory.loss)