forked from mindspore-Ecosystem/mindspore
commit
4f032cf3f8
|
@ -95,6 +95,16 @@ constexpr auto kJSocVersion = "socVersion";
|
|||
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
||||
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
|
||||
|
||||
bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(INFO) << "Check if need change default format";
|
||||
if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
|
||||
return attr == kOpFormat_NCDHW;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
|
||||
nlohmann::json *kernel_json) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
@ -161,10 +171,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode>
|
|||
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
|
||||
const string &op_input_name, size_t input_i,
|
||||
std::vector<nlohmann::json> *input_list) {
|
||||
auto def_format = kOpFormat_NCHW;
|
||||
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||
auto format = GetDeviceInputFormat(anf_node, real_input_index);
|
||||
auto shape = GetDeviceInputShape(anf_node, real_input_index);
|
||||
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
def_format = kOpFormat_NCDHW;
|
||||
}
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
|
@ -172,7 +186,7 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode>
|
|||
input_desc_json[kJDtype] = dtype;
|
||||
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
|
||||
input_desc_json[kJOriShape] = ori_shape;
|
||||
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
|
||||
input_desc_json[kJOriFormat] = def_format;
|
||||
input_desc_json[kJShape] = shape;
|
||||
input_desc_json[kJFormat] = format;
|
||||
input_desc_json[kJValid] = value;
|
||||
|
@ -379,6 +393,10 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
|
|||
std::vector<nlohmann::json> *output_list) {
|
||||
MS_EXCEPTION_IF_NULL(output_idx);
|
||||
MS_EXCEPTION_IF_NULL(output_list);
|
||||
auto def_format = kOpFormat_NCHW;
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
def_format = kOpFormat_NCDHW;
|
||||
}
|
||||
for (size_t i = 0; i < output_obj_num; i++) {
|
||||
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
|
||||
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
|
||||
|
@ -397,7 +415,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
|
|||
output_obj[kJShape] = shape;
|
||||
output_obj[kJFormat] = format;
|
||||
output_obj[kJOriShape] = ori_shape;
|
||||
output_obj[kJOriFormat] = kOpFormat_NCHW;
|
||||
output_obj[kJOriFormat] = def_format;
|
||||
output_obj[kJName] = output_ptr->name();
|
||||
output_obj[kJValid] = true;
|
||||
output_obj[kJParamType] = output_ptr->param_type();
|
||||
|
@ -580,6 +598,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod
|
|||
format = kOpFormat_NCHW;
|
||||
}
|
||||
}
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
format = kOpFormat_NCDHW;
|
||||
}
|
||||
return format;
|
||||
}
|
||||
|
||||
|
@ -619,6 +640,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
|
|||
format = kOpFormat_NCHW;
|
||||
}
|
||||
}
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
format = kOpFormat_NCDHW;
|
||||
}
|
||||
return format;
|
||||
}
|
||||
|
||||
|
@ -818,6 +842,10 @@ void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) {
|
|||
void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
|
||||
size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) {
|
||||
GenPreDescJson(output_desc);
|
||||
auto def_format = kOpFormat_NCHW;
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
def_format = kOpFormat_NCDHW;
|
||||
}
|
||||
// data_type
|
||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx);
|
||||
(*output_desc)[kJDataType] = tbe::TypeIdToString(type_id);
|
||||
|
@ -828,7 +856,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
|
|||
}
|
||||
(*output_desc)[kJName] = output_desc_name;
|
||||
// ori_format
|
||||
(*output_desc)[kJOriFormat] = kOpFormat_NCHW;
|
||||
(*output_desc)[kJOriFormat] = def_format;
|
||||
// ori_shape
|
||||
auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx);
|
||||
if (ori_shape.empty()) {
|
||||
|
|
|
@ -248,13 +248,57 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
|
|||
|
||||
bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
return false;
|
||||
if (IsSameShape()) {
|
||||
if (!HasScalarInput()) {
|
||||
AssignSupportFormat(kOpFormat_NDC1HWC0, support_format);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_NCDHW);
|
||||
} else if (!Is5DShape(shape)) {
|
||||
return false;
|
||||
} else if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return false;
|
||||
} else {
|
||||
input_support_format.emplace_back(kOpFormat_NDC1HWC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is5DShape(shape)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto shape_tmp = input_shapes_[0];
|
||||
auto broadcast_c_axis = std::any_of(
|
||||
input_shapes_.begin(), input_shapes_.end(),
|
||||
[&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); });
|
||||
if (broadcast_c_axis) {
|
||||
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||
return false;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_NDC1HWC0);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format);
|
||||
support_format->input_format.emplace_back(input_support_format);
|
||||
support_format->output_format.emplace_back(output_support_format);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const {
|
||||
return shape.size() == kShape4dDims;
|
||||
}
|
||||
|
||||
bool TbeKernelBroadCastSelecter::Is5DShape(const std::vector<size_t> &shape) const {
|
||||
return shape.size() == kShape5dDims;
|
||||
}
|
||||
|
||||
bool TbeKernelBroadCastSelecter::IsSameShape() const {
|
||||
auto shape = input_shapes_.begin();
|
||||
for (const auto &item : input_shapes_) {
|
||||
|
|
|
@ -40,6 +40,7 @@ class TbeKernelBroadCastSelecter {
|
|||
bool IsSameShape() const;
|
||||
void PadScalarShape(std::vector<size_t> *shape) const;
|
||||
bool Is4DShape(const std::vector<size_t> &shape) const;
|
||||
bool Is5DShape(const std::vector<size_t> &shape) const;
|
||||
bool IsScalarShape(const std::vector<size_t> &shape) const;
|
||||
bool HasScalarInput() const;
|
||||
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
|
||||
|
|
|
@ -72,8 +72,18 @@ bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format)
|
|||
|
||||
bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
// like to 5HD
|
||||
return false;
|
||||
if (!Is5DShape(input_shape_)) {
|
||||
return false;
|
||||
}
|
||||
if (!keep_dims_ || axis_.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); });
|
||||
if (reduce_c_axis) {
|
||||
return false;
|
||||
}
|
||||
AssignSupportFormat(kOpFormat_NDC1HWC0, support_format);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const {
|
||||
|
@ -142,6 +152,8 @@ void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_for
|
|||
|
||||
bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; }
|
||||
|
||||
bool TbeKernelReduceSelecter::Is5DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape5dDims; }
|
||||
|
||||
void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->empty()) {
|
||||
|
|
|
@ -39,6 +39,7 @@ class TbeKernelReduceSelecter {
|
|||
void GetReduceAttrKeepDim();
|
||||
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
|
||||
bool Is4DShape(const std::vector<size_t> &shape) const;
|
||||
bool Is5DShape(const std::vector<size_t> &shape) const;
|
||||
void PadScalarShape(std::vector<size_t> *shape) const;
|
||||
CNodePtr cnode_ptr_;
|
||||
std::vector<size_t> input_shape_{};
|
||||
|
|
|
@ -187,6 +187,9 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
|
|||
if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) {
|
||||
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ.";
|
||||
}
|
||||
if (!broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format)) {
|
||||
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support NDC1HWC0.";
|
||||
}
|
||||
PrintSupportedFormat(support_format);
|
||||
OpInfo op_info_new;
|
||||
CreateNewOpInfo(op_info, support_format, &op_info_new);
|
||||
|
@ -281,10 +284,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
|
|||
return true;
|
||||
}
|
||||
// not support format:
|
||||
// 1 NDHWC with shape size != 5
|
||||
// 3 !NDHWC with shape size > 4
|
||||
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
|
||||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
|
||||
// 1 NCDHW with shape size != 5
|
||||
if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) {
|
||||
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
namespace {
|
||||
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
|
||||
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
|
||||
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
||||
std::vector<AnfNodePtr> trans_inputs;
|
||||
|
@ -70,9 +70,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
CNodePtr trans_data = nullptr;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Init
|
||||
std::string default_format = kOpFormat_DEFAULT;
|
||||
|
||||
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
default_format = kOpFormat_NCDHW;
|
||||
}
|
||||
}
|
||||
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT;
|
||||
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
|
||||
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
|
||||
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
||||
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
||||
|
|
|
@ -369,6 +369,26 @@ void KernelGraph::CheckLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
|
||||
kernel_build_info_builder->SetOutputsFormat({format});
|
||||
kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
|
||||
}
|
||||
|
||||
void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) {
|
||||
auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i);
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) {
|
||||
ReSetParameterValueNodeFormatAndType(in_node, format);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||
auto cnode = FuncGraph::NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -378,6 +398,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
SetKernelInfoForNode(cnode);
|
||||
if (AnfAlgo::HasNodeAttr("io_format", cnode)) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
ResetInFormat(cnode, kOpFormat_NCDHW);
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
return cnode;
|
||||
}
|
||||
|
|
|
@ -273,6 +273,7 @@ class KernelGraph : public FuncGraph {
|
|||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
||||
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||
|
|
|
@ -266,6 +266,41 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(N1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -310,7 +345,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < kNdhwc) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
}
|
||||
|
@ -405,7 +440,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
|
||||
{kOpFormat_NDHWC, NdhwcDeviceShape}};
|
||||
{kOpFormat_NCDHW, NcdhwDeviceShape},
|
||||
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};
|
||||
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
|
@ -441,7 +478,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
if (shape.size() != kNchwDims) {
|
||||
if (shape.size() != kNchwDims && shape.size() != 5) {
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
|
@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
||||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
||||
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
|
@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
|
||||
{kOpFormat_FRAC_NZ, FracNzToNchw},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
|
||||
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
|
||||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
|
||||
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
|
@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto c1 = args.device_shape[2];
|
||||
auto c0 = args.device_shape[5];
|
||||
const size_t cdhw = c * d * h * w;
|
||||
const size_t dhw = d * h * w;
|
||||
const size_t hw = h * w;
|
||||
const size_t dc1hwc0 = d * c1 * h * w * c0;
|
||||
const size_t c1hwc0 = c1 * h * w * c0;
|
||||
const size_t hwc0 = h * w * c0;
|
||||
const size_t wc0 = w * c0;
|
||||
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
size_t n_head = n_i * cdhw;
|
||||
for (size_t c_i = 0; c_i < c; c_i++) {
|
||||
size_t c_head = n_head + c_i * dhw;
|
||||
for (size_t d_i = 0; d_i < d; d_i++) {
|
||||
size_t d_head = c_head + d_i * hw;
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
size_t h_head = d_head + h_i * w;
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
size_t dst_i = h_head + w_i;
|
||||
size_t c1_i = c_i / c0;
|
||||
size_t c0_i = c_i % c0;
|
||||
auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
|
||||
SetData(size, false, src_idx, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != 5) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto d = args.host_shape[2];
|
||||
auto h = args.host_shape[3];
|
||||
auto w = args.host_shape[4];
|
||||
auto c0 = kCubeSize;
|
||||
auto c1 = DivCeil(c, c0);
|
||||
const size_t cdhw = c * d * h * w;
|
||||
const size_t dhw = d * h * w;
|
||||
const size_t hw = h * w;
|
||||
const size_t dc1hwc0 = d * c1 * h * w * c0;
|
||||
const size_t c1hwc0 = c1 * h * w * c0;
|
||||
const size_t hwc0 = h * w * c0;
|
||||
const size_t wc0 = w * c0;
|
||||
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
size_t n_head = n_i * dc1hwc0;
|
||||
for (size_t d_i = 0; d_i < d; d_i++) {
|
||||
size_t d_head = n_head + d_i * c1hwc0;
|
||||
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
size_t c1_head = d_head + c1_i * hwc0;
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
size_t h_head = c1_head + h_i * wc0;
|
||||
for (size_t w_i = 0; w_i < w; w_i++) {
|
||||
size_t w_head = h_head + w_i * c0;
|
||||
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
size_t dst_i = c0_i + w_head;
|
||||
size_t c_i = c0_i + c1_i * c0;
|
||||
size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
|
||||
auto pad_zero = c_i >= c;
|
||||
SetData(size, pad_zero, src_i, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,6 +66,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result);
|
|||
bool NchwToFracZc04(const FormatArgs &args, void *result);
|
||||
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
|
||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
||||
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
|
||||
|
||||
// device to host
|
||||
bool ToNchw(const FormatArgs &args, void *result);
|
||||
bool FracZToNchw(const FormatArgs &args, void *result);
|
||||
|
@ -73,6 +75,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result);
|
|||
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -292,7 +292,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size
|
|||
if (host_shape.empty()) {
|
||||
host_shape.emplace_back(1);
|
||||
}
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) {
|
||||
if (type_id_ == type) {
|
||||
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
sync_ok = true;
|
||||
|
@ -454,7 +454,7 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js
|
|||
|
||||
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
|
||||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_);
|
||||
} else {
|
||||
if (host_shape_.empty()) {
|
||||
|
@ -531,7 +531,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
|||
if (host_shape.empty()) {
|
||||
host_shape.emplace_back(1);
|
||||
}
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
|
||||
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) {
|
||||
if (type_id_ == type) {
|
||||
SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE);
|
||||
sync_ok = true;
|
||||
|
@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
host_shape.emplace_back(1);
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
} else {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape);
|
||||
|
|
|
@ -81,6 +81,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|||
string priority_matched_format = kOpFormat_NC1HWC0;
|
||||
bool is_init = false;
|
||||
bool need_change_nd = false;
|
||||
bool is_5d_input = false;
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
|
||||
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
||||
if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
|
||||
|
@ -93,14 +94,21 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|||
priority_matched_format = kOpFormat_DEFAULT;
|
||||
}
|
||||
auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
|
||||
if (input_shape_size == 5) {
|
||||
is_5d_input = true;
|
||||
}
|
||||
need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
|
||||
}
|
||||
if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
|
||||
priority_matched_format = kOpFormat_DEFAULT;
|
||||
}
|
||||
if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
|
||||
priority_matched_format = kOpFormat_NDC1HWC0;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
|
||||
return priority_matched_format;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
|
||||
* if equal then next num location
|
||||
|
@ -157,7 +165,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) {
|
||||
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT ||
|
||||
kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
|
||||
}
|
||||
}
|
||||
|
@ -376,7 +385,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|||
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
||||
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
|
||||
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
||||
}
|
||||
builder->SetOutputsFormat(output_format);
|
||||
|
@ -386,7 +397,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
|
||||
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
||||
}
|
||||
builder->SetOutputsFormat(output_format);
|
||||
|
|
|
@ -386,11 +386,23 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
|
|||
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
|
||||
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
|
||||
constexpr auto kOpFormat_NDHWC = "NDHWC";
|
||||
constexpr auto kOpFormat_NCDHW = "NCDHW";
|
||||
constexpr auto kOpFormat_DHWNC = "DHWNC";
|
||||
constexpr auto kOpFormat_DHWCN = "DHWCN";
|
||||
constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0";
|
||||
constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D";
|
||||
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
|
||||
const std::set<std::string> kOpFormatList = {
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM};
|
||||
|
||||
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0,
|
||||
kOpFormat_ND, kOpFormat_NCHW,
|
||||
kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
|
||||
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04,
|
||||
kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM,
|
||||
kOpFormat_NDC1HWC0, kOpFormat_NCDHW,
|
||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC,
|
||||
kOpFormat_DHWCN};
|
||||
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
|
||||
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
||||
kApplyMomentumOpName,
|
||||
|
@ -427,8 +439,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
|||
kSparseApplyProximalAdagradOpName};
|
||||
|
||||
const std::set<std::string> kHWSpecialFormatSet = {
|
||||
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM};
|
||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
|
||||
|
||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
|
||||
|
|
Loading…
Reference in New Issue