commit
917aaf6784
|
@ -551,7 +551,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const
|
||||||
}
|
}
|
||||||
|
|
||||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape)) {
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
|
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
|
||||||
}
|
}
|
||||||
auto dtype = GetOutputDeviceDataType(node, output_idx);
|
auto dtype = GetOutputDeviceDataType(node, output_idx);
|
||||||
|
@ -570,7 +570,7 @@ ShapeVector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, si
|
||||||
}
|
}
|
||||||
|
|
||||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape)) {
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
|
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx), node);
|
||||||
}
|
}
|
||||||
auto dtype = GetOutputDeviceDataType(node, output_idx);
|
auto dtype = GetOutputDeviceDataType(node, output_idx);
|
||||||
|
@ -591,7 +591,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShapeForTbeBuild(const A
|
||||||
}
|
}
|
||||||
|
|
||||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape)) {
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
|
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
|
||||||
}
|
}
|
||||||
auto dtype = GetInputDeviceDataType(node, input_idx);
|
auto dtype = GetInputDeviceDataType(node, input_idx);
|
||||||
|
@ -605,7 +605,7 @@ std::vector<int64_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &
|
||||||
return infer_shape;
|
return infer_shape;
|
||||||
}
|
}
|
||||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape)) {
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
|
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx), node);
|
||||||
}
|
}
|
||||||
auto dtype = GetInputDeviceDataType(node, input_idx);
|
auto dtype = GetInputDeviceDataType(node, input_idx);
|
||||||
|
|
|
@ -177,6 +177,11 @@ class COMMON_EXPORT AnfAlgo {
|
||||||
static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
|
static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
|
||||||
static bool IsNodeOutputDynamicShape(const AnfNodePtr &node);
|
static bool IsNodeOutputDynamicShape(const AnfNodePtr &node);
|
||||||
static bool IsDynamicShape(const AnfNodePtr &node);
|
static bool IsDynamicShape(const AnfNodePtr &node);
|
||||||
|
static bool IsDynamicRankNode(const AnfNodePtr &node);
|
||||||
|
static bool IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr);
|
||||||
|
static bool IsNodeOutputDynamicRank(const AnfNodePtr &node);
|
||||||
|
static bool IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
|
||||||
|
static bool IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
|
||||||
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
|
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
|
||||||
static bool IsCondControlKernel(const CNodePtr &node);
|
static bool IsCondControlKernel(const CNodePtr &node);
|
||||||
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
||||||
|
|
|
@ -1031,6 +1031,9 @@ constexpr auto kAttrKsizes = "ksizes";
|
||||||
constexpr auto kAttrIsKernelDynamicImpl = "is_kernel_dynamic_impl";
|
constexpr auto kAttrIsKernelDynamicImpl = "is_kernel_dynamic_impl";
|
||||||
constexpr auto kAttrIsKernelDynamicShape = "is_kernel_dynamic_shape";
|
constexpr auto kAttrIsKernelDynamicShape = "is_kernel_dynamic_shape";
|
||||||
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
|
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
|
||||||
|
constexpr auto kAttrIsDynamicRank = "is_dynamic_rank";
|
||||||
|
constexpr auto kAttrInputIsDynamicRank = "input_is_dynamic_rank";
|
||||||
|
constexpr auto kAttrOutputIsDynamicRank = "output_is_dynamic_rank";
|
||||||
constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape";
|
constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape";
|
||||||
constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape";
|
constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape";
|
||||||
constexpr auto kAttrPynativeNextOpName = "next_op";
|
constexpr auto kAttrPynativeNextOpName = "next_op";
|
||||||
|
@ -1323,6 +1326,7 @@ COMMON_EXPORT bool IsOneOfComputeDepend(const std::string &name);
|
||||||
COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format);
|
COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format);
|
||||||
COMMON_EXPORT bool IsOneOfFormat(const std::string &format);
|
COMMON_EXPORT bool IsOneOfFormat(const std::string &format);
|
||||||
COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format);
|
COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format);
|
||||||
|
COMMON_EXPORT bool IsOneOfDynRankNeedPadShape(const std::string &format);
|
||||||
|
|
||||||
// The map between kernel's output and input ref relationship.
|
// The map between kernel's output and input ref relationship.
|
||||||
// Key is the output index while the value is input index which will be used as the reference of output.
|
// Key is the output index while the value is input index which will be used as the reference of output.
|
||||||
|
|
|
@ -178,11 +178,20 @@ void KernelBuildInfo::SetInputsDeviceType(const std::vector<TypeId> &inputs_devi
|
||||||
|
|
||||||
void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) {
|
void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) {
|
||||||
if (index >= outputs_format_.size()) {
|
if (index >= outputs_format_.size()) {
|
||||||
MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output";
|
MS_LOG(EXCEPTION) << "The index [" << index
|
||||||
|
<< "] is exceed the length of output formats list, total size:" << outputs_format_.size();
|
||||||
}
|
}
|
||||||
outputs_format_[index] = format;
|
outputs_format_[index] = format;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelBuildInfo::SetInputFormat(const std::string &format, size_t index) {
|
||||||
|
if (index >= inputs_format_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The index [" << index
|
||||||
|
<< "] is exceed the length of input formats list, total size:" << inputs_format_.size();
|
||||||
|
}
|
||||||
|
inputs_format_[index] = format;
|
||||||
|
}
|
||||||
|
|
||||||
void KernelBuildInfo::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
|
void KernelBuildInfo::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
|
||||||
outputs_format_ = outputs_format;
|
outputs_format_ = outputs_format;
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,6 +118,8 @@ class BACKEND_EXPORT KernelBuildInfo {
|
||||||
|
|
||||||
void SetOutputFormat(const std::string &format, size_t index);
|
void SetOutputFormat(const std::string &format, size_t index);
|
||||||
|
|
||||||
|
void SetInputFormat(const std::string &format, size_t index);
|
||||||
|
|
||||||
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
|
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
|
||||||
|
|
||||||
void SetInputsFormat(const std::vector<std::string> &inputs_format);
|
void SetInputsFormat(const std::vector<std::string> &inputs_format);
|
||||||
|
|
|
@ -183,6 +183,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
||||||
bool is_init = false;
|
bool is_init = false;
|
||||||
bool need_change_nd = false;
|
bool need_change_nd = false;
|
||||||
bool is_5d_input = false;
|
bool is_5d_input = false;
|
||||||
|
bool is_dyn_rank = common::AnfAlgo::IsDynamicRankNode(cnode);
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||||
for (size_t index = 0; index < input_num; ++index) {
|
for (size_t index = 0; index < input_num; ++index) {
|
||||||
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
||||||
|
@ -194,7 +195,11 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
||||||
if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
|
if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
|
||||||
priority_matched_format = kOpFormat_DEFAULT;
|
priority_matched_format = kOpFormat_DEFAULT;
|
||||||
}
|
}
|
||||||
auto input_shape_size = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
|
const auto &prev_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||||
|
if (IsDynamicRank(prev_shape)) {
|
||||||
|
is_dyn_rank = true;
|
||||||
|
}
|
||||||
|
auto input_shape_size = prev_shape.size();
|
||||||
if (input_shape_size == k5dSize) {
|
if (input_shape_size == k5dSize) {
|
||||||
is_5d_input = true;
|
is_5d_input = true;
|
||||||
}
|
}
|
||||||
|
@ -206,6 +211,9 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
||||||
if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
|
if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
|
||||||
priority_matched_format = kOpFormat_NDC1HWC0;
|
priority_matched_format = kOpFormat_NDC1HWC0;
|
||||||
}
|
}
|
||||||
|
if (is_dyn_rank) {
|
||||||
|
priority_matched_format = kOpFormat_ND;
|
||||||
|
}
|
||||||
common::AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
|
common::AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
|
||||||
return priority_matched_format;
|
return priority_matched_format;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
#include "include/backend/anf_runtime_algorithm.h"
|
#include "include/backend/anf_runtime_algorithm.h"
|
||||||
|
@ -93,8 +94,36 @@ inline void GetRangeByShape(const AnfNodePtr &anf_node, const ShapeVector &shape
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapeVector TbeDynamicShapeUtil::UpdateShape(const AnfNodePtr &node, const std::string &format,
|
||||||
|
const ShapeVector &shape, size_t index, bool is_input,
|
||||||
|
bool *is_change_nd) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
const std::set<std::string> op_names = {kTransDataOpName};
|
||||||
|
if (!node->isa<CNode>() || op_names.find(common::AnfAlgo::GetCNodeName(node)) == op_names.end()) {
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
std::string sp_format = format;
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||||
|
if (kernel_info->select_kernel_build_info() != nullptr) {
|
||||||
|
auto in_format = AnfAlgo::GetInputFormat(node, 0);
|
||||||
|
auto out_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||||
|
sp_format = IsOneOfHWSpecialFormat(in_format) ? in_format : out_format;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &pad_idx =
|
||||||
|
is_input ? AnfAlgo::GetInputReshapeType(node, index) : AnfAlgo::GetOutputReshapeType(node, index);
|
||||||
|
if (format == kOpFormat_NCHW && shape.size() < kDim4 && IsOneOfDynRankNeedPadShape(sp_format)) {
|
||||||
|
if (is_change_nd) {
|
||||||
|
*is_change_nd = true;
|
||||||
|
}
|
||||||
|
return trans::PaddingShape(shape, sp_format, pad_idx);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index,
|
RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index,
|
||||||
const std::string &def_format, const TypeId &type) {
|
const std::string &def_format, const std::string &ori_format,
|
||||||
|
const TypeId &type) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
@ -110,13 +139,17 @@ RangePair TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node,
|
||||||
auto prev_node = common::AnfAlgo::GetPrevNodeOutput(anf_node, index);
|
auto prev_node = common::AnfAlgo::GetPrevNodeOutput(anf_node, index);
|
||||||
MS_EXCEPTION_IF_NULL(prev_node.first);
|
MS_EXCEPTION_IF_NULL(prev_node.first);
|
||||||
auto shape = common::AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
auto shape = common::AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
||||||
|
if (anf_node->isa<CNode>()) {
|
||||||
|
shape = UpdateShape(anf_node, ori_format, shape, index, true);
|
||||||
|
}
|
||||||
GetRangeByShape(anf_node, shape, &ret);
|
GetRangeByShape(anf_node, shape, &ret);
|
||||||
|
|
||||||
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);
|
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index,
|
RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index,
|
||||||
const std::string &def_format, const TypeId &type) {
|
const std::string &def_format, const std::string &ori_format,
|
||||||
|
const TypeId &type) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
@ -130,6 +163,9 @@ RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node,
|
||||||
RangePair ret;
|
RangePair ret;
|
||||||
|
|
||||||
auto shape = common::AnfAlgo::GetOutputInferShape(anf_node, index);
|
auto shape = common::AnfAlgo::GetOutputInferShape(anf_node, index);
|
||||||
|
if (anf_node->isa<CNode>()) {
|
||||||
|
shape = UpdateShape(anf_node, ori_format, shape, index, false);
|
||||||
|
}
|
||||||
GetRangeByShape(anf_node, shape, &ret);
|
GetRangeByShape(anf_node, shape, &ret);
|
||||||
|
|
||||||
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);
|
return shapeRangeTransfer.GetRealRange(ret, format, data_type, reshape_type);
|
||||||
|
|
|
@ -32,15 +32,17 @@ class TbeDynamicShapeUtil {
|
||||||
public:
|
public:
|
||||||
TbeDynamicShapeUtil() = default;
|
TbeDynamicShapeUtil() = default;
|
||||||
~TbeDynamicShapeUtil() = default;
|
~TbeDynamicShapeUtil() = default;
|
||||||
|
static ShapeVector UpdateShape(const AnfNodePtr &node, const std::string &format, const ShapeVector &shape,
|
||||||
|
size_t index, bool is_input, bool *is_change_nd = nullptr);
|
||||||
static bool GetDynamicShapeAttr(const CNodePtr &cnode);
|
static bool GetDynamicShapeAttr(const CNodePtr &cnode);
|
||||||
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
|
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
|
||||||
static std::shared_ptr<OpInfo> FindOp(const CNodePtr &cnode);
|
static std::shared_ptr<OpInfo> FindOp(const CNodePtr &cnode);
|
||||||
static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
||||||
const TypeId &type);
|
const std::string &ori_format, const TypeId &type);
|
||||||
static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
||||||
const TypeId &type);
|
const std::string &ori_format, const TypeId &type);
|
||||||
};
|
};
|
||||||
} // namespace tbe
|
} // namespace tbe
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -221,10 +221,12 @@ void SingleTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetInputReshapeType(anf_node, real_input_index));
|
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetInputReshapeType(anf_node, real_input_index));
|
||||||
(*input_desc)[kJCValue] = infer_shape[1];
|
(*input_desc)[kJCValue] = infer_shape[1];
|
||||||
}
|
}
|
||||||
|
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true);
|
||||||
(*input_desc)[kJShape] = shape;
|
(*input_desc)[kJShape] = shape;
|
||||||
(*input_desc)[kJFormat] = format;
|
(*input_desc)[kJFormat] = format;
|
||||||
(*input_desc)[kJValid] = true;
|
(*input_desc)[kJValid] = true;
|
||||||
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type);
|
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format,
|
||||||
|
(*input_desc)[kJOriFormat], d_type);
|
||||||
GenInputConstValue(anf_node, real_input_index, input_desc);
|
GenInputConstValue(anf_node, real_input_index, input_desc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,8 +238,8 @@ void SingleTbeJsonCreator::GenOutputDescJson(const AnfNodePtr &anf_node, size_t
|
||||||
auto type_str = GetJsonValue<std::string>(*output_desc, kJDtype);
|
auto type_str = GetJsonValue<std::string>(*output_desc, kJDtype);
|
||||||
auto d_type = tbe::DtypeToTypeId(type_str);
|
auto d_type = tbe::DtypeToTypeId(type_str);
|
||||||
(*output_desc)[kJValid] = true;
|
(*output_desc)[kJValid] = true;
|
||||||
(*output_desc)[kJRange] =
|
(*output_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(
|
||||||
tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, node_out_idx, (*output_desc)[kJFormat], d_type);
|
anf_node, node_out_idx, (*output_desc)[kJFormat], (*output_desc)[kJOriFormat], d_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SingleTbeJsonCreator::AssignInputsJson(const AnfNodePtr &anf_node, const std::vector<nlohmann::json> &inputs_desc,
|
bool SingleTbeJsonCreator::AssignInputsJson(const AnfNodePtr &anf_node, const std::vector<nlohmann::json> &inputs_desc,
|
||||||
|
@ -398,6 +400,7 @@ void SelectTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_o
|
||||||
(*output_desc)[kJFormat] = format;
|
(*output_desc)[kJFormat] = format;
|
||||||
(*output_desc)[kJOriFormat] = def_format;
|
(*output_desc)[kJOriFormat] = def_format;
|
||||||
(*output_desc)[kJOriShape] = ori_shape;
|
(*output_desc)[kJOriShape] = ori_shape;
|
||||||
|
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false);
|
||||||
(*output_desc)[kJShape] = shape;
|
(*output_desc)[kJShape] = shape;
|
||||||
(*output_desc)[kJOutputIndex] = desc_output_idx;
|
(*output_desc)[kJOutputIndex] = desc_output_idx;
|
||||||
}
|
}
|
||||||
|
@ -422,13 +425,15 @@ void SelectTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
|
||||||
if (common::AnfAlgo::GetCNodeName(anf_node) == kMaxPool3DGradGradDOpName) {
|
if (common::AnfAlgo::GetCNodeName(anf_node) == kMaxPool3DGradGradDOpName) {
|
||||||
(*input_desc)[kJOriFormat] = kOpFormat_NDHWC;
|
(*input_desc)[kJOriFormat] = kOpFormat_NDHWC;
|
||||||
}
|
}
|
||||||
|
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*input_desc)[kJOriFormat], shape, real_input_index, true);
|
||||||
(*input_desc)[kJShape] = shape;
|
(*input_desc)[kJShape] = shape;
|
||||||
(*input_desc)[kJFormat] = format;
|
(*input_desc)[kJFormat] = format;
|
||||||
(*input_desc)[kJValid] = true;
|
(*input_desc)[kJValid] = true;
|
||||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, real_input_index);
|
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, real_input_index);
|
||||||
MS_EXCEPTION_IF_NULL(input_node_with_index.first);
|
MS_EXCEPTION_IF_NULL(input_node_with_index.first);
|
||||||
if (!input_node_with_index.first->isa<ValueNode>()) {
|
if (!input_node_with_index.first->isa<ValueNode>()) {
|
||||||
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format, d_type);
|
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format,
|
||||||
|
(*input_desc)[kJOriFormat], d_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
|
bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
|
||||||
|
|
|
@ -487,6 +487,7 @@ void TbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx
|
||||||
(*output_desc)[kJFormat] = format;
|
(*output_desc)[kJFormat] = format;
|
||||||
(*output_desc)[kJOriFormat] = def_format;
|
(*output_desc)[kJOriFormat] = def_format;
|
||||||
(*output_desc)[kJOriShape] = ori_shape;
|
(*output_desc)[kJOriShape] = ori_shape;
|
||||||
|
shape = tbe::TbeDynamicShapeUtil::UpdateShape(anf_node, (*output_desc)[kJOriFormat], shape, node_out_idx, false);
|
||||||
(*output_desc)[kJShape] = shape;
|
(*output_desc)[kJShape] = shape;
|
||||||
(*output_desc)[kJName] = output_desc_name;
|
(*output_desc)[kJName] = output_desc_name;
|
||||||
// !! Note: output_index, only node's output use it
|
// !! Note: output_index, only node's output use it
|
||||||
|
|
|
@ -83,7 +83,7 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_
|
||||||
return infer_shape;
|
return infer_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
if (trans::IsNeedPadding(format, infer_shape)) {
|
||||||
auto reshape_type =
|
auto reshape_type =
|
||||||
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
|
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
|
||||||
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
|
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
|
||||||
|
|
|
@ -36,6 +36,23 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||||
namespace {
|
namespace {
|
||||||
|
struct CreateNodeArgs {
|
||||||
|
FuncGraphPtr func_graph{nullptr};
|
||||||
|
AnfNodePtr node{nullptr};
|
||||||
|
AnfNodePtr input_node{nullptr};
|
||||||
|
AnfNodePtr orig_node{nullptr};
|
||||||
|
KernelSelectPtr kernel_select{nullptr};
|
||||||
|
std::string trans_opname;
|
||||||
|
std::string input_format;
|
||||||
|
std::string dst_format;
|
||||||
|
std::string spec_format;
|
||||||
|
std::string reshape_type;
|
||||||
|
TypeId type_id;
|
||||||
|
ShapeVector out_shape;
|
||||||
|
bool is_dynamic_shape;
|
||||||
|
bool need_padding;
|
||||||
|
};
|
||||||
|
|
||||||
std::string GetTransOpName(const std::string &spec_format) {
|
std::string GetTransOpName(const std::string &spec_format) {
|
||||||
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
|
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
|
||||||
? prim::kPrimTransDataRNN->name()
|
? prim::kPrimTransDataRNN->name()
|
||||||
|
@ -83,6 +100,61 @@ CNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inp
|
||||||
return reshape;
|
return reshape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr CreateTransDataWithOutReshape(const CreateNodeArgs &args) {
|
||||||
|
// don't need padding insert TransData only
|
||||||
|
auto trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select,
|
||||||
|
args.need_padding, args.trans_opname);
|
||||||
|
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
|
||||||
|
args.type_id);
|
||||||
|
return trans_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr CreateTransDataWithReshape(const CreateNodeArgs &args) {
|
||||||
|
AnfNodePtr trans_node = nullptr;
|
||||||
|
CNodePtr trans_data = nullptr;
|
||||||
|
if (!args.need_padding) {
|
||||||
|
// don't need padding insert transdata only
|
||||||
|
trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding,
|
||||||
|
args.trans_opname);
|
||||||
|
trans_node = trans_data;
|
||||||
|
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
|
||||||
|
args.type_id);
|
||||||
|
} else if (args.spec_format == args.dst_format) {
|
||||||
|
// if need padding & default to special format
|
||||||
|
// ori_shape -> reshape[padding shape] -> transdata[device shape]
|
||||||
|
auto padding_shape = trans::PaddingShape(args.out_shape, args.dst_format, args.reshape_type, args.node);
|
||||||
|
std::vector<int> padding_axis;
|
||||||
|
if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) {
|
||||||
|
padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node);
|
||||||
|
}
|
||||||
|
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(padding_shape);
|
||||||
|
auto reshape_node = CreateReshapeNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select,
|
||||||
|
pad_shape_ptr, args.is_dynamic_shape, padding_axis);
|
||||||
|
trans_data = NewTransOpNode(args.func_graph, reshape_node, args.orig_node, args.kernel_select, args.need_padding,
|
||||||
|
args.trans_opname);
|
||||||
|
trans_node = trans_data;
|
||||||
|
trans_data->set_abstract(args.input_node->abstract());
|
||||||
|
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
|
||||||
|
args.type_id);
|
||||||
|
} else {
|
||||||
|
// if need padding & special to default format
|
||||||
|
// device shape -> transdata[padding shape] -> reshape[ori_shape]
|
||||||
|
trans_data = NewTransOpNode(args.func_graph, args.input_node, args.orig_node, args.kernel_select, args.need_padding,
|
||||||
|
args.trans_opname);
|
||||||
|
RefreshKernelBuildInfo(args.kernel_select, args.input_format, args.dst_format, trans_data, args.reshape_type,
|
||||||
|
args.type_id);
|
||||||
|
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(args.out_shape);
|
||||||
|
std::vector<int> padding_axis;
|
||||||
|
if (std::count(args.out_shape.begin(), args.out_shape.end(), -1) > 1) {
|
||||||
|
padding_axis = trans::StringToAxisVector(args.out_shape, args.dst_format, args.reshape_type, args.node);
|
||||||
|
}
|
||||||
|
auto reshape_node = CreateReshapeNode(args.func_graph, trans_data, args.orig_node, args.kernel_select,
|
||||||
|
pad_shape_ptr, args.is_dynamic_shape, padding_axis);
|
||||||
|
trans_node = reshape_node;
|
||||||
|
}
|
||||||
|
return trans_node;
|
||||||
|
}
|
||||||
|
|
||||||
void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
|
void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(trans_node);
|
MS_EXCEPTION_IF_NULL(trans_node);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -136,6 +208,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||||
auto real_input = node_with_index.first;
|
auto real_input = node_with_index.first;
|
||||||
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
||||||
|
MS_LOG(DEBUG)
|
||||||
|
<< "ValueNode or Parameter has no inputs, try to insert for ValueNode or Parameter at out anchor, node: "
|
||||||
|
<< node->fullname_with_scope();
|
||||||
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
|
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
common::AnfAlgo::SetNodeInput(node, input_node, index);
|
common::AnfAlgo::SetNodeInput(node, input_node, index);
|
||||||
|
@ -143,8 +218,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
||||||
ShapeVector origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
ShapeVector origin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||||
if (NeedInsertTransData(origin_shape, dest_format)) {
|
if (NeedInsertTransData(origin_shape, dest_format)) {
|
||||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
MS_LOG(DEBUG) << "Need insert TransData change format from [" << dest_format
|
||||||
<< " To DefaultFormat , index: " << index;
|
<< "] to [DefaultFormat], input index:" << index << ", node: " << node->fullname_with_scope();
|
||||||
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
auto transdata = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
||||||
if (real_input->isa<Parameter>()) {
|
if (real_input->isa<Parameter>()) {
|
||||||
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
|
SetGroupAttr(real_input->cast<ParameterPtr>(), input_node, transdata, dest_format);
|
||||||
|
@ -165,7 +240,8 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
||||||
<< node->DebugString() << trace::DumpSourceLines(node);
|
<< node->DebugString() << trace::DumpSourceLines(node);
|
||||||
}
|
}
|
||||||
if (NeedInsertTransData(origin_shape, output_format)) {
|
if (NeedInsertTransData(origin_shape, output_format)) {
|
||||||
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
|
MS_LOG(DEBUG) << "Inserted TransData change format from [" << output_format
|
||||||
|
<< "] to [DefaultFormat], single output index :0";
|
||||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
||||||
}
|
}
|
||||||
return node;
|
return node;
|
||||||
|
@ -243,61 +319,32 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
|
||||||
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
|
auto out_shape_ptr = input_node_out_shape->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
||||||
ShapeVector out_shape = out_shape_ptr->shape();
|
ShapeVector out_shape = out_shape_ptr->shape();
|
||||||
|
auto is_dyn_rank = out_shape_ptr->IsDimUnknown();
|
||||||
auto is_dynamic_shape = out_shape_ptr->IsDynamic();
|
auto is_dynamic_shape = out_shape_ptr->IsDynamic();
|
||||||
|
|
||||||
bool need_padding = trans::IsNeedPadding(spec_format, out_shape.size());
|
bool need_padding = trans::IsNeedPadding(spec_format, out_shape);
|
||||||
std::string trans_opname = GetTransOpName(spec_format);
|
std::string trans_opname = GetTransOpName(spec_format);
|
||||||
bool is_insert_output = node == input_node;
|
bool is_insert_output = node == input_node;
|
||||||
auto orig_node = GetOriginNode(is_insert_output, node);
|
auto orig_node = GetOriginNode(is_insert_output, node);
|
||||||
|
AnfNodePtr trans_data = nullptr;
|
||||||
AnfNodePtr trans_node = nullptr;
|
CreateNodeArgs args = {func_graph, node, input_node, orig_node, kernel_select,
|
||||||
CNodePtr trans_data = nullptr;
|
trans_opname, input_format, dst_format, spec_format, reshape_type,
|
||||||
if (!need_padding) {
|
type_id, out_shape, is_dynamic_shape, need_padding};
|
||||||
// don't need padding insert transdata only
|
if (is_dyn_rank) {
|
||||||
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
|
trans_data = CreateTransDataWithOutReshape(args);
|
||||||
trans_node = trans_data;
|
|
||||||
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
|
|
||||||
} else if (spec_format == dst_format) {
|
|
||||||
// if need padding & default to special format
|
|
||||||
// ori_shape -> reshape[padding shape] -> transdata[device shape]
|
|
||||||
|
|
||||||
auto padding_shape = trans::PaddingShape(out_shape, dst_format, reshape_type, node);
|
|
||||||
std::vector<int> padding_axis;
|
|
||||||
if (std::count(padding_shape.begin(), padding_shape.end(), -1) > 1) {
|
|
||||||
padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node);
|
|
||||||
}
|
|
||||||
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(padding_shape);
|
|
||||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, orig_node, kernel_select, pad_shape_ptr,
|
|
||||||
is_dynamic_shape, padding_axis);
|
|
||||||
trans_data = NewTransOpNode(func_graph, reshape_node, orig_node, kernel_select, need_padding, trans_opname);
|
|
||||||
trans_node = trans_data;
|
|
||||||
trans_data->set_abstract(input_node->abstract());
|
|
||||||
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
|
|
||||||
} else {
|
} else {
|
||||||
// if need padding & special to default format
|
trans_data = CreateTransDataWithReshape(args);
|
||||||
// device shape -> transdata[padding shape] -> reshape[ori_shape]
|
|
||||||
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
|
|
||||||
RefreshKernelBuildInfo(kernel_select, input_format, dst_format, trans_data, reshape_type, type_id);
|
|
||||||
abstract::ShapePtr pad_shape_ptr = std::make_shared<abstract::Shape>(out_shape);
|
|
||||||
std::vector<int> padding_axis;
|
|
||||||
if (std::count(out_shape.begin(), out_shape.end(), -1) > 1) {
|
|
||||||
padding_axis = trans::StringToAxisVector(out_shape, dst_format, reshape_type, node);
|
|
||||||
}
|
|
||||||
auto reshape_node = CreateReshapeNode(func_graph, trans_data, orig_node, kernel_select, pad_shape_ptr,
|
|
||||||
is_dynamic_shape, padding_axis);
|
|
||||||
|
|
||||||
trans_node = reshape_node;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
|
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
|
||||||
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
|
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
|
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), trans_data);
|
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), trans_data);
|
||||||
}
|
}
|
||||||
if (is_insert_output) {
|
if (is_insert_output) {
|
||||||
ReFreshInferShape(trans_node, node);
|
ReFreshInferShape(trans_data, node);
|
||||||
}
|
}
|
||||||
return trans_node;
|
return trans_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RefreshKernelBuildInfo(const KernelSelectPtr &kernel_select, const std::string &input_format,
|
void RefreshKernelBuildInfo(const KernelSelectPtr &kernel_select, const std::string &input_format,
|
||||||
|
@ -368,15 +415,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||||
auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
|
auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
|
||||||
MS_EXCEPTION_IF_NULL(out_shape_base);
|
MS_EXCEPTION_IF_NULL(out_shape_base);
|
||||||
ShapeVector out_shape;
|
ShapeVector out_shape;
|
||||||
|
bool is_dyn_rank = false;
|
||||||
bool is_dynamic_shape = false;
|
bool is_dynamic_shape = false;
|
||||||
if (out_shape_base->isa<abstract::Shape>()) {
|
if (out_shape_base->isa<abstract::Shape>()) {
|
||||||
auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
|
auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
||||||
out_shape = out_shape_ptr->shape();
|
out_shape = out_shape_ptr->shape();
|
||||||
is_dynamic_shape = out_shape_ptr->IsDynamic();
|
is_dynamic_shape = out_shape_ptr->IsDynamic();
|
||||||
|
is_dyn_rank = out_shape_ptr->IsDimUnknown();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (need_padding) {
|
if (need_padding && !is_dyn_rank) {
|
||||||
// if need padding we should set the transdata node's shape to the padding shape
|
// if need padding we should set the transdata node's shape to the padding shape
|
||||||
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
||||||
|
|
||||||
|
@ -495,6 +544,7 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||||
size_t in_num = common::AnfAlgo::GetInputNum(cnode); // include monads.
|
size_t in_num = common::AnfAlgo::GetInputNum(cnode); // include monads.
|
||||||
|
MS_LOG(DEBUG) << "Try to insert TransData at input anchor for node: " << cnode->fullname_with_scope();
|
||||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||||
// Monad inputs keep unchanged from GetTransInputNodePtr().
|
// Monad inputs keep unchanged from GetTransInputNodePtr().
|
||||||
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
|
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
|
||||||
|
|
|
@ -247,13 +247,16 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsNeedPadding(const std::string &format, size_t shape_size) {
|
bool IsNeedPadding(const std::string &format, const ShapeVector &shape) {
|
||||||
if (shape_size == 0) {
|
if (shape.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (IsDynamicRank(shape) && !IsOneOfDynRankNeedPadShape(format)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || IsOneOfNoPaddingFormat(format)) {
|
if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || IsOneOfNoPaddingFormat(format)) {
|
||||||
return false;
|
return false;
|
||||||
} else if (shape_size < kDim4) {
|
} else if (shape.size() < kDim4) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -288,7 +291,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
||||||
host_shape = common::AnfAlgo::GetOutputInferShape(node, index);
|
host_shape = common::AnfAlgo::GetOutputInferShape(node, index);
|
||||||
}
|
}
|
||||||
auto format = AnfAlgo::GetOutputFormat(node, index);
|
auto format = AnfAlgo::GetOutputFormat(node, index);
|
||||||
if (IsNeedPadding(format, host_shape.size())) {
|
if (IsNeedPadding(format, host_shape)) {
|
||||||
host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index), node);
|
host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index), node);
|
||||||
}
|
}
|
||||||
return host_shape;
|
return host_shape;
|
||||||
|
|
|
@ -39,6 +39,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "include/backend/visible.h"
|
#include "include/backend/visible.h"
|
||||||
|
#include "mindapi/base/shape_vector.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace trans {
|
namespace trans {
|
||||||
|
@ -329,7 +330,7 @@ BACKEND_EXPORT ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t
|
||||||
/**
|
/**
|
||||||
* If need padding
|
* If need padding
|
||||||
* */
|
* */
|
||||||
BACKEND_EXPORT bool IsNeedPadding(const std::string &format, size_t shape_size);
|
BACKEND_EXPORT bool IsNeedPadding(const std::string &format, const ShapeVector &shape);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Padding shape to 5D by default mode
|
* Padding shape to 5D by default mode
|
||||||
|
@ -442,16 +443,21 @@ std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &form
|
||||||
MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format
|
MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format
|
||||||
<< ", detail info: " << node->DebugString();
|
<< ", detail info: " << node->DebugString();
|
||||||
}
|
}
|
||||||
std::vector<T> host_shape;
|
|
||||||
if (IsOneOf3DFormat(format)) {
|
if (IsOneOf3DFormat(format)) {
|
||||||
if (shape.size() >= kDim5) {
|
if (shape.size() >= kDim5) {
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
host_shape = PaddingShapeTo5d(shape, pad_index);
|
if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
|
||||||
} else {
|
return {-1, -1, -1, -1, -1};
|
||||||
host_shape = PaddingShapeTo4d(shape, pad_index);
|
}
|
||||||
|
return PaddingShapeTo5d(shape, pad_index);
|
||||||
}
|
}
|
||||||
return host_shape;
|
|
||||||
|
if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
|
||||||
|
return {-1, -1, -1, -1};
|
||||||
|
}
|
||||||
|
return PaddingShapeTo4d(shape, pad_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1382,6 +1382,93 @@ bool AnfAlgo::HasDynamicShapeFlag(const PrimitivePtr &prim) {
|
||||||
return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape);
|
return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsNodeDynamicRank(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
MS_LOG(DEBUG) << "Node is not a cnode";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto in_dyn_rank = AnfAlgo::IsNodeInputDynamicRank(cnode);
|
||||||
|
auto out_dyn_rank = AnfAlgo::IsNodeOutputDynamicRank(cnode);
|
||||||
|
if (in_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrInputIsDynamicRank, cnode)) {
|
||||||
|
AnfAlgo::SetNodeAttrSafely(kAttrInputIsDynamicRank, MakeValue(true), cnode);
|
||||||
|
MS_LOG(DEBUG) << "Set input dynamic rank attr for node:" << cnode->fullname_with_scope();
|
||||||
|
}
|
||||||
|
if (out_dyn_rank && !AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicRank, cnode)) {
|
||||||
|
AnfAlgo::SetNodeAttrSafely(kAttrOutputIsDynamicRank, MakeValue(true), cnode);
|
||||||
|
MS_LOG(DEBUG) << "Set output dynamic rank attr for node:" << cnode->fullname_with_scope();
|
||||||
|
}
|
||||||
|
return in_dyn_rank || out_dyn_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfAlgo::IsDynamicRankNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->isa<Parameter>()) {
|
||||||
|
return IsOutputAnchorDynamicRank(node, 0);
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if ((!HasNodeAttr(kAttrInputIsDynamicRank, cnode)) && (!HasNodeAttr(kAttrOutputIsDynamicRank, cnode))) {
|
||||||
|
auto ret = IsNodeDynamicRank(node);
|
||||||
|
MS_LOG(DEBUG) << "The Node:" << node->fullname_with_scope() << " is dynamic rank: [" << ret << "]";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return GetBooleanAttr(node, kAttrInputIsDynamicRank) || GetBooleanAttr(node, kAttrOutputIsDynamicRank) ||
|
||||||
|
GetBooleanAttr(node, kAttrIsDynamicRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfAlgo::IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only cnode has inputs, node: " << node->fullname_with_scope();
|
||||||
|
}
|
||||||
|
const auto &in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, idx);
|
||||||
|
if (mindspore::IsDynamicRank(in_shape)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfAlgo::IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
const auto &out_shape = common::AnfAlgo::GetOutputInferShape(node, idx);
|
||||||
|
if (mindspore::IsDynamicRank(out_shape)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfAlgo::IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||||
|
const auto &inputs = anf_node_ptr->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
const auto &input = inputs[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
if (IsNodeOutputDynamicRank(input)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfAlgo::IsNodeOutputDynamicRank(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto base_shape = node->Shape();
|
||||||
|
if (base_shape == nullptr) {
|
||||||
|
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (base_shape->isa<abstract::DynamicSequenceShape>()) {
|
||||||
|
auto b_ptr = base_shape->cast<abstract::DynamicSequenceShapePtr>();
|
||||||
|
if (b_ptr->IsDimUnknown()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return base_shape->IsDimUnknown();
|
||||||
|
}
|
||||||
|
|
||||||
bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) {
|
bool AnfAlgo::IsDynamicShape(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
|
|
|
@ -211,4 +211,11 @@ bool IsOneOfServerFormatC04(const std::string &format) {
|
||||||
static const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
static const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||||
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
|
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsOneOfDynRankNeedPadShape(const std::string &format) {
|
||||||
|
const std::set<std::string> kOpFormats = {kOpFormat_NC1HWC0, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z,
|
||||||
|
kOpFormat_NDC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
|
||||||
|
kOpFormat_FRACTAL_Z_3D, kOpFormat_FRACTAL_Z_C04, kOpFormat_NCDHW};
|
||||||
|
return kOpFormats.find(format) != kOpFormats.end();
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -244,6 +244,7 @@ class MS_CORE_API DynamicSequenceShape : public BaseShape {
|
||||||
// element's shape
|
// element's shape
|
||||||
BaseShapePtr element_shape_{nullptr};
|
BaseShapePtr element_shape_{nullptr};
|
||||||
};
|
};
|
||||||
|
using DynamicSequenceShapePtr = std::shared_ptr<DynamicSequenceShape>;
|
||||||
GVAR_DEF(std::shared_ptr<DynamicSequenceShape>, kDynamicSequenceShape, std::make_shared<DynamicSequenceShape>());
|
GVAR_DEF(std::shared_ptr<DynamicSequenceShape>, kDynamicSequenceShape, std::make_shared<DynamicSequenceShape>());
|
||||||
|
|
||||||
/// \brief SequequeShape defines base class of multiple-shape classes.
|
/// \brief SequequeShape defines base class of multiple-shape classes.
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.ops.operations as ops
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
np.random.seed(3)
|
||||||
|
|
||||||
|
|
||||||
|
class MSDynRankNet(Cell):
|
||||||
|
def __init__(self, is_training=True):
|
||||||
|
super(MSDynRankNet, self).__init__()
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_norm = ops.BatchNorm(is_training=self.is_training)
|
||||||
|
self.reduce_mean = ops.ReduceMean(keep_dims=False)
|
||||||
|
self.relu = ops.ReLU()
|
||||||
|
|
||||||
|
def construct(self, input_x, scale, offset, mean, variance, indices):
|
||||||
|
unique_indices = self.relu(indices)
|
||||||
|
reduced_in = self.reduce_mean(input_x, unique_indices)
|
||||||
|
reduced_scale = self.reduce_mean(scale, unique_indices)
|
||||||
|
reduced_offset = self.reduce_mean(offset, unique_indices)
|
||||||
|
reduced_mean = self.reduce_mean(mean, unique_indices)
|
||||||
|
reduced_variance = self.reduce_mean(variance, unique_indices)
|
||||||
|
out, _, _, _, _ = self.batch_norm(reduced_in, reduced_scale, reduced_offset, reduced_mean, reduced_variance)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NetFactory:
|
||||||
|
def __init__(self, x, scale, offset, mean, variance, indices, dtype=np.float32, is_training=False):
|
||||||
|
super(NetFactory, self).__init__()
|
||||||
|
self.x = x
|
||||||
|
self.scale = scale
|
||||||
|
self.offset = offset
|
||||||
|
self.mean = mean
|
||||||
|
self.variance = variance
|
||||||
|
self.indices = indices
|
||||||
|
self.dtype = dtype
|
||||||
|
self.is_training = is_training
|
||||||
|
self.nh2nc = [0, 3, 1, 2]
|
||||||
|
self.nc2nh = [0, 2, 3, 1]
|
||||||
|
|
||||||
|
def mindspore_case(self):
|
||||||
|
ms_x = Tensor(self.x)
|
||||||
|
ms_indices = Tensor(self.indices)
|
||||||
|
ms_scale = Tensor(self.scale)
|
||||||
|
ms_offset = Tensor(self.offset)
|
||||||
|
ms_mean = Tensor(self.mean)
|
||||||
|
ms_variance = Tensor(self.variance)
|
||||||
|
|
||||||
|
ms_dyn_x = Tensor(shape=[None for _ in ms_x.shape], dtype=ms_x.dtype)
|
||||||
|
ms_dyn_scale = Tensor(shape=[None for _ in ms_scale.shape], dtype=ms_scale.dtype)
|
||||||
|
ms_dyn_offset = Tensor(shape=[None for _ in ms_offset.shape], dtype=ms_offset.dtype)
|
||||||
|
ms_dyn_mean = Tensor(shape=[None for _ in ms_mean.shape], dtype=ms_mean.dtype)
|
||||||
|
ms_dyn_variance = Tensor(shape=[None for _ in ms_variance.shape], dtype=ms_variance.dtype)
|
||||||
|
ms_dyn_indices = Tensor(shape=[None], dtype=ms_indices.dtype)
|
||||||
|
|
||||||
|
ms_net = MSDynRankNet(is_training=self.is_training)
|
||||||
|
ms_net.set_inputs(ms_dyn_x, ms_dyn_scale, ms_dyn_offset, ms_dyn_mean, ms_dyn_variance, ms_dyn_indices)
|
||||||
|
|
||||||
|
ms_out = ms_net(ms_x, ms_scale, ms_offset, ms_mean, ms_variance, ms_indices)
|
||||||
|
return ms_out.asnumpy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_batch_norm_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test batch norm dynamic rank
|
||||||
|
Description: test batch norm dynamic rank with input tensor's type float32
|
||||||
|
Expectation: none.
|
||||||
|
"""
|
||||||
|
input_x = np.random.randn(3, 3, 4, 3, 3).astype(np.float32)
|
||||||
|
scale_ = np.ones((4, 4)).astype(np.float32)
|
||||||
|
offset_ = np.ones((4, 4)).astype(np.float32)
|
||||||
|
mean_ = np.ones((4, 4)).astype(np.float32)
|
||||||
|
variance_ = np.ones((4, 4)).astype(np.float32)
|
||||||
|
indices_ = np.unique(np.random.randint(1, 2, (1,)).astype(np.int32))
|
||||||
|
|
||||||
|
# graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
graph_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_, indices=indices_,
|
||||||
|
dtype=np.float32)
|
||||||
|
graph_mode_out = graph_mode_net.mindspore_case()
|
||||||
|
|
||||||
|
# pynative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
pynative_mode_net = NetFactory(input_x, scale=scale_, offset=offset_, mean=mean_, variance=variance_,
|
||||||
|
indices=indices_,
|
||||||
|
dtype=np.float32)
|
||||||
|
pynative_mode_out = pynative_mode_net.mindspore_case()
|
||||||
|
|
||||||
|
assert np.allclose(pynative_mode_out, graph_mode_out, 1e-4, 1e-4)
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
np.random.seed(3)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class MSBiasAddDynRankNet(Cell):
|
||||||
|
def __init__(self, data_format="NCHW"):
|
||||||
|
super(MSBiasAddDynRankNet, self).__init__()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, input_a, input_b, indices):
|
||||||
|
relu_indices = self.relu(indices)
|
||||||
|
reduce_a = self.reduce_sum(input_a, relu_indices)
|
||||||
|
out = self.bias_add(reduce_a, input_b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAddNet(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(TorchAddNet, self).__init__()
|
||||||
|
self.keep_dims = False
|
||||||
|
|
||||||
|
def forward(self, input_a, input_b, indices):
|
||||||
|
relu_indices = torch.relu(indices)
|
||||||
|
reduce_a = torch.sum(input_a, relu_indices.tolist(), keepdim=self.keep_dims)
|
||||||
|
out = torch.add(reduce_a, input_b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BiasAddOpFactory:
|
||||||
|
def __init__(self, in_shape, indices, dtype=np.float32, data_format="NCHW"):
|
||||||
|
super(BiasAddOpFactory, self).__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.input_x = np.random.randn(*in_shape).astype(self.dtype)
|
||||||
|
self.data_format = data_format
|
||||||
|
self.indices = indices
|
||||||
|
self.input_b = np.random.randn(in_shape[-1]).astype(self.dtype)
|
||||||
|
self.loss = 1e-4
|
||||||
|
|
||||||
|
def ms_biass_add_forward(self):
|
||||||
|
a = Tensor(self.input_x)
|
||||||
|
b = Tensor(self.input_b)
|
||||||
|
indices = Tensor(self.indices)
|
||||||
|
|
||||||
|
dyn_a = Tensor(shape=[None for _ in a.shape], dtype=a.dtype)
|
||||||
|
dyn_b = Tensor(shape=[None for _ in b.shape], dtype=b.dtype)
|
||||||
|
dyn_indices = Tensor(shape=[None], dtype=indices.dtype)
|
||||||
|
|
||||||
|
ms_net = MSBiasAddDynRankNet(data_format=self.data_format)
|
||||||
|
ms_net.set_inputs(dyn_a, dyn_b, dyn_indices)
|
||||||
|
out = ms_net(a, b, indices)
|
||||||
|
return out.asnumpy()
|
||||||
|
|
||||||
|
def torch_bias_add_forward(self):
|
||||||
|
torch_net = TorchAddNet()
|
||||||
|
out = torch_net(torch.from_numpy(self.input_x), torch.from_numpy(self.input_b), torch.from_numpy(self.indices))
|
||||||
|
return out.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_bias_add_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: test bias add dynamic rank
|
||||||
|
Description: test bias add dynamic rank with input tensor's type float32
|
||||||
|
Expectation: none.
|
||||||
|
"""
|
||||||
|
in_shape = (16, 16, 16, 16, 16)
|
||||||
|
indices_np = np.unique(np.random.randint(0, 2, size=3).astype(np.int32))
|
||||||
|
factory = BiasAddOpFactory(in_shape=in_shape, indices=indices_np, dtype=np.float32, data_format="NCHW")
|
||||||
|
ms_out = factory.ms_biass_add_forward()
|
||||||
|
torch_out = factory.torch_bias_add_forward()
|
||||||
|
|
||||||
|
np.allclose(ms_out, torch_out, factory.loss, factory.loss)
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
np.random.seed(3)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class MSReduceSumNet(Cell):
|
||||||
|
def __init__(self, keep_dims=False):
|
||||||
|
super(MSReduceSumNet, self).__init__()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=keep_dims)
|
||||||
|
self.reduce = P.ReduceSum(keep_dims=False)
|
||||||
|
|
||||||
|
def construct(self, x, indices, axis):
|
||||||
|
x = self.reduce(x, axis)
|
||||||
|
return self.reduce_sum(x, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchReduceSumNet(torch.nn.Module):
|
||||||
|
def __init__(self, keep_dims=False):
|
||||||
|
super(TorchReduceSumNet, self).__init__()
|
||||||
|
self.keep_dims = keep_dims
|
||||||
|
|
||||||
|
def forward(self, input_x, indices, axis):
|
||||||
|
x = torch.sum(input_x, axis.tolist(), False)
|
||||||
|
out = torch.sum(x, indices.tolist(), self.keep_dims)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceOpFactory:
|
||||||
|
def __init__(self, input_x, indices, axis, keep_dims, dtype=np.float32, loos=1e-4):
|
||||||
|
super(ReduceOpFactory, self).__init__()
|
||||||
|
self.out_grad = None
|
||||||
|
self.input_x = input_x
|
||||||
|
self.indices = indices
|
||||||
|
self.axis = axis
|
||||||
|
self.keep_dims = keep_dims
|
||||||
|
self.dtype = dtype
|
||||||
|
self.loss = loos
|
||||||
|
|
||||||
|
def ms_reduce_sum_forward(self):
|
||||||
|
net = MSReduceSumNet(self.keep_dims)
|
||||||
|
in_tensor = Tensor(self.input_x)
|
||||||
|
in_indices = Tensor(self.indices)
|
||||||
|
in_axis = Tensor(self.axis)
|
||||||
|
|
||||||
|
in_tensor_dyn = Tensor(shape=[None for _ in in_tensor.shape], dtype=in_tensor.dtype)
|
||||||
|
in_indices_dyn = Tensor(shape=[None for _ in in_indices.shape], dtype=in_indices.dtype)
|
||||||
|
in_axis_dyn = Tensor(shape=[None for _ in in_axis.shape], dtype=in_axis.dtype)
|
||||||
|
|
||||||
|
net.set_inputs(in_tensor_dyn, in_indices_dyn, in_axis_dyn)
|
||||||
|
out = net(in_tensor, in_indices, in_axis)
|
||||||
|
return out.asnumpy()
|
||||||
|
|
||||||
|
def torch_reduce_sum_forward(self):
|
||||||
|
net = TorchReduceSumNet(self.keep_dims)
|
||||||
|
out = net(torch.from_numpy(self.input_x.astype(self.dtype)), self.indices, self.axis)
|
||||||
|
return out.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_reduce_sum_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: test reduce sum dynamic rank
|
||||||
|
Description: test reduce sum dynamic rank with input tensor's type float32
|
||||||
|
Expectation: none.
|
||||||
|
"""
|
||||||
|
dtype = np.float32
|
||||||
|
x = np.random.randn(22, 20, 28, 36, 24, 23).astype(dtype)
|
||||||
|
indices = np.array([0, -1])
|
||||||
|
axis = np.unique(np.random.randint(0, 2, size=5).astype(np.int32))
|
||||||
|
factory = ReduceOpFactory(x, indices, axis, keep_dims=True, dtype=dtype, loos=1e-4)
|
||||||
|
|
||||||
|
ms_data = factory.ms_reduce_sum_forward()
|
||||||
|
torch_data = factory.torch_reduce_sum_forward()
|
||||||
|
np.allclose(torch_data, ms_data, factory.loss, factory.loss)
|
Loading…
Reference in New Issue