diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index a7188614175..e44b5dd80d9 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -13,6 +13,7 @@ # limitations under the License. # =========================================================================== """Cost model splitter""" +import os from functools import reduce from .model import PrimLib, Graph, Tensor @@ -23,12 +24,19 @@ class GraphSplitByPattern: MODE_BASIC = 1 MODE_COMPOSITE = 2 + class StitchInfo: + """StitchInfo""" + def __init__(self): + self.stitch_ops = set() + self.stitch_atomic_ops = set() + def __init__(self, init_op, is_output): self.pattern = PrimLib.iter_type(init_op) self.ops = [init_op] self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = None + self.stitch_info = self.StitchInfo() self.is_output = is_output self.output_excluded = set() if self.pattern == PrimLib.REDUCE: @@ -69,6 +77,12 @@ class GraphSplitByPattern: for input_area, r in self.in_relations.items(): input_area.out_relations[self] = r + def update_stitch_info(self, stitch_info): + if stitch_info.stitch_ops: + self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) + if stitch_info.stitch_atomic_ops: + self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops) + def fuse(self, area): """Fuse `area` to `self`""" def _update_relation(relations, a, r): @@ -107,6 +121,7 @@ class GraphSplitByPattern: self.is_output = True if area.output_excluded: self.output_excluded.update(area.output_excluded) + self.update_stitch_info(area.stitch_info) def check_circle(self, to): """Check circle. It returns false if circle exists""" @@ -181,10 +196,25 @@ class GraphSplitByPattern: graphmodes = [] for i, area in enumerate(self.areas): area.ops.sort(key=lambda op: ids[op]) - subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops)) + subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info)) graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") return subgraphs, graphmodes + def dump_subgraphs(self, subgraphs): + """Dump subgraphs""" + if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on": + subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n" + for i, sub in enumerate(subgraphs): + subgraphs_str += str("============") + str(i) + "\n" + subgraphs_str += str(sub) + dirname = 'subgraphs' + if not os.path.exists(dirname): + os.makedirs(dirname) + graphname = self.graph.name + filename = dirname + '/' + graphname + '.log' + with open(filename, 'w') as f: + f.write(subgraphs_str) + def split(self): """Split graph by pattern""" self.do_split() @@ -192,6 +222,7 @@ class GraphSplitByPattern: # Note: after this function, the input output relation is not maintained. self.split_output_reshapes() subgraphs, graphmodes = self.to_subgraphs() + self.dump_subgraphs(subgraphs) return subgraphs, graphmodes def split_output_reshapes(self): @@ -362,15 +393,25 @@ class GraphSplitGpu(GraphSplitByPattern): return reduce_size >= 1024 return True + def _reduce_nums(ops): + count = 0 + for op in ops: + if op.prim.startswith('Reduce'): + count += 1 + return count + def _reduce_output(dom): if dom.pattern != PrimLib.REDUCE: return None + if _reduce_nums(dom.ops) > 1: + return None if _is_atomic_add_available(dom): return None is_all_reduce = _tensor_size(dom.ops[0].output) == 1 # excluded large size all reduce if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: return None + fused = [] for a, r in dom.out_relations.items(): if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ @@ -378,6 +419,24 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, False + def _reduce_stitch(dom): + if dom.pattern != PrimLib.REDUCE: + return None + if _tensor_size(dom.ops[0].output) == 1: + return None + if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: + return None + + fused = [] + for a, r in dom.out_relations.items(): + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): + if _reduce_nums(a.ops) < 2: + # softmax + if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: + dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + fused.append(a) + return fused, False + def _transpose(dom): if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": return None @@ -398,6 +457,7 @@ class GraphSplitGpu(GraphSplitByPattern): changed = self.fuse(_broadcast_width) or changed if use_poly_reduce: changed = self.fuse(_reduce_output) or changed + changed = self.fuse(_reduce_stitch) or changed self.fuse(_transpose) class GraphSplitAscend(GraphSplitByPattern): diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index f55391e9d70..7aeb669de18 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -310,11 +310,12 @@ class Operator: class Graph: """Graph""" - def __init__(self, name, ops): + def __init__(self, name, ops, stitch_info=None): self.name = name self.ops = ops # in topo order, can not use set self.inputs = [] self.outputs = [] + self.stitch_info = stitch_info def set_processor(self, processor): """Set processor""" @@ -372,6 +373,12 @@ class Graph: out_str = ', '.join([repr(t) for t in outputs]) lines = [] lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) + if self.stitch_info: + if self.stitch_info.stitch_ops: + lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) + if self.stitch_info.stitch_atomic_ops: + lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) + for op in self.ops: lines.append(' ' + str(op)) lines.append('}') @@ -405,12 +412,20 @@ class Graph: in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, - 'tensor_name': op.output.name, 'format': t.data_format}] + 'tensor_name': op.output.name, 'format': op.output.data_format}] op_desc.append({'attr': attrs, 'impl_path': '', 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) + graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, 'platform': 'AKG', 'process': self.processor} + + if self.stitch_info and self.stitch_info.stitch_ops: + buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} + if self.stitch_info.stitch_atomic_ops: + buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) + graph_desc['buffer_stitch'] = buffer_stitch + return graph_desc diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index f66cdb1edae..0b8c283bbcb 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -313,6 +313,14 @@ class CompositeGraph: self.graph = builder.get()[0] self.desc = desc + def add_stitch_info(self, subgraph, desc): + if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: + buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} + if subgraph.stitch_info.stitch_atomic_ops: + buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) + desc['buffer_stitch'] = buffer_stitch + return desc + def dump(self, subgraph): """Dump Graph to json""" desc = {} @@ -368,6 +376,8 @@ class CompositeGraph: desc[key] = subgraph.name else: desc[key] = self.desc[key] + + desc = self.add_stitch_info(subgraph, desc) return desc diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index ba7bb02ba34..7562ed6745c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -414,6 +414,35 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js return DecodeFusedNodes(kernel_json); } +StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) { + StitchInfo info; + if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { + nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; + if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { + std::vector stitch_ops = buffer_stitch[kJsonKeyStitchOp]; + info.stitch_ops = stitch_ops; + } + if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { + std::vector stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; + info.stitch_atomic_ops = stitch_atomic_ops; + } + } + return info; +} + +void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { + std::vector output_descs = op_desc[kJsonKeyOutputDesc]; + if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; + std::string tensor_name = output_descs[0][kJsonKeyTensorName]; + if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { + AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); + } + if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != + info.stitch_atomic_ops.end()) { + AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); + } +} + bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map &address_node_map, AnfNodePtrList *res_graphs) { @@ -425,6 +454,7 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; return false; } + StitchInfo info = GetStitchInfo(kernel_json); for (const auto &op_desc : op_node_descs) { if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; @@ -436,7 +466,9 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; return false; } - res_graphs->push_back(address_node_map.at(ptr_address)); + auto node = address_node_map.at(ptr_address)->cast(); + SetStitchAttr(op_desc, info, node); + res_graphs->push_back(node); } MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h index ac3e9331d35..74fad47a6af 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,10 @@ namespace mindspore { namespace kernel { +struct StitchInfo { + std::vector stitch_ops; + std::vector stitch_atomic_ops; +}; class AkgKernelJsonDecoder { public: AkgKernelJsonDecoder() { nodes_map_.clear(); } @@ -40,6 +44,8 @@ class AkgKernelJsonDecoder { ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); AnfNodePtr DecodeOutput(const std::vector &output_descs, const FuncGraphPtr &func_graph); + StitchInfo GetStitchInfo(const nlohmann::json &kernel_json); + void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node); std::map nodes_map_; }; } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 6d959adeccd..65f0c7d7f0d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -580,6 +580,31 @@ void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &process (*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json; } +void AkgKernelJsonGenerator::GenStitchJson(const std::vector &anf_nodes, + std::map *node_json_map, + nlohmann::json *kernel_json) { + std::vector stitchs; + for (auto const &anf_node : anf_nodes) { + if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast()) && + AnfAlgo::GetNodeAttr(anf_node, kAttrStitch) == "common") { + auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0}); + if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) { + stitchs.emplace_back(name); + } + } + } + if (!stitchs.empty()) { + std::vector v; + for (auto &s : stitchs) { + std::vector t; + t.emplace_back(s); + v.emplace_back(t); + } + nlohmann::json stitch_json; + stitch_json[kJsonKeyStitchOp] = v; + (*kernel_json)[kJsonKeyBufferStitch] = stitch_json; + } +} bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, const std::vector &output_list, nlohmann::json *kernel_json) { @@ -637,6 +662,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf (*kernel_json)[kJsonKeyComposite] = true; (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id(); + GenStitchJson(anf_nodes, &node_json_map, kernel_json); + if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { MS_LOG(ERROR) << "Cal mem size failed."; return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index 9f6f49cccc3..0efbc84393b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -54,6 +54,9 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; constexpr auto kJsonKeyFusionType = "fusion_type"; constexpr auto kJsonKeySubGraph = "sub_graph"; constexpr auto kJsonKeyCoreNum = "core_num"; +constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; +constexpr auto kJsonKeyStitchOp = "stitch_op"; +constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; constexpr auto kAttrInputNames = "input_names"; @@ -98,6 +101,8 @@ class AkgKernelJsonGenerator { void GetAttrJson(const AnfNodePtr &anf_node, const std::vector &dyn_input_sizes, const OpAttrPtr &op_attr, nlohmann::json *attr_json, const ValuePtr &attr_value); bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json); + void GenStitchJson(const std::vector &anf_nodes, std::map *node_json_map, + nlohmann::json *kernel_json); bool GetIOSize(const nlohmann::json &node_json, std::vector *input_size, std::vector *output_size); bool GenSingleJsons(const std::vector &anf_nodes, std::map *node_json_map); void UpdateTensorName(const std::vector &anf_nodes, std::map *node_json_map); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h index 34f0ea11faa..e6dffde93e3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include #include #include +#include #include "backend/optimizer/common/optimizer.h" #include "backend/session/kernel_graph.h" @@ -28,21 +29,24 @@ namespace mindspore { namespace opt { class AtomicCleanInsertter : public Pass { public: - AtomicCleanInsertter() : Pass("atomic_clean") {} + explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {} ~AtomicCleanInsertter() override = default; - bool Run(const FuncGraphPtr &func_graph) override; + virtual bool Run(const FuncGraphPtr &func_graph); - private: - void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, - const FuncGraphManagerPtr &mng); - bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); + protected: + virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); + virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, + const FuncGraphManagerPtr &mng); void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index); - void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); + CNodePtr atomic_add_node_{nullptr}; + + private: + bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); void CorrectAbstract(const AnfNodePtr &composite_node); - void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); + void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); std::tuple FindPatronNode(const KernelGraphPtr &main_graph); @@ -55,7 +59,6 @@ class AtomicCleanInsertter : public Pass { bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, const FuncGraphManagerPtr &mng); - CNodePtr atomic_add_node_{nullptr}; size_t reduce_real_output_index_{0}; size_t real_output_num_{0}; std::vector> to_process_order_; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc new file mode 100644 index 00000000000..935a621132f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/core_ops.h" +#include "ir/tensor.h" +#include "utils/utils.h" +#include "utils/log_adapter.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { + // Change kernel build info. + auto kernel_info = static_cast(composite_node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); + auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats(); + auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats(); + auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes(); + auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes(); + auto origin_processor = origin_kernel_build_info->processor(); + + std::vector &new_inputs_format = origin_inputs_format; + std::vector &new_inputs_type = origin_inputs_type; + std::vector new_outputs_format; + std::vector new_outputs_type; + for (size_t i = 0; i < origin_outputs_format.size(); ++i) { + new_outputs_format.push_back(origin_outputs_format[i]); + new_outputs_type.push_back(origin_outputs_type[i]); + } + + auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0); + new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); + new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); + + kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder; + new_info_builder.SetInputsFormat(new_inputs_format); + new_info_builder.SetInputsDeviceType(new_inputs_type); + new_info_builder.SetOutputsFormat(new_outputs_format); + new_info_builder.SetOutputsDeviceType(new_outputs_type); + new_info_builder.SetProcessor(origin_processor); + new_info_builder.SetKernelType(KernelType::AKG_KERNEL); + new_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto new_selected_info = new_info_builder.Build(); + AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); +} + +CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, + const AnfNodePtr &new_parameter) { + // add inplaceassign + AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true. + auto inplace_assign_node = + CreateCNode({NewValueNode(std::make_shared("InplaceAssign")), new_parameter, atomic_add_node_, out_node}, + sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)}); + AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), inplace_assign_node); + AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_); + AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), inplace_assign_node); + return inplace_assign_node; +} + +void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, + const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node); + auto mng_sub = sub_graph->manager(); + if (mng_sub == nullptr) { + mng_sub = Manage(sub_graph, false); + sub_graph->set_manager(mng_sub); + } + + // add input + auto inputs = composite_node->cast()->inputs(); + inputs.push_back(new_input); + composite_node->cast()->set_inputs(inputs); + + // add parameter + auto parameter = sub_graph->add_parameter(); + parameter->set_abstract(new_input->abstract()); + parameter->set_kernel_info(new_input->kernel_info_ptr()); + + auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter); + + // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid + // elimination. + std::vector> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_); + bool connected = false; + for (const auto &[user_node, index] : reduce_user_nodes) { + auto user_cnode = user_node->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + user_cnode->set_input(index, parameter); + if (!connected) { + std::vector> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); + if (!user_user.empty()) { + auto pair = user_user[0]; + AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second); + } + connected = true; + } + CorrectKernelBuildInfo(composite_node, new_input); + } + + auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); + sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); + MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; +} + +std::vector> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, + const CNodePtr &target) { + auto node = inner_node->cast(); + MS_EXCEPTION_IF_NULL(node); + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto mng_sub = sub_graph->manager(); + if (mng_sub == nullptr) { + mng_sub = Manage(sub_graph, false); + sub_graph->set_manager(mng_sub); + } + std::vector> inner_user_nodes; + auto users = mng_sub->node_users()[target]; + std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes), + [](const std::pair &pair) { return pair; }); + return inner_user_nodes; +} + +bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) { + if (!AnfAlgo::IsGraphKernel(anf_node)) return false; + auto node = anf_node->cast(); + MS_EXCEPTION_IF_NULL(node); + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + AnfNodePtrList kernel_nodes; + kernel::GetValidKernelNodes(sub_graph, &kernel_nodes); + for (auto &n : kernel_nodes) { + if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast()) && + AnfAlgo::GetNodeAttr(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) { + MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!"; + atomic_add_node_ = n->cast(); + stitch_node_ = anf_node; + return true; + } + } + return false; +} + +bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { + auto kernel_graph = std::dynamic_pointer_cast(func_graph); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + + bool changed = false; + auto topo_nodes = TopoSort(kernel_graph->get_return()); + for (const auto &node : topo_nodes) { + // if stitch attr exists, add atomic clean op depends on the attr + if (IsStitchWithAtomic(node)) { + InsertAtomicClean(kernel_graph, node, mng); + changed = true; + } + } + + if (changed) { + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + } + + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h new file mode 100644 index 00000000000..85b5f4d6904 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ + +#include +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +class StitchAtomicCleanInsertter : public AtomicCleanInsertter { + public: + StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {} + ~StitchAtomicCleanInsertter() override = default; + bool Run(const FuncGraphPtr &func_graph) override; + + private: + void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); + CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); + std::vector> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target); + void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, + const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng); + bool IsStitchWithAtomic(const AnfNodePtr &anf_node); + + AnfNodePtr stitch_node_{nullptr}; +}; +using StitchAtomicCleanInsertterPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 28c2ff8e2a4..cbffb2345f4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -717,6 +717,7 @@ std::unordered_set GetExpandOps() { prim::kPrimMinimumGrad, prim::kPrimGkDropout, prim::kPrimDropoutGrad, + prim::kPrimSoftMax, #endif }; return expand_ops; diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 236c70f197e..d525c2156fe 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -42,6 +42,7 @@ #include "backend/optimizer/gpu/add_relu_v2_fusion.h" #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" +#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" #include "backend/optimizer/graph_kernel/clean_all_in_once.h" @@ -201,6 +202,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ // will be exposed, use GetitemTuple Pass to delete them. pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); // Prevent fake loop in parallel fusion. pm->AddPass(std::make_shared(kGPUDevice, opt::ParallelConfig(7))); pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 3307b0399fb..4facb834b7e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -384,6 +384,7 @@ constexpr auto kAttrIsGrad = "is_grad"; constexpr auto kAttrRecompute = "recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; +constexpr auto kAttrStitch = "stitch"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 688b7fe6428..77523464bbf 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -132,7 +132,7 @@ inline const PrimitivePtr kPrimRange = std::make_shared("Range"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); -inline const PrimitivePtr kPrimSoftMax = std::make_shared("SoftMax"); +inline const PrimitivePtr kPrimSoftMax = std::make_shared("Softmax"); inline const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); inline const PrimitivePtr kPrimTanh = std::make_shared("Tanh");