!11777 stitch fusion

From: @r1chardf1d0
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-02 10:49:04 +08:00 committed by Gitee
commit a24ff36d9c
14 changed files with 431 additions and 19 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ===========================================================================
"""Cost model splitter"""
import os
from functools import reduce
from .model import PrimLib, Graph, Tensor
@ -23,12 +24,19 @@ class GraphSplitByPattern:
MODE_BASIC = 1
MODE_COMPOSITE = 2
class StitchInfo:
"""StitchInfo"""
def __init__(self):
self.stitch_ops = set()
self.stitch_atomic_ops = set()
def __init__(self, init_op, is_output):
self.pattern = PrimLib.iter_type(init_op)
self.ops = [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = None
self.stitch_info = self.StitchInfo()
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
@ -69,6 +77,12 @@ class GraphSplitByPattern:
for input_area, r in self.in_relations.items():
input_area.out_relations[self] = r
def update_stitch_info(self, stitch_info):
if stitch_info.stitch_ops:
self.stitch_info.stitch_ops.update(stitch_info.stitch_ops)
if stitch_info.stitch_atomic_ops:
self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops)
def fuse(self, area):
"""Fuse `area` to `self`"""
def _update_relation(relations, a, r):
@ -107,6 +121,7 @@ class GraphSplitByPattern:
self.is_output = True
if area.output_excluded:
self.output_excluded.update(area.output_excluded)
self.update_stitch_info(area.stitch_info)
def check_circle(self, to):
"""Check circle. It returns false if circle exists"""
@ -181,10 +196,25 @@ class GraphSplitByPattern:
graphmodes = []
for i, area in enumerate(self.areas):
area.ops.sort(key=lambda op: ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops))
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info))
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes
def dump_subgraphs(self, subgraphs):
"""Dump subgraphs"""
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
for i, sub in enumerate(subgraphs):
subgraphs_str += str("============") + str(i) + "\n"
subgraphs_str += str(sub)
dirname = 'subgraphs'
if not os.path.exists(dirname):
os.makedirs(dirname)
graphname = self.graph.name
filename = dirname + '/' + graphname + '.log'
with open(filename, 'w') as f:
f.write(subgraphs_str)
def split(self):
"""Split graph by pattern"""
self.do_split()
@ -192,6 +222,7 @@ class GraphSplitByPattern:
# Note: after this function, the input output relation is not maintained.
self.split_output_reshapes()
subgraphs, graphmodes = self.to_subgraphs()
self.dump_subgraphs(subgraphs)
return subgraphs, graphmodes
def split_output_reshapes(self):
@ -362,15 +393,25 @@ class GraphSplitGpu(GraphSplitByPattern):
return reduce_size >= 1024
return True
def _reduce_nums(ops):
count = 0
for op in ops:
if op.prim.startswith('Reduce'):
count += 1
return count
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if _reduce_nums(dom.ops) > 1:
return None
if _is_atomic_add_available(dom):
return None
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
# excluded large size all reduce
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
@ -378,6 +419,24 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a)
return fused, False
def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if _tensor_size(dom.ops[0].output) == 1:
return None
if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a):
if _reduce_nums(a.ops) < 2:
# softmax
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False
def _transpose(dom):
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
return None
@ -398,6 +457,7 @@ class GraphSplitGpu(GraphSplitByPattern):
changed = self.fuse(_broadcast_width) or changed
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)
class GraphSplitAscend(GraphSplitByPattern):

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -310,11 +310,12 @@ class Operator:
class Graph:
"""Graph"""
def __init__(self, name, ops):
def __init__(self, name, ops, stitch_info=None):
self.name = name
self.ops = ops # in topo order, can not use set
self.inputs = []
self.outputs = []
self.stitch_info = stitch_info
def set_processor(self, processor):
"""Set processor"""
@ -372,6 +373,12 @@ class Graph:
out_str = ', '.join([repr(t) for t in outputs])
lines = []
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
if self.stitch_info:
if self.stitch_info.stitch_ops:
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
if self.stitch_info.stitch_atomic_ops:
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
for op in self.ops:
lines.append(' ' + str(op))
lines.append('}')
@ -405,12 +412,20 @@ class Graph:
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
'tensor_name': op.output.name, 'format': t.data_format}]
'tensor_name': op.output.name, 'format': op.output.data_format}]
op_desc.append({'attr': attrs, 'impl_path': '',
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
'platform': 'AKG', 'process': self.processor}
if self.stitch_info and self.stitch_info.stitch_ops:
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
if self.stitch_info.stitch_atomic_ops:
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
graph_desc['buffer_stitch'] = buffer_stitch
return graph_desc

View File

@ -313,6 +313,14 @@ class CompositeGraph:
self.graph = builder.get()[0]
self.desc = desc
def add_stitch_info(self, subgraph, desc):
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
if subgraph.stitch_info.stitch_atomic_ops:
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
desc['buffer_stitch'] = buffer_stitch
return desc
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
@ -368,6 +376,8 @@ class CompositeGraph:
desc[key] = subgraph.name
else:
desc[key] = self.desc[key]
desc = self.add_stitch_info(subgraph, desc)
return desc

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -414,6 +414,35 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
return DecodeFusedNodes(kernel_json);
}
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) {
StitchInfo info;
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
info.stitch_ops = stitch_ops;
}
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
info.stitch_atomic_ops = stitch_atomic_ops;
}
}
return info;
}
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
}
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
info.stitch_atomic_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
}
}
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs) {
@ -425,6 +454,7 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
return false;
}
StitchInfo info = GetStitchInfo(kernel_json);
for (const auto &op_desc : op_node_descs) {
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
@ -436,7 +466,9 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
return false;
}
res_graphs->push_back(address_node_map.at(ptr_address));
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
SetStitchAttr(op_desc, info, node);
res_graphs->push_back(node);
}
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
return true;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,10 @@
namespace mindspore {
namespace kernel {
struct StitchInfo {
std::vector<std::string> stitch_ops;
std::vector<std::string> stitch_atomic_ops;
};
class AkgKernelJsonDecoder {
public:
AkgKernelJsonDecoder() { nodes_map_.clear(); }
@ -40,6 +44,8 @@ class AkgKernelJsonDecoder {
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json);
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node);
std::map<std::string, AnfNodePtr> nodes_map_;
};
} // namespace kernel

View File

@ -580,6 +580,31 @@ void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &process
(*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
}
void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map,
nlohmann::json *kernel_json) {
std::vector<std::string> stitchs;
for (auto const &anf_node : anf_nodes) {
if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast<CNodePtr>()) &&
AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrStitch) == "common") {
auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0});
if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) {
stitchs.emplace_back(name);
}
}
}
if (!stitchs.empty()) {
std::vector<nlohmann::json> v;
for (auto &s : stitchs) {
std::vector<std::string> t;
t.emplace_back(s);
v.emplace_back(t);
}
nlohmann::json stitch_json;
stitch_json[kJsonKeyStitchOp] = v;
(*kernel_json)[kJsonKeyBufferStitch] = stitch_json;
}
}
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) {
@ -637,6 +662,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
GenStitchJson(anf_nodes, &node_json_map, kernel_json);
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;

View File

@ -54,6 +54,9 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion";
constexpr auto kJsonKeyFusionType = "fusion_type";
constexpr auto kJsonKeySubGraph = "sub_graph";
constexpr auto kJsonKeyCoreNum = "core_num";
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
constexpr auto kJsonKeyStitchOp = "stitch_op";
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
constexpr auto kAttrInputNames = "input_names";
@ -98,6 +101,8 @@ class AkgKernelJsonGenerator {
void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr,
nlohmann::json *attr_json, const ValuePtr &attr_value);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json);
void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map,
nlohmann::json *kernel_json);
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size, std::vector<size_t> *output_size);
bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include <tuple>
#include <utility>
#include <vector>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
@ -28,21 +29,24 @@ namespace mindspore {
namespace opt {
class AtomicCleanInsertter : public Pass {
public:
AtomicCleanInsertter() : Pass("atomic_clean") {}
explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {}
~AtomicCleanInsertter() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
virtual bool Run(const FuncGraphPtr &func_graph);
private:
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
protected:
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index);
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
CNodePtr atomic_add_node_{nullptr};
private:
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
void CorrectAbstract(const AnfNodePtr &composite_node);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph);
@ -55,7 +59,6 @@ class AtomicCleanInsertter : public Pass {
bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
const FuncGraphManagerPtr &mng);
CNodePtr atomic_add_node_{nullptr};
size_t reduce_real_output_index_{0};
size_t real_output_num_{0};
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_;

View File

@ -0,0 +1,200 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
#include <algorithm>
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <utility>
#include <set>
#include <stack>
#include <string>
#include <tuple>
#include <vector>
#include "base/core_ops.h"
#include "ir/tensor.h"
#include "utils/utils.h"
#include "utils/log_adapter.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
// Change kernel build info.
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
auto origin_processor = origin_kernel_build_info->processor();
std::vector<std::string> &new_inputs_format = origin_inputs_format;
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
std::vector<std::string> new_outputs_format;
std::vector<TypeId> new_outputs_type;
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
new_outputs_format.push_back(origin_outputs_format[i]);
new_outputs_type.push_back(origin_outputs_type[i]);
}
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
new_info_builder.SetInputsFormat(new_inputs_format);
new_info_builder.SetInputsDeviceType(new_inputs_type);
new_info_builder.SetOutputsFormat(new_outputs_format);
new_info_builder.SetOutputsDeviceType(new_outputs_type);
new_info_builder.SetProcessor(origin_processor);
new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto new_selected_info = new_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
// add inplaceassign
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
auto inplace_assign_node =
CreateCNode({NewValueNode(std::make_shared<Primitive>("InplaceAssign")), new_parameter, atomic_add_node_, out_node},
sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), inplace_assign_node);
AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), inplace_assign_node);
return inplace_assign_node;
}
void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
// add input
auto inputs = composite_node->cast<CNodePtr>()->inputs();
inputs.push_back(new_input);
composite_node->cast<CNodePtr>()->set_inputs(inputs);
// add parameter
auto parameter = sub_graph->add_parameter();
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
// elimination.
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
bool connected = false;
for (const auto &[user_node, index] : reduce_user_nodes) {
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, parameter);
if (!connected) {
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
if (!user_user.empty()) {
auto pair = user_user[0];
AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
}
connected = true;
}
CorrectKernelBuildInfo(composite_node, new_input);
}
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) {
auto node = inner_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
auto users = mng_sub->node_users()[target];
std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
[](const std::pair<AnfNodePtr, int> &pair) { return pair; });
return inner_user_nodes;
}
bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
AnfNodePtrList kernel_nodes;
kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
for (auto &n : kernel_nodes) {
if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
atomic_add_node_ = n->cast<CNodePtr>();
stitch_node_ = anf_node;
return true;
}
}
return false;
}
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
bool changed = false;
auto topo_nodes = TopoSort(kernel_graph->get_return());
for (const auto &node : topo_nodes) {
// if stitch attr exists, add atomic clean op depends on the attr
if (IsStitchWithAtomic(node)) {
InsertAtomicClean(kernel_graph, node, mng);
changed = true;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
public:
StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {}
~StitchAtomicCleanInsertter() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target);
void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng);
bool IsStitchWithAtomic(const AnfNodePtr &anf_node);
AnfNodePtr stitch_node_{nullptr};
};
using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -717,6 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimMinimumGrad,
prim::kPrimGkDropout,
prim::kPrimDropoutGrad,
prim::kPrimSoftMax,
#endif
};
return expand_ops;

View File

@ -42,6 +42,7 @@
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/clean_all_in_once.h"
@ -201,6 +202,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>());
pm->AddPass(std::make_shared<opt::StitchAtomicCleanInsertter>());
pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion.
pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7)));
pm->AddPass(std::make_shared<opt::BindValueToGraph>());

View File

@ -384,6 +384,7 @@ constexpr auto kAttrIsGrad = "is_grad";
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrStitch = "stitch";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -132,7 +132,7 @@ inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("SoftMax");
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax");
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");