forked from mindspore-Ecosystem/mindspore
Init GraphKernel.
- It provides a unified style to express graph and kernel for user. - It provides a unified IR to represent graph and kernel for developer. - It breaks the boundary between graph and kernel. - It provides more opportunities to do compile optimization.
This commit is contained in:
parent
01216a9a57
commit
a6dfa281ea
|
@ -13,3 +13,6 @@
|
|||
[submodule "graphengine"]
|
||||
path = graphengine
|
||||
url = https://gitee.com/mindspore/graphengine.git
|
||||
[submodule "akg"]
|
||||
path = akg
|
||||
url = https://gitee.com/mindspore/akg.git
|
||||
|
|
|
@ -86,10 +86,14 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
endif()
|
||||
|
||||
if (ENABLE_AKG AND ENABLE_D)
|
||||
add_subdirectory("${CMAKE_SOURCE_DIR}/akg")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
if (ENABLE_TESTCASES)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
include(cmake/package.cmake)
|
||||
include(cmake/package.cmake)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit c460176523d039c8995f1d71089753725ebc0792
|
5
build.sh
5
build.sh
|
@ -246,6 +246,9 @@ checkopts "$@"
|
|||
echo "---------------- mindspore: build start ----------------"
|
||||
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
|
||||
git submodule update --init graphengine
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
git submodule update --init --recursive akg
|
||||
fi
|
||||
|
||||
build_exit()
|
||||
{
|
||||
|
@ -308,7 +311,7 @@ build_mindspore()
|
|||
if [[ "X$USE_GLOG" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
||||
fi
|
||||
echo "${CMAKE_ARGS}"
|
||||
|
|
|
@ -236,6 +236,16 @@ if (ENABLE_GPU)
|
|||
endif ()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_D AND ENABLE_AKG)
|
||||
set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg)
|
||||
install(
|
||||
DIRECTORY
|
||||
${AKG_PATH}/akg
|
||||
DESTINATION ${INSTALL_PY_DIR}/..
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
|
||||
install(
|
||||
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing akg compile with json"""
|
||||
import sys
|
||||
def run_compiler(op_json):
|
||||
"""
|
||||
Run AKG compiler to compile op with subprocess, if this process of
|
||||
compilation failed, an exception will be raised
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
p = __import__("akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
res = func(op_json)
|
||||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_compiler(sys.argv[1])
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
|
||||
def _compile_akg_task(*json_strs):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
||||
"""
|
||||
akg_compiler = os.path.join(os.path.split(
|
||||
os.path.realpath(__file__))[0], "compiler.py")
|
||||
for json_str in json_strs:
|
||||
res = subprocess.run(
|
||||
[sys.executable, akg_compiler, json_str], text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Failed, args: {}!".format(json_str))
|
||||
|
||||
|
||||
def compile_akg_kernel_parallel(json_infos, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
json_infos: list. list contain kernel info(task id and json str)
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
|
||||
Returns:
|
||||
True for all compile success, False for some failed.
|
||||
"""
|
||||
if not isinstance(json_infos, list):
|
||||
raise ValueError("json_infos must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
if process == 0 and json_infos:
|
||||
process = 1
|
||||
|
||||
cpu_proc_num = cpu_count()
|
||||
max_proc_num = 16
|
||||
process = min([cpu_proc_num, max_proc_num, process])
|
||||
|
||||
args = [[] for _ in range(process)]
|
||||
for p, info in enumerate(json_infos):
|
||||
args[p % process].append(info)
|
||||
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compile_akg_task, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
|
@ -1,107 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def _compiletask(platform, *jsons):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
platform: str. AKG platform or TBE platform
|
||||
*jsons: str. json str contain kernel info, suitable for json compile
|
||||
api
|
||||
|
||||
"""
|
||||
if platform == "AKG":
|
||||
p = __import__("_akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
for json_item in jsons:
|
||||
res = func(json_item)
|
||||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
if platform == "TBE":
|
||||
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
|
||||
for json_item in jsons:
|
||||
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Tbe compile error")
|
||||
|
||||
|
||||
def compilekernelparallel(jsons, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
jsons: list. json str list contain kernel info
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
"""
|
||||
if not isinstance(jsons, list):
|
||||
raise ValueError("jsons must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
jsons_akg = []
|
||||
jsons_tbe = []
|
||||
for json_ in jsons:
|
||||
j = json.loads(json_)
|
||||
if j["platform"] == "TBE":
|
||||
jsons_tbe.append(json_)
|
||||
continue
|
||||
if j["platform"] == "AKG":
|
||||
jsons_akg.append(json_)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"not support this platform {0}".format(j["platform"]))
|
||||
if jsons_akg:
|
||||
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
|
||||
else:
|
||||
process_akg = 0
|
||||
|
||||
if process_akg == 0 and jsons_akg:
|
||||
process_akg = 1
|
||||
process_tbe = process-process_akg
|
||||
if process_tbe == 0 and jsons_tbe:
|
||||
process_tbe = 1
|
||||
raise RuntimeWarning("we add a process for compile more operator")
|
||||
|
||||
args = [[] for _ in range(process_akg+process_tbe)]
|
||||
args_lens = len(args)
|
||||
for p in range(args_lens):
|
||||
if p < process_tbe:
|
||||
args[p].append("TBE")
|
||||
else:
|
||||
args[p].append("AKG")
|
||||
jsons_tbe_lens = len(jsons_tbe)
|
||||
for p in range(jsons_tbe_lens):
|
||||
args[p % process_tbe].append(jsons_tbe[p])
|
||||
jsons_akg_lens = len(jsons_akg)
|
||||
for p in range(jsons_akg_lens):
|
||||
args[process-p % process_akg-1].append(jsons_akg[p])
|
||||
for p in range(args_lens):
|
||||
args[p] = tuple(args[p])
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compiletask, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
|
@ -39,7 +39,7 @@ if(ENABLE_GPU)
|
|||
"device/gpu/*.cu"
|
||||
"kernel/gpu/*.cu"
|
||||
"kernel/akg/gpu/*.cc"
|
||||
"kernel/akg/akgkernelbuild.cc"
|
||||
"kernel/akg/akg_kernel_build.cc"
|
||||
"kernel/akg/akg_kernel_attrs_process.cc"
|
||||
)
|
||||
|
||||
|
|
|
@ -428,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
auto temp_shape = shape;
|
||||
std::vector<size_t> device_shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
|
||||
// For [1] and [1024] shape we can trait it as NZ shape
|
||||
return shape;
|
||||
}
|
||||
if (shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
|
||||
} else {
|
||||
|
|
|
@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
|
|||
}
|
||||
|
||||
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
|
||||
buffer << "#flags :" << std::endl;
|
||||
for (const auto &flag : graph->flags()) {
|
||||
buffer << flag.first << " : " << flag.second << std::endl;
|
||||
buffer << "#attrs :" << std::endl;
|
||||
for (const auto &attr : graph->attrs()) {
|
||||
buffer << attr.first << " : ";
|
||||
if (attr.second->isa<BoolImm>()) {
|
||||
buffer << GetValue<bool>(attr.second);
|
||||
} else if (attr.second->isa<StringImm>()) {
|
||||
buffer << GetValue<std::string>(attr.second);
|
||||
}
|
||||
buffer << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
|
|||
fout << std::endl;
|
||||
|
||||
for (const auto &sg : *sub_graphs) {
|
||||
fout << "subgraph flag:" << std::endl;
|
||||
fout << "subgraph attr:" << std::endl;
|
||||
MS_EXCEPTION_IF_NULL(sg.first);
|
||||
for (const auto &flag : sg.first->flags()) {
|
||||
fout << flag.first << " : " << flag.second << std::endl;
|
||||
for (const auto &attr : sg.first->attrs()) {
|
||||
fout << attr.first << " : ";
|
||||
if (attr.second->isa<BoolImm>()) {
|
||||
fout << GetValue<bool>(attr.second);
|
||||
} else if (attr.second->isa<StringImm>()) {
|
||||
fout << GetValue<std::string>(attr.second);
|
||||
}
|
||||
fout << std::endl;
|
||||
}
|
||||
fout << "subgraph @" << sg.first->ToString() << ".";
|
||||
fout << sg.first->debug_info()->get_id() << "(";
|
||||
|
|
|
@ -548,9 +548,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
|
|||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
ValuePtr value_ptr = nullptr;
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
if (primitive != nullptr) {
|
||||
value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
} else {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
value_ptr = func_graph->get_attr(kStreamNeedActivedFirst);
|
||||
}
|
||||
if (value_ptr == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -26,10 +26,12 @@
|
|||
#include "kernel/kernel.h"
|
||||
#include "kernel/tbe/tbe_kernel_build.h"
|
||||
#include "kernel/tbe/tbe_kernel_parallel_build.h"
|
||||
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
|
||||
#include "kernel/aicpu/aicpu_kernel_build.h"
|
||||
#include "kernel/hccl/hccl_kernel_build.h"
|
||||
#include "kernel/rts/rt_kernel_build.h"
|
||||
#include "kernel/tbe/tbe_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "./common.h"
|
||||
|
@ -91,6 +93,7 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph
|
|||
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
std::vector<AnfNodePtr> tbe_nodes;
|
||||
std::vector<AnfNodePtr> akg_nodes;
|
||||
std::vector<AnfNodePtr> other_nodes;
|
||||
for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
@ -105,19 +108,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke
|
|||
}
|
||||
break;
|
||||
}
|
||||
case KernelType::AKG_KERNEL: {
|
||||
akg_nodes.push_back(anf_node);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
other_nodes.push_back(anf_node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool ret = kernel::TbeOpParallelBuild(tbe_nodes);
|
||||
bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
|
||||
bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes);
|
||||
auto bin_map = kernel::tbe::KernelMeta::GetInstance();
|
||||
(void)bin_map->ReadIndex(kernel::kCceKernelMeta);
|
||||
for (const auto &anf_node : other_nodes) {
|
||||
kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||
}
|
||||
return ret;
|
||||
return tbe_ret && akg_ret;
|
||||
}
|
||||
|
||||
static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
|
||||
|
@ -234,7 +244,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|||
for (const auto &anf_node : kernel_graph->execution_order()) {
|
||||
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) {
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
||||
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
||||
auto new_value_node = NewValueNode(clear_zero_prim);
|
||||
|
|
|
@ -15,16 +15,27 @@
|
|||
*/
|
||||
|
||||
#include "device/ascend/kernel_select_ascend.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/kernel_query.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "kernel/kernel_query.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
}
|
||||
auto pri_match_format = GetPriorityMatchFormat(kernel_node);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
auto input_anf_node = kernel_node->input(input_index + 1);
|
||||
// we do not take ValueNode into consideration in graph kernel.
|
||||
if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
|
||||
if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
|
||||
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
||||
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
|
||||
// we match output fix precision first.
|
||||
auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
|
||||
if (prev_device_type == kTypeUnknown) {
|
||||
prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
|
||||
|
@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
}
|
||||
}
|
||||
|
||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
auto real_input_node = input_with_index.first;
|
||||
if (real_input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
bool is_ref = false;
|
||||
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
||||
if (op_info != nullptr) {
|
||||
is_ref = op_info->is_ref();
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode &&
|
||||
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
// we set special device info of a input tensor.
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
|
||||
MS_EXCEPTION_IF_NULL(support_index);
|
||||
int index = kUnSupportMixedDataTypeIndex;
|
||||
|
@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
auto real_input_node = input_with_index.first;
|
||||
if (real_input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
||||
continue;
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
||||
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
||||
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
|
||||
continue;
|
||||
}
|
||||
// we set special device info of a input tensor.
|
||||
bool is_ref = false;
|
||||
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
||||
if (op_info != nullptr) {
|
||||
is_ref = op_info->is_ref();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
||||
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
||||
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
|||
return select_status;
|
||||
}
|
||||
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
||||
if (AnfAlgo::IsGraphKernel(kernel_node)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
SelectGraphKernelInfo(kernel_node, func_graph);
|
||||
return kStatusAllMatched;
|
||||
}
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
|
||||
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
|
||||
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
|
||||
if (select_status == kNoMatched) {
|
||||
|
|
|
@ -27,7 +27,10 @@ enum KernelSelectStatus {
|
|||
kStatusReducePrecision = 1,
|
||||
kStatusRaisePrecision = 2,
|
||||
};
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node);
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node,
|
||||
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node);
|
||||
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,516 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "device/ascend/kernel_select_ascend.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "kernel/kernel_query.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
TypeId except_type = kTypeUnknown;
|
||||
if (primitive->GetAttr(kAttrFixPrecision) != nullptr) {
|
||||
auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision));
|
||||
if (strExceptDtype == "float16") {
|
||||
except_type = kNumberTypeFloat16;
|
||||
} else if (strExceptDtype == "float32") {
|
||||
except_type = kNumberTypeFloat32;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype;
|
||||
}
|
||||
}
|
||||
|
||||
return except_type;
|
||||
}
|
||||
|
||||
void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
||||
if (!kernel::IsWeightBoundary(kernel_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
// reset format and dtype.
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) {
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
// select nodes in subgraph.
|
||||
auto anf_node = node_list[i];
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto fix_precision_type = GetPrimitivePrecision(cnode);
|
||||
if (fix_precision_type != kTypeUnknown) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL);
|
||||
|
||||
for (size_t index = 0; index < kernel_info_list.size(); ++index)
|
||||
// only math the first input
|
||||
if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type &&
|
||||
kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) &&
|
||||
AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) {
|
||||
auto selected_kernel_info_ptr = kernel_info_list[index];
|
||||
ResetKernelBuildInfo(cnode);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get());
|
||||
SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) {
|
||||
for (size_t i = 1; i <= shape.size(); ++i) {
|
||||
if (i > 2) {
|
||||
break;
|
||||
}
|
||||
if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int> &axis) {
|
||||
std::vector<int> frac_nz_axis = axis;
|
||||
auto shape_len = ori_shape.size();
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len;
|
||||
if (axis_idx == shape_len - 1) {
|
||||
frac_nz_axis[i] = axis_idx - 1;
|
||||
frac_nz_axis.push_back(axis_idx + 2);
|
||||
} else if (axis_idx == shape_len - 2) {
|
||||
frac_nz_axis[i] = axis_idx + 1;
|
||||
frac_nz_axis.push_back(axis_idx + 2);
|
||||
} else {
|
||||
frac_nz_axis[i] = axis_idx;
|
||||
}
|
||||
}
|
||||
return frac_nz_axis;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int> &axis,
|
||||
bool keep_dims) {
|
||||
std::vector<size_t> result;
|
||||
std::set<size_t> positive_idx;
|
||||
for (const auto &a : axis) {
|
||||
positive_idx.insert(a >= 0 ? a : ori_shape.size() + a);
|
||||
}
|
||||
for (size_t i = 0; i < ori_shape.size(); ++i) {
|
||||
if (positive_idx.count(i) == 0) {
|
||||
result.push_back(ori_shape[i]);
|
||||
} else if (keep_dims) {
|
||||
result.push_back(1);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void UpdateFracNZReduceOp(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
|
||||
if (input_format == kOpFormat_FRAC_NZ) {
|
||||
// Clone primitive to modify it
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
auto new_prim = std::make_shared<Primitive>(*prim);
|
||||
auto new_prim_node = NewValueNode(new_prim);
|
||||
cnode->set_input(0, new_prim_node);
|
||||
|
||||
auto axis_value = new_prim->GetAttr(kAttrAxis);
|
||||
std::vector<int> default_axis;
|
||||
if (axis_value->isa<ValueList>()) {
|
||||
auto value_list = dyn_cast<ValueList>(axis_value);
|
||||
for (const auto &item : value_list->value()) {
|
||||
if (item->isa<Int32Imm>()) {
|
||||
default_axis.push_back(GetValue<int32_t>(item));
|
||||
}
|
||||
}
|
||||
} else if (axis_value->isa<ValueTuple>()) {
|
||||
auto value_tuple = dyn_cast<ValueTuple>(axis_value);
|
||||
for (const auto &item : value_tuple->value()) {
|
||||
if (item->isa<Int32Imm>()) {
|
||||
default_axis.push_back(GetValue<int32_t>(item));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Axis attr type is not correct!";
|
||||
}
|
||||
auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
std::vector<int> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis);
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int>>(frac_nz_axis), cnode);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
if (output_shape.size() == 1) {
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(default_format);
|
||||
MS_EXCEPTION_IF_NULL(use_same_format);
|
||||
std::unordered_map<std::string, size_t> all_input_formats;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (!input_kernel_node->isa<Parameter>()) {
|
||||
auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
|
||||
++all_input_formats[pre_format];
|
||||
continue;
|
||||
}
|
||||
auto para = input_kernel_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
|
||||
auto pre_format = AnfAlgo::GetOutputFormat(para, 0);
|
||||
++all_input_formats[pre_format];
|
||||
continue;
|
||||
}
|
||||
*use_same_format = false;
|
||||
}
|
||||
|
||||
if (all_input_formats.empty()) {
|
||||
// all inputs are parameter.
|
||||
*default_format = kOpFormat_NC1HWC0;
|
||||
} else {
|
||||
std::vector<std::pair<std::string, size_t>> pairs;
|
||||
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
|
||||
pairs.push_back(std::make_pair(iter->first, iter->second));
|
||||
}
|
||||
auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
|
||||
if (a.second != b.second) {
|
||||
return a.second > b.second;
|
||||
} else if (a.first == kOpFormat_DEFAULT) {
|
||||
return a.second + 1 > b.second;
|
||||
} else if (b.first == kOpFormat_DEFAULT) {
|
||||
return a.second > b.second + 1;
|
||||
}
|
||||
return a.second > b.second;
|
||||
};
|
||||
std::sort(pairs.begin(), pairs.end(), cmp_func);
|
||||
*default_format = pairs.begin()->first;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (!input_kernel_node->isa<Parameter>() ||
|
||||
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0);
|
||||
if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) {
|
||||
*default_format = kOpFormat_DEFAULT;
|
||||
*use_same_format = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
|
||||
const std::string &default_format, bool use_same_format,
|
||||
std::vector<std::string> *graph_input_format,
|
||||
std::vector<TypeId> *graph_input_type) {
|
||||
MS_EXCEPTION_IF_NULL(graph_input_format);
|
||||
MS_EXCEPTION_IF_NULL(graph_input_type);
|
||||
// We set same format to all inputs of graph kernel subgraph, and process this latter.
|
||||
// We set dtype to inputs of graph kernel subgraph same as infer dtypes.
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (use_same_format) {
|
||||
bool can_convert = true;
|
||||
if (default_format == kOpFormat_FRAC_NZ) {
|
||||
auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
if (!CanConvertDefaultShapeToNZ(infer_shape)) {
|
||||
MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead";
|
||||
can_convert = false;
|
||||
}
|
||||
}
|
||||
if (can_convert) {
|
||||
graph_input_format->push_back(default_format);
|
||||
} else {
|
||||
graph_input_format->push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!input_kernel_node->isa<Parameter>()) {
|
||||
// subgraph parameter from output of other nodes.
|
||||
graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i));
|
||||
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto para = input_kernel_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
|
||||
// parameter already selected.
|
||||
graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0));
|
||||
graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// weight parameter.
|
||||
graph_input_format->push_back(default_format);
|
||||
graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
|
||||
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
|
||||
const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
// select nodes in subgraph.
|
||||
auto anf_node = node_list[i];
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
SelectKernelInfo(cnode, KernelType::AKG_KERNEL);
|
||||
// Update ReduceSum
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) {
|
||||
continue;
|
||||
}
|
||||
UpdateFracNZReduceOp(cnode);
|
||||
// If ReduceSum's output is 1d and not Default format, convert it to Default format
|
||||
auto out_format = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||
if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
|
||||
continue;
|
||||
}
|
||||
auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
// Insert EquivFormat node, then select kernel info again
|
||||
std::vector<AnfNodePtr> trans_inputs;
|
||||
trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
|
||||
trans_inputs.push_back(cnode);
|
||||
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node);
|
||||
|
||||
if (trans_node->kernel_info() == nullptr) {
|
||||
trans_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
}
|
||||
SelectKernelInfo(trans_node, KernelType::AKG_KERNEL);
|
||||
mng->Replace(cnode, trans_node);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng,
|
||||
const std::string &default_format, std::vector<std::string> *graph_input_format,
|
||||
std::vector<TypeId> *graph_input_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(graph_input_format);
|
||||
MS_EXCEPTION_IF_NULL(graph_input_type);
|
||||
// update graph input format and dtype use inner ops.
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (graph_input_format->size() != input_num) {
|
||||
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
|
||||
<< "], [%" << graph_input_format->size() << "] != [%" << input_num << "]";
|
||||
}
|
||||
std::vector<bool> need_update(input_num, false);
|
||||
auto &node_users = mng->node_users();
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto &input = input_list[i];
|
||||
auto iter = node_users.find(input);
|
||||
if (iter == node_users.end() || iter->second.empty()) {
|
||||
continue;
|
||||
}
|
||||
for (auto &node_user : iter->second) {
|
||||
if (node_user.first->kernel_info() == nullptr ||
|
||||
node_user.first->kernel_info()->select_kernel_build_info() == nullptr) {
|
||||
// maybe not a real kernel.
|
||||
continue;
|
||||
}
|
||||
auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1));
|
||||
if (user_format != (*graph_input_format)[i]) {
|
||||
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
|
||||
<< kernel_node->DebugString()
|
||||
<< "] selected different format. we use defult: " << default_format;
|
||||
(*graph_input_format)[i] = default_format;
|
||||
need_update[i] = true;
|
||||
}
|
||||
|
||||
if (kernel_node->input(i + 1)->isa<Parameter>()) {
|
||||
auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1));
|
||||
if (user_dtype != (*graph_input_type)[i]) {
|
||||
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
|
||||
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
|
||||
<< kernel_node->DebugString()
|
||||
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
|
||||
(*graph_input_type)[i] = default_dtype;
|
||||
need_update[i] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (!need_update[i]) {
|
||||
continue;
|
||||
}
|
||||
need_update[i] = false;
|
||||
|
||||
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
|
||||
<< "] to: " << (*graph_input_format)[i];
|
||||
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
|
||||
<< "] to: " << TypeIdLabel((*graph_input_type)[i]);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
|
||||
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
||||
}
|
||||
|
||||
ResetKernelBuildInfo(kernel_node);
|
||||
// select nodes in subgraph again.
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
auto anf_node = node_list[i];
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t j = 0; j < cnode_input_num; ++j) {
|
||||
auto input_node = cnode->input(j + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (!IsValueNode<tensor::Tensor>(input_node)) {
|
||||
continue;
|
||||
}
|
||||
// reset format and dtype of const tensor.
|
||||
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get());
|
||||
}
|
||||
SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL);
|
||||
}
|
||||
}
|
||||
|
||||
void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
|
||||
const std::vector<std::string> &graph_input_format,
|
||||
const std::vector<TypeId> &graph_input_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<std::string> graph_output_format;
|
||||
std::vector<TypeId> graph_output_type;
|
||||
for (size_t i = 0; i < output_index.size(); ++i) {
|
||||
auto const &output = output_index[i];
|
||||
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
|
||||
TypeId output_type(kTypeUnknown);
|
||||
if (output.first->isa<CNode>()) {
|
||||
output_type = AnfAlgo::GetCNodeOutputPrecision(output.first);
|
||||
}
|
||||
if (output_type == kTypeUnknown) {
|
||||
output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
|
||||
}
|
||||
graph_output_type.push_back(output_type);
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto graph_selected_info = graph_info_builder.Build();
|
||||
MS_EXCEPTION_IF_NULL(graph_selected_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
|
||||
SetTensorDeviceInfo(*graph_selected_info, kernel_node);
|
||||
}
|
||||
|
||||
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
// collect input info of funcgraph
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
std::vector<AnfNodePtr> input_list;
|
||||
std::vector<AnfNodePtr> output_list;
|
||||
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
||||
if (input_list.size() != kernel_node->inputs().size() - 1) {
|
||||
MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode["
|
||||
<< kernel_node->DebugString() << "], [%" << input_list.size() << "] != ["
|
||||
<< kernel_node->inputs().size() << "]";
|
||||
}
|
||||
|
||||
std::string default_format;
|
||||
bool use_same_format = true;
|
||||
GetDefaultFormat(kernel_node, &default_format, &use_same_format);
|
||||
MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format
|
||||
<< "] for ParameterWeight.";
|
||||
|
||||
std::vector<std::string> graph_input_format;
|
||||
std::vector<TypeId> graph_input_type;
|
||||
UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
|
||||
&graph_input_type);
|
||||
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
}
|
||||
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
||||
UpdateEquivFormat(output_index, node_list, func_graph, mng);
|
||||
node_list.clear();
|
||||
input_list.clear();
|
||||
output_list.clear();
|
||||
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
||||
|
||||
// update graph input format and dtype use inner ops.
|
||||
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format,
|
||||
&graph_input_type);
|
||||
|
||||
// set fix_precision for kernel when the me prim has fix_precision attr
|
||||
UpdateKernelInfo(node_list);
|
||||
|
||||
output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
||||
SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -24,7 +24,7 @@ namespace device {
|
|||
namespace ascend {
|
||||
void GraphDescReporter::ReportData() {
|
||||
for (const auto &node : cnode_list_) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
|
||||
MS_LOG(WARNING) << "Skip non tbe kernel";
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() {
|
|||
|
||||
size_t task_index = 0;
|
||||
for (const auto &node : cnode_list_) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
|
||||
MS_LOG(WARNING) << "Skip non tbe kernel";
|
||||
++task_index;
|
||||
continue;
|
||||
|
|
|
@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
|
|||
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||
if (anf_node_ptr->inputs().size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2.";
|
||||
// akg process
|
||||
// set atomic clean addr
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs);
|
||||
auto graph = anf_node_ptr->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_users = manager->node_users();
|
||||
if (node_users[anf_node_ptr].empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
|
||||
}
|
||||
auto depend_node = node_users[anf_node_ptr].pop().first;
|
||||
if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
|
||||
MS_LOG(EXCEPTION) << "Checking Depend node failed";
|
||||
}
|
||||
if (node_users[depend_node].empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty.";
|
||||
}
|
||||
auto post_node = node_users[depend_node].pop().first;
|
||||
for (auto index : clean_output_indexs) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(post_node, index);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->push_back(input);
|
||||
}
|
||||
MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size();
|
||||
}
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
|
||||
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
|
||||
|
@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
|
|||
input->size = device_address->size_;
|
||||
kernel_inputs->push_back(input);
|
||||
}
|
||||
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
|
||||
MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
|
||||
}
|
||||
// set clean workspace address
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include "device/gpu/gpu_kernel_build.h"
|
||||
#include <string>
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/akg/akgkernelbuild.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
#include "kernel/akg/gpu/akg_gpu_kernel_build.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "operator/ops.h"
|
||||
|
@ -37,7 +37,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) {
|
||||
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) {
|
||||
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
|
||||
if (!gpu_kernel_ptr) {
|
||||
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed";
|
||||
|
|
|
@ -184,7 +184,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
|
||||
if (!result) {
|
||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||
kernel_type = AUTO_DIFF_KERNEL;
|
||||
kernel_type = AKG_KERNEL;
|
||||
}
|
||||
|
||||
if (!result) {
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive_base.h"
|
||||
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support intermediate representation definition
|
||||
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
|
||||
|
@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() {
|
|||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode != nullptr) {
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (value != nullptr) {
|
||||
return cnode->IsApply(value);
|
||||
}
|
||||
return false;
|
||||
const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
return prim != nullptr;
|
||||
}
|
||||
|
||||
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
|
||||
|
|
|
@ -124,6 +124,7 @@ class AnfNode : public Base {
|
|||
|
||||
const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
|
||||
KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
|
||||
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
|
||||
void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
|
||||
|
||||
AbstractBasePtr abstract() const { return abstract_; }
|
||||
|
@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) {
|
|||
std::string GetCNodeFuncName(CNodePtr cnode);
|
||||
|
||||
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value);
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
|
||||
|
||||
// used to check whether an AnfNode is a cnode with a Primitive as first input
|
||||
// used to get PrimitivePtr from a cnode first input
|
||||
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
|
||||
|
||||
// used to check whether an AnfNode is a valuenode having some Primitive value
|
||||
|
|
|
@ -70,7 +70,7 @@ std::string CNode::fullname_with_scope() {
|
|||
}
|
||||
fullname_with_scope_ = name;
|
||||
} else {
|
||||
// cnode input 0 should be primitive ptr
|
||||
// cnode input 0 should be primitive ptr or funcgraph ptr
|
||||
auto value_ptr = input(0)->cast<ValueNodePtr>();
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
|
||||
|
@ -84,11 +84,23 @@ std::string CNode::fullname_with_scope() {
|
|||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
|
||||
auto prim = input_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(scope());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
fullname_with_scope_ =
|
||||
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
fullname_with_scope_ = scope()->name() + "/";
|
||||
if (prim != nullptr) {
|
||||
fullname_with_scope_ += prim->name();
|
||||
} else {
|
||||
auto func_graph = input_value->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_flag != nullptr) {
|
||||
auto fg_name = GetValue<std::string>(fg_flag);
|
||||
fullname_with_scope_ += "GraphKernel_" + fg_name;
|
||||
} else {
|
||||
fullname_with_scope_ += func_graph->ToString();
|
||||
}
|
||||
}
|
||||
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
}
|
||||
|
||||
return fullname_with_scope_;
|
||||
|
|
|
@ -77,9 +77,9 @@ class Bool : public Number {
|
|||
|
||||
TypeId generic_type_id() const override { return kNumberTypeBool; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Bool>(); }
|
||||
std::string ToString() const override { return "Bool_"; }
|
||||
std::string ToReprString() const override { return "bool_"; }
|
||||
std::string DumpText() const override { return "Bool_"; }
|
||||
std::string ToString() const override { return "Bool"; }
|
||||
std::string ToReprString() const override { return "bool"; }
|
||||
std::string DumpText() const override { return "Bool"; }
|
||||
};
|
||||
|
||||
// Int
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace mindspore {
|
|||
* Methods of Graph
|
||||
*/
|
||||
FuncGraph::FuncGraph()
|
||||
: flags_(),
|
||||
: attrs_(),
|
||||
transforms_(),
|
||||
parameter_default_value_(),
|
||||
seen_(0),
|
||||
|
@ -95,13 +95,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
|||
return p;
|
||||
}
|
||||
|
||||
bool FuncGraph::has_flag(const std::string &flag) {
|
||||
if (flags_.count(flag)) {
|
||||
return flags_[flag];
|
||||
bool FuncGraph::has_flag(const std::string &key) {
|
||||
auto iter = attrs_.find(key);
|
||||
if (iter != attrs_.cend()) {
|
||||
if (iter->second->isa<BoolImm>()) {
|
||||
return GetValue<bool>(iter->second);
|
||||
}
|
||||
MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FuncGraph::has_attr(const std::string &key) {
|
||||
auto iter = attrs_.find(key);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
|
||||
ValuePtr FuncGraph::get_attr(const std::string &key) {
|
||||
auto iter = attrs_.find(key);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
|
||||
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||
|
|
|
@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
|
|||
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
|
||||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
|
||||
namespace abstract {
|
||||
|
@ -195,10 +196,19 @@ class FuncGraph : public FuncGraphBase {
|
|||
void set_is_generate(bool generated) { is_generated_ = generated; }
|
||||
bool is_generated() const { return is_generated_; }
|
||||
|
||||
bool has_flag(const std::string &flag);
|
||||
std::unordered_map<std::string, bool> &flags() { return flags_; }
|
||||
void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; }
|
||||
void set_flags(const std::string &key, const bool value) { flags_[key] = value; }
|
||||
std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
|
||||
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
bool has_flag(const std::string &key);
|
||||
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
|
||||
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
|
||||
|
||||
bool has_attr(const std::string &key);
|
||||
ValuePtr get_attr(const std::string &key);
|
||||
void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
|
||||
|
||||
std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
|
||||
void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
|
||||
|
@ -317,7 +327,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
|
||||
|
||||
std::unordered_map<std::string, bool> flags_;
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::unordered_map<std::string, FuncGraphTransform> transforms_;
|
||||
// parameter default value
|
||||
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
||||
|
|
|
@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
|||
new_node->set_abstract(old_node->abstract());
|
||||
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
|
||||
new_node->set_scope(scope);
|
||||
new_node->set_kernel_info(old_node->kernel_info_ptr());
|
||||
repl_node_[old_node] = new_node;
|
||||
nodes_.emplace_back(old_node, new_node);
|
||||
TraceManager::EndTrace();
|
||||
|
@ -211,7 +212,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
|
||||
*target_func_graph = std::make_shared<FuncGraph>();
|
||||
(*target_func_graph)->set_flags(func_graph->flags());
|
||||
(*target_func_graph)->set_attrs(func_graph->attrs());
|
||||
(*target_func_graph)->set_transforms(func_graph->transforms());
|
||||
(*target_func_graph)->set_has_vararg(func_graph->has_vararg());
|
||||
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
|
||||
|
@ -636,9 +637,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
|
||||
if (MsContext::GetInstance()->is_multi_graph_sink()) {
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
}
|
||||
|
||||
return new_func_graph;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -399,8 +399,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
|
|||
depend_inputs.push_back(*iter);
|
||||
}
|
||||
}
|
||||
set_flags(GRAPH_FLAG_HAS_EFFECT, false);
|
||||
set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
|
||||
set_flag(GRAPH_FLAG_HAS_EFFECT, false);
|
||||
set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
|
||||
if (!depend_inputs.empty()) {
|
||||
SetEffectDepends(depend_inputs);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,10 @@ if (ENABLE_D)
|
|||
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"kernel_query.cc"
|
||||
"kernel_fusion.cc"
|
||||
"akg/ascend/*.cc"
|
||||
"akg/akg_kernel_build.cc"
|
||||
"akg/akg_kernel_attrs_process.cc"
|
||||
"akg/akg_kernel_metadata.cc"
|
||||
"tbe/*.cc"
|
||||
"aicpu/*.cc"
|
||||
"rts/*.cc"
|
||||
|
@ -33,7 +37,7 @@ if (ENABLE_GPU)
|
|||
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"gpu/*.cu"
|
||||
"akg/gpu/*.cc"
|
||||
"akg/akgkernelbuild.cc"
|
||||
"akg/akg_kernel_build.cc"
|
||||
"akg/akg_kernel_attrs_process.cc"
|
||||
)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <map>
|
||||
#include "device/kernel_runtime.h"
|
||||
#include "kernel/aicpu/aicpu_kernel_mod.h"
|
||||
#include "kernel/akg/akgkernelbuild.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
#include "proto/tensor.pb.h"
|
||||
#include "proto/tensor_shape.pb.h"
|
||||
#include "proto/attr.pb.h"
|
||||
|
|
|
@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) {
|
|||
dst_type = "float32";
|
||||
} else if (output_type == kFloat16->type_id()) {
|
||||
dst_type = "float16";
|
||||
} else if (output_type == kInt32->type_id()) {
|
||||
dst_type = "int32";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString();
|
||||
}
|
||||
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/akg/akgkernelbuild.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
#include <Python.h>
|
||||
#include <sys/types.h>
|
||||
#include <signal.h>
|
||||
|
@ -43,7 +43,9 @@ namespace kernel {
|
|||
constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200;
|
||||
constexpr int32_t ARGS_SIZE = 1;
|
||||
constexpr auto kCompileWithJsonFunc = "compilewithjson";
|
||||
|
||||
// json key
|
||||
constexpr auto kOpDesc = "op_desc";
|
||||
constexpr auto kInputDesc = "input_desc";
|
||||
constexpr auto kShape = "shape";
|
||||
constexpr auto kDataType = "data_type";
|
||||
|
@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc";
|
|||
constexpr auto kName = "name";
|
||||
constexpr auto kTensorName = "tensor_name";
|
||||
constexpr auto kValue = "value";
|
||||
constexpr auto KInpputNames = "input_names";
|
||||
constexpr auto KDynInputSizes = "dyn_input_sizes";
|
||||
constexpr auto KInputNames = "input_names";
|
||||
constexpr auto KInput = "input";
|
||||
constexpr auto KDtype = "dtype";
|
||||
int AkgKernelBuild::op_cnt_ = 0;
|
||||
std::mutex AkgKernelBuild::op_cnt_mtx_;
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::string Vector2Str(const std::vector<T> &inputs) {
|
||||
if (!inputs.empty()) {
|
||||
std::ostringstream oss;
|
||||
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
|
||||
oss << inputs.back();
|
||||
return oss.str();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string PyObjectToStr(PyObject *const PyObj) {
|
||||
std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) {
|
||||
char *pChar = nullptr;
|
||||
std::string str_res;
|
||||
if (PyObj == nullptr) {
|
||||
|
@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) {
|
|||
return str_res;
|
||||
}
|
||||
|
||||
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
||||
const std::pair<size_t, size_t> &position) {
|
||||
if (node_json.count(tag) == 0) {
|
||||
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
|
||||
return "";
|
||||
}
|
||||
|
||||
auto const &tag_desc = node_json[tag];
|
||||
nlohmann::json first_index;
|
||||
if (tag == kOutputDesc) {
|
||||
first_index = tag_desc;
|
||||
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
|
||||
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
|
||||
return "";
|
||||
} else {
|
||||
first_index = tag_desc[position.first];
|
||||
}
|
||||
|
||||
if (!first_index.is_array() || first_index.size() <= position.second) {
|
||||
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
|
||||
return "";
|
||||
}
|
||||
auto const &second_index = first_index[position.second];
|
||||
if (second_index.count(kTensorName) == 0) {
|
||||
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "].";
|
||||
return "";
|
||||
}
|
||||
|
||||
return second_index[kTensorName];
|
||||
}
|
||||
|
||||
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
||||
nlohmann::json *const node_json) {
|
||||
MS_EXCEPTION_IF_NULL(node_json);
|
||||
if (node_json->count(tag) == 0) {
|
||||
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
|
||||
return;
|
||||
}
|
||||
|
||||
nlohmann::json *tag_desc = &((*node_json)[tag]);
|
||||
nlohmann::json *first_index;
|
||||
if (tag == kOutputDesc) {
|
||||
first_index = tag_desc;
|
||||
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
|
||||
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
|
||||
return;
|
||||
} else {
|
||||
first_index = &((*tag_desc)[position.first]);
|
||||
}
|
||||
|
||||
if (!first_index->is_array() || first_index->size() <= position.second) {
|
||||
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
|
||||
return;
|
||||
}
|
||||
nlohmann::json *second_index = &((*first_index)[position.second]);
|
||||
if (second_index->count(kTensorName) == 0) {
|
||||
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "].";
|
||||
return;
|
||||
}
|
||||
(*second_index)[kTensorName] = new_name;
|
||||
return;
|
||||
}
|
||||
|
||||
int AkgKernelBuild::op_cnt_ = 0;
|
||||
std::mutex AkgKernelBuild::op_cnt_mtx_;
|
||||
|
||||
std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::string device;
|
||||
|
@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
|
||||
// dtype : float16
|
||||
auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index);
|
||||
TypePtr type_ptr = TypeIdToType(type_id);
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
std::string dtype = type_ptr->ToString();
|
||||
dtype = Dtype2String(dtype);
|
||||
std::string dtype = TypeId2String(type_id);
|
||||
if (dtype.empty()) {
|
||||
MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. ";
|
||||
return false;
|
||||
|
@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kDataType] = dtype;
|
||||
input_desc_json[kName] = op_input_name;
|
||||
input_desc_json[kTensorName] =
|
||||
op_input_name + "_" + std::to_string(real_input_index) + "_" + std::to_string(input_i);
|
||||
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
|
||||
input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
|
||||
if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
|
||||
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
|
||||
<< "], value: " << input_desc_json[kValue];
|
||||
|
||||
input_shape.clear();
|
||||
}
|
||||
if (input_shape.empty()) {
|
||||
input_shape.push_back(1);
|
||||
}
|
||||
input_desc_json[kShape] = input_shape;
|
||||
input_list.emplace_back(input_desc_json);
|
||||
real_input_index++;
|
||||
}
|
||||
inputs_json->emplace_back(input_list);
|
||||
real_input_index++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
for (size_t i = 0; i < output_tensor_num; i++) {
|
||||
nlohmann::json output_json;
|
||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
|
||||
TypePtr type_ptr = TypeIdToType(type_id);
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
std::string dtype = type_ptr->ToString();
|
||||
dtype = Dtype2String(dtype);
|
||||
std::string dtype = TypeId2String(type_id);
|
||||
if (dtype.empty()) {
|
||||
MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. ";
|
||||
return false;
|
||||
|
@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
std::string output_name = outputs[i]->name();
|
||||
output_json[kDataType] = dtype;
|
||||
output_json[kName] = output_name;
|
||||
output_json[kTensorName] = output_name + "_" + std::to_string(i);
|
||||
output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
|
||||
output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i);
|
||||
outputs_json->push_back(output_json);
|
||||
}
|
||||
|
@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const
|
|||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||
|
||||
// get basic params from currentNodeOpDesc
|
||||
(*node_json)["platform"] = "AKG";
|
||||
(*node_json)[kName] = op_name;
|
||||
(*node_json)["fusion_type"] = AnfAlgo::GetFusionType(anf_node);
|
||||
(*node_json)["impl_path"] = op_info_ptr->impl_path();
|
||||
(*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node);
|
||||
(*node_json)["composite"] = false;
|
||||
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr input_names_v = primitive->GetAttr(KInpputNames);
|
||||
ValuePtr input_names_v = primitive->GetAttr(KInputNames);
|
||||
if (input_names_v == nullptr) {
|
||||
MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "].";
|
||||
return false;
|
||||
|
@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod
|
|||
(void)alarm(0);
|
||||
if (pRes == nullptr) {
|
||||
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
|
||||
<< PyObjectToStr(pArg) << ").";
|
||||
<< AkgKernelBuild::PyObjectToStr(pArg) << ").";
|
||||
return nullptr;
|
||||
}
|
||||
if (PyObject_IsTrue(pRes) != 1) {
|
||||
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
|
||||
<< PyObjectToStr(pArg) << ").";
|
||||
<< AkgKernelBuild::PyObjectToStr(pArg) << ").";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto
|
|||
<< "]";
|
||||
return kernel_pack;
|
||||
}
|
||||
|
||||
size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
||||
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
|
||||
size_t index = input_tensor_idx_.size();
|
||||
input_tensor_idx_[input_node] = index;
|
||||
}
|
||||
|
||||
return input_tensor_idx_[input_node];
|
||||
}
|
||||
|
||||
size_t AkgKernelBuild::GetOutputTensorIdxInc() {
|
||||
size_t idx = output_tensor_idx_++;
|
||||
return idx;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -32,29 +32,45 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
class AkgKernelBuild {
|
||||
public:
|
||||
AkgKernelBuild() = default;
|
||||
AkgKernelBuild() {
|
||||
input_tensor_idx_ = {};
|
||||
output_tensor_idx_ = 0;
|
||||
}
|
||||
~AkgKernelBuild() = default;
|
||||
|
||||
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
|
||||
std::vector<size_t> *const output_size);
|
||||
static std::string GetProcessor(const AnfNodePtr &anf_node);
|
||||
static std::string PyObjectToStr(PyObject *const PyObj);
|
||||
|
||||
private:
|
||||
protected:
|
||||
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
|
||||
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
|
||||
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
||||
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
|
||||
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
|
||||
int GetOpCntInc();
|
||||
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
|
||||
size_t GetOutputTensorIdxInc();
|
||||
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
||||
nlohmann::json *const node_json);
|
||||
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
|
||||
|
||||
int GetOpCntInc();
|
||||
std::string GetProcessor(const AnfNodePtr &anf_node);
|
||||
static int op_cnt_;
|
||||
// lock for variable fusionOpCnt in singleton mode
|
||||
static std::mutex op_cnt_mtx_;
|
||||
std::string json_name_;
|
||||
std::string json_info_;
|
||||
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
|
||||
size_t output_tensor_idx_;
|
||||
};
|
||||
|
||||
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
|
||||
std::vector<size_t> *const output_size);
|
||||
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
||||
nlohmann::json *const node_json);
|
||||
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
||||
const std::pair<size_t, size_t> &position);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/akg/akg_kernel_metadata.h"
|
||||
#include <memory>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AkgMetadataInfo(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
for (size_t i = 0; i < support_devices.size(); i++) {
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
|
||||
if (op_info_ptr == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) {
|
||||
MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed.";
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "].";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (kernel_info_list->empty()) {
|
||||
MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "].";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
|
|
@ -0,0 +1,385 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <Python.h>
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "kernel/tbe/tbe_utils.h"
|
||||
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
|
||||
#include "kernel/akg/akg_kernel_attrs_process.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
constexpr int32_t PARALLEL_ARGS_SIZE = 3;
|
||||
constexpr int32_t PROCESS_NUM = 16;
|
||||
constexpr int32_t TIME_OUT = 300;
|
||||
|
||||
constexpr auto kOpDesc = "op_desc";
|
||||
constexpr auto kShape = "shape";
|
||||
constexpr auto kDataType = "data_type";
|
||||
constexpr auto kInputDesc = "input_desc";
|
||||
constexpr auto kOutputDesc = "output_desc";
|
||||
constexpr auto kTensorName = "tensor_name";
|
||||
constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel";
|
||||
constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler";
|
||||
|
||||
bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
|
||||
auto it = kAkgKernelAttrsProcessMap.find(op_name);
|
||||
if (it != kAkgKernelAttrsProcessMap.end()) {
|
||||
it->second(anf_node);
|
||||
}
|
||||
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
|
||||
nlohmann::json node_json;
|
||||
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
|
||||
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
|
||||
}
|
||||
|
||||
kernel_json_ = node_json.dump();
|
||||
|
||||
if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) {
|
||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list) {
|
||||
if (anf_nodes.empty() || input_list.empty()) {
|
||||
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
|
||||
<< "].";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list ["
|
||||
<< input_list.size() << "].";
|
||||
|
||||
std::map<AnfNodePtr, nlohmann::json> node_json_map;
|
||||
|
||||
for (auto const &anf_node : anf_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (!AnfAlgo::IsRealKernel(anf_node)) {
|
||||
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
|
||||
return false;
|
||||
}
|
||||
auto it = kAkgKernelAttrsProcessMap.find(op_name);
|
||||
if (it != kAkgKernelAttrsProcessMap.end()) {
|
||||
it->second(anf_node);
|
||||
}
|
||||
|
||||
nlohmann::json node_json;
|
||||
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
|
||||
MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed.";
|
||||
return false;
|
||||
}
|
||||
// No need for composite op.
|
||||
node_json.erase("id");
|
||||
node_json.erase("op");
|
||||
node_json.erase("composite");
|
||||
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
if (primitive->GetAttr("fusion") != nullptr) {
|
||||
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
|
||||
}
|
||||
|
||||
node_json_map[anf_node] = node_json;
|
||||
}
|
||||
|
||||
for (auto const &anf_node : anf_nodes) {
|
||||
std::vector<int> dyn_input_sizes;
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
|
||||
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
||||
}
|
||||
|
||||
bool is_dynamic_input = !dyn_input_sizes.empty();
|
||||
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
|
||||
size_t real_input_index = 0;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
|
||||
for (size_t j = 0; j < input_tensor_num; ++j) {
|
||||
auto tmp_input = GetKernelInput(anf_node, real_input_index);
|
||||
std::string tensor_name = GetTensorName(node_json_map[anf_node], kInputDesc, std::make_pair(i, j));
|
||||
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
|
||||
std::string new_tensor_name =
|
||||
GetTensorName(node_json_map[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second));
|
||||
SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
|
||||
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
|
||||
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
|
||||
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
|
||||
<< anf_node->fullname_with_scope() << "] is out input.";
|
||||
}
|
||||
real_input_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json fused_node_json;
|
||||
std::vector<nlohmann::json> node_json_desc;
|
||||
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
|
||||
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
|
||||
fused_node_json[kOpDesc] = node_json_desc;
|
||||
|
||||
nlohmann::json inputs_json;
|
||||
auto input_index = GetInputIndex(anf_nodes, input_list);
|
||||
for (size_t i = 0; i < input_index.size(); ++i) {
|
||||
auto tmp_input = input_index[i];
|
||||
auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first);
|
||||
std::string dtype = TypeId2String(type_id);
|
||||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kTensorName] = GetTensorName(node_json_map[tmp_input.first], kInputDesc, tmp_input.second);
|
||||
input_desc_json[kDataType] = dtype;
|
||||
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first);
|
||||
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
|
||||
}
|
||||
fused_node_json[kInputDesc] = inputs_json;
|
||||
|
||||
nlohmann::json outputs_json;
|
||||
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
|
||||
for (size_t i = 0; i < output_index.size(); ++i) {
|
||||
auto tmp_output = output_index[i];
|
||||
bool found = false;
|
||||
nlohmann::json output_desc_json;
|
||||
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
|
||||
if (tmp_output.first == input_list[input_i]) {
|
||||
output_desc_json = inputs_json[input_i][0];
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second);
|
||||
std::string dtype = TypeId2String(type_id);
|
||||
output_desc_json[kTensorName] =
|
||||
GetTensorName(node_json_map[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second));
|
||||
output_desc_json[kDataType] = dtype;
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second);
|
||||
if (output_shape.empty()) {
|
||||
output_shape.push_back(1);
|
||||
}
|
||||
output_desc_json[kShape] = output_shape;
|
||||
}
|
||||
outputs_json.emplace_back(output_desc_json);
|
||||
}
|
||||
fused_node_json[kOutputDesc] = outputs_json;
|
||||
|
||||
size_t hash_id = std::hash<std::string>()(fused_node_json.dump());
|
||||
json_name_ = "Fused_";
|
||||
auto fg = anf_nodes[0]->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (attr_val != nullptr) {
|
||||
auto fg_attr = GetValue<std::string>(attr_val);
|
||||
(void)json_name_.append(fg_attr).append("_");
|
||||
}
|
||||
(void)json_name_.append(std::to_string(hash_id));
|
||||
fused_node_json["composite_graph"] = fg->ToString();
|
||||
fused_node_json["op"] = json_name_;
|
||||
fused_node_json["platform"] = "AKG";
|
||||
fused_node_json["process"] = "aicore";
|
||||
fused_node_json["composite"] = true;
|
||||
|
||||
kernel_json_ = fused_node_json.dump();
|
||||
|
||||
if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) {
|
||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) {
|
||||
MS_EXCEPTION_IF_NULL(p_args);
|
||||
*p_args = PyTuple_New(PARALLEL_ARGS_SIZE);
|
||||
|
||||
PyObject *arg1 = PyList_New(kernel_jsons.size());
|
||||
for (int i = 0; i < PyList_Size(arg1); ++i) {
|
||||
PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str()));
|
||||
}
|
||||
PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM);
|
||||
PyObject *arg3 = Py_BuildValue("i", TIME_OUT);
|
||||
|
||||
(void)PyTuple_SetItem(*p_args, 0, arg1);
|
||||
(void)PyTuple_SetItem(*p_args, 1, arg2);
|
||||
(void)PyTuple_SetItem(*p_args, 2, arg3);
|
||||
}
|
||||
|
||||
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
|
||||
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
||||
std::vector<std::string> jsons;
|
||||
std::unordered_set<std::string> json_name_set;
|
||||
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes;
|
||||
for (const auto &[builder, anf_node] : build_args) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto json_name = builder.json_name();
|
||||
MS_LOG(DEBUG) << "Akg start compile op: " << json_name;
|
||||
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
||||
if (cached_kernel_pack != nullptr) {
|
||||
MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope["
|
||||
<< anf_node->fullname_with_scope() << "].";
|
||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (json_name_set.count(json_name) != 0) {
|
||||
repeat_nodes.push_back({builder, anf_node});
|
||||
continue;
|
||||
}
|
||||
json_name_set.insert(json_name);
|
||||
auto node_json = builder.kernel_json();
|
||||
kernel::SaveJsonInfo(json_name, node_json);
|
||||
jsons.push_back(node_json);
|
||||
}
|
||||
|
||||
// No nodes need to be compiled!
|
||||
if (jsons.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to call python method to compile nodes parallely.
|
||||
PyObject *p_module = nullptr;
|
||||
PyObject *p_func = nullptr;
|
||||
PyObject *p_arg = nullptr;
|
||||
PyObject *p_res = nullptr;
|
||||
|
||||
p_module = PyImport_ImportModule(kMultiProcModule);
|
||||
if (p_module == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "].";
|
||||
return false;
|
||||
}
|
||||
|
||||
p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc);
|
||||
GenParallelCompileFuncArgs(jsons, &p_arg);
|
||||
MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size()
|
||||
<< " Akg kernels parallelly.";
|
||||
p_res = PyEval_CallObject(p_func, p_arg);
|
||||
if (p_res == nullptr) {
|
||||
PyErr_Print();
|
||||
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
|
||||
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
|
||||
return false;
|
||||
}
|
||||
if (PyObject_IsTrue(p_res) != 1) {
|
||||
PyErr_Print();
|
||||
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
|
||||
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
|
||||
return false;
|
||||
}
|
||||
|
||||
// All unique done here, cache them and set kernel.
|
||||
for (const auto &[builder, anf_node] : build_args) {
|
||||
auto json_name = builder.json_name();
|
||||
auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
||||
if (new_kernel_pack == nullptr) {
|
||||
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope["
|
||||
<< anf_node->fullname_with_scope() << "].";
|
||||
return false;
|
||||
}
|
||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
|
||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||
MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!";
|
||||
}
|
||||
|
||||
// Handle repeated nodes.
|
||||
for (const auto &[builder, anf_node] : repeat_nodes) {
|
||||
auto node_json = builder.kernel_json();
|
||||
auto json_name = builder.json_name();
|
||||
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
||||
if (cached_kernel_pack == nullptr) return false;
|
||||
MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope["
|
||||
<< anf_node->fullname_with_scope() << "].";
|
||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node;
|
||||
for (const auto &anf_node : anf_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
AkgAscendKernelBuilder akg_cce_kernel_builder;
|
||||
KernelPackPtr kernel_pack = nullptr;
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::IsGraphKernel(cnode)) {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
std::vector<AnfNodePtr> input_list;
|
||||
std::vector<AnfNodePtr> output_list;
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]";
|
||||
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
||||
if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) {
|
||||
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "].";
|
||||
}
|
||||
} else {
|
||||
if (!akg_cce_kernel_builder.CollectJson(anf_node)) {
|
||||
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "].";
|
||||
}
|
||||
}
|
||||
json_and_node.push_back({akg_cce_kernel_builder, anf_node});
|
||||
}
|
||||
|
||||
if (json_and_node.empty()) {
|
||||
MS_LOG(DEBUG) << "There is no kernel needed to be compiled.";
|
||||
return true;
|
||||
}
|
||||
|
||||
return AkgOpParallelBuild(json_and_node);
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AkgAscendKernelBuilder : public AkgKernelBuild {
|
||||
public:
|
||||
AkgAscendKernelBuilder() = default;
|
||||
~AkgAscendKernelBuilder() = default;
|
||||
|
||||
bool CollectJson(const AnfNodePtr &anf_node);
|
||||
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list);
|
||||
std::string json_name() const { return json_name_; }
|
||||
std::string kernel_json() const { return kernel_json_; }
|
||||
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
|
||||
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
|
||||
|
||||
private:
|
||||
std::string kernel_json_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
};
|
||||
|
||||
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
|
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "runtime/rt.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using std::fstream;
|
||||
using std::map;
|
||||
using std::mutex;
|
||||
using std::string;
|
||||
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
|
||||
using tbe::KernelManager;
|
||||
constexpr uint32_t DEFAULT_BLOCK_DIM = 1;
|
||||
/**
|
||||
* @brief infotable contain func_stub\blockdim\kernel file buffer
|
||||
*/
|
||||
AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {}
|
||||
|
||||
void AkgKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; }
|
||||
|
||||
void AkgKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
|
||||
|
||||
void AkgKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; }
|
||||
|
||||
const std::vector<size_t> &AkgKernelMod::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
void DumpData(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
const char *dump_data = getenv("MS_KERNEL_DUMP_DATA");
|
||||
if (dump_data) {
|
||||
int idx = 0;
|
||||
for (const auto &x : inputs) {
|
||||
std::vector<char> buf(x->size);
|
||||
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
|
||||
RT_MEMCPY_DEVICE_TO_HOST)) {
|
||||
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string file_name("input_");
|
||||
file_name += std::to_string(idx);
|
||||
std::ofstream file(file_name, std::ios::binary);
|
||||
if (file.is_open()) {
|
||||
(void)file.write(buf.data(), SizeToLong(buf.size()));
|
||||
file.close();
|
||||
idx++;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Open file failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
idx = 0;
|
||||
for (const auto &x : outputs) {
|
||||
std::vector<char> buf(x->size);
|
||||
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
|
||||
RT_MEMCPY_DEVICE_TO_HOST)) {
|
||||
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string file_name("output_");
|
||||
file_name += std::to_string(idx);
|
||||
std::ofstream file(file_name, std::ios::binary);
|
||||
if (file.is_open()) {
|
||||
(void)file.write(buf.data(), SizeToLong(buf.size()));
|
||||
file.close();
|
||||
idx++;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Open file failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (stream_ptr == 0) {
|
||||
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kernel_pack_ == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||
if (func_stub == 0) {
|
||||
MS_LOG(ERROR) << "GenFuncStub failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// pack all addresses into a vector.
|
||||
std::vector<void *> runtime_args;
|
||||
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args),
|
||||
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args),
|
||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||
|
||||
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||
auto stream = reinterpret_cast<rtStream_t *>(stream_ptr);
|
||||
if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(),
|
||||
SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) {
|
||||
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
|
||||
return false;
|
||||
}
|
||||
|
||||
DumpData(inputs, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
||||
if (kernel_pack_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "kernel pack should not be nullptr.";
|
||||
}
|
||||
|
||||
std::vector<uint8_t> args;
|
||||
uint32_t args_size = 0;
|
||||
std::vector<uint8_t> sm_desc;
|
||||
void *binary = nullptr;
|
||||
uint32_t binary_size = 0;
|
||||
std::vector<uint8_t> meta_data;
|
||||
std::vector<void *> input_data_addrs;
|
||||
std::vector<void *> output_data_addrs;
|
||||
std::vector<void *> workspace_addrs;
|
||||
|
||||
// pack all addresses into a vector.
|
||||
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs),
|
||||
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
|
||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||
|
||||
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||
if (func_stub == 0) {
|
||||
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
|
||||
}
|
||||
|
||||
std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_);
|
||||
|
||||
MS_LOG(DEBUG) << "The block_dim is:" << block_dim;
|
||||
|
||||
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
|
||||
stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs,
|
||||
output_data_addrs, workspace_addrs);
|
||||
return {task_info_ptr};
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "kernel/ascend_kernel_mod.h"
|
||||
#include "kernel/tbe/tbe_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AkgKernelMod : public AscendKernelMod {
|
||||
public:
|
||||
explicit AkgKernelMod(const KernelPackPtr &kernel_pack);
|
||||
~AkgKernelMod() final {}
|
||||
|
||||
void SetInputSizeList(const std::vector<size_t> &size_list);
|
||||
void SetOutputSizeList(const std::vector<size_t> &size_list);
|
||||
void SetWorkspaceSizeList(const std::vector<size_t> &size_list);
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||
|
||||
private:
|
||||
KernelPackPtr kernel_pack_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
|
||||
using AkgKernelModPtr = std::shared_ptr<AkgKernelMod>;
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
|
|
@ -18,7 +18,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/akg/akgkernelbuild.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
#include "kernel/akg/gpu/akg_gpu_kernel_mod.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
|
|
|
@ -23,6 +23,11 @@
|
|||
#include "nlohmann/json.hpp"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "common/utils.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/graph_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -48,12 +53,6 @@ const std::map<TypeId, std::string> type_id_str_map = {
|
|||
{TypeId::kNumberTypeBool, "bool"},
|
||||
};
|
||||
|
||||
const std::map<std::string, std::string> DATATYPE_STRING_MAP{
|
||||
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
|
||||
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
|
||||
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "bool"}, {"Float64", "double"},
|
||||
};
|
||||
|
||||
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
||||
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
|
||||
{"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
|
||||
|
@ -243,14 +242,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string Dtype2String(const std::string &dtypes) {
|
||||
auto iter = DATATYPE_STRING_MAP.find(dtypes);
|
||||
if (iter == DATATYPE_STRING_MAP.end()) {
|
||||
MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
std::string TypeId2String(TypeId type_id) {
|
||||
auto iter = type_id_str_map.find(type_id);
|
||||
if (iter == type_id_str_map.end()) {
|
||||
|
@ -361,7 +352,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
|
|||
output_num = 1;
|
||||
} else {
|
||||
if (output_idx < real_output_num) {
|
||||
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
|
||||
MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
|
||||
output_num = 1;
|
||||
}
|
||||
}
|
||||
|
@ -403,7 +394,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu
|
|||
}
|
||||
|
||||
if (imply_type == kAKG) {
|
||||
builder->SetKernelType(AUTO_DIFF_KERNEL);
|
||||
builder->SetKernelType(AKG_KERNEL);
|
||||
} else if (imply_type == kAICPU) {
|
||||
builder->SetKernelType(AICPU_KERNEL);
|
||||
} else {
|
||||
|
@ -634,5 +625,256 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie
|
|||
}
|
||||
unique_grad->indices_size_ = unique_indices_size + 1;
|
||||
}
|
||||
|
||||
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
||||
if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
|
||||
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
|
||||
}
|
||||
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return AnfAlgo::VisitKernel(anf_node, 0);
|
||||
} else {
|
||||
return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list) {
|
||||
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
|
||||
for (size_t i = 0; i < input_list.size(); ++i) {
|
||||
auto const &input = input_list[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
bool found = false;
|
||||
// using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
|
||||
auto mng = input->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
const NodeUsersMap &users = mng->node_users();
|
||||
auto input_users = users.find(input);
|
||||
if (input_users == users.end() || input_users->second.empty()) {
|
||||
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
|
||||
<< input->func_graph()->ToString() << "] has no users.";
|
||||
}
|
||||
|
||||
for (auto const &input_user : input_users->second) {
|
||||
for (auto const &anf_node : node_list) {
|
||||
if (anf_node != input_user.first) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> dyn_input_sizes;
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
|
||||
dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
|
||||
}
|
||||
|
||||
if (dyn_input_sizes.empty()) {
|
||||
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
|
||||
found = true;
|
||||
break;
|
||||
} else {
|
||||
int used_as_idx = input_user.second - 1;
|
||||
int accum_idx = 0;
|
||||
size_t dyn_i = 0;
|
||||
for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
|
||||
accum_idx += dyn_input_sizes[dyn_i];
|
||||
if (used_as_idx < accum_idx) {
|
||||
input_index.push_back(std::make_pair(
|
||||
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (dyn_i != dyn_input_sizes.size()) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
|
||||
<< input->func_graph()->ToString() << "] found no related kernel info.";
|
||||
}
|
||||
}
|
||||
return input_index;
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list) {
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> output_index;
|
||||
for (size_t i = 0; i < output_list.size(); ++i) {
|
||||
auto const &output = output_list[i];
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
bool found = false;
|
||||
auto pree_node = AnfAlgo::VisitKernel(output, 0);
|
||||
|
||||
auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
|
||||
if (pos != std::end(node_list)) {
|
||||
output_index.push_back(pree_node);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
|
||||
if (ret != std::end(input_list)) {
|
||||
output_index.push_back(std::make_pair(pree_node.first, 0));
|
||||
found = true;
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
|
||||
<< output->func_graph()->ToString() << "] found no related kernel info.";
|
||||
}
|
||||
}
|
||||
return output_index;
|
||||
}
|
||||
|
||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_list);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
|
||||
for (auto const &node : node_lists) {
|
||||
if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
|
||||
node_list->push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
||||
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_list);
|
||||
MS_EXCEPTION_IF_NULL(input_list);
|
||||
MS_EXCEPTION_IF_NULL(output_list);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
GetValidKernelNodes(func_graph, node_list);
|
||||
|
||||
auto parameters = func_graph->parameters();
|
||||
input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
|
||||
|
||||
auto func_output = func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(func_output);
|
||||
if (func_output->isa<CNode>()) {
|
||||
// multi output.
|
||||
auto cnode = func_output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input0 = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
|
||||
auto input_node = cnode->input(input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
|
||||
}
|
||||
} else {
|
||||
// single output.
|
||||
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
|
||||
}
|
||||
} else {
|
||||
// single output.
|
||||
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
|
||||
}
|
||||
}
|
||||
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(node_json);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->size()) {
|
||||
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
||||
<< cnode->inputs().size() << "][" << cnode->DebugString() << "]";
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
if (!IsValueNode<tensor::Tensor>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type_id = tensor->data_type();
|
||||
auto *data = tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
|
||||
// not const tensor.
|
||||
MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
|
||||
}
|
||||
|
||||
if (type_id == kFloat32->type_id()) {
|
||||
float *val = static_cast<float *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = val[0];
|
||||
MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
|
||||
return true;
|
||||
} else if (type_id == kFloat16->type_id()) {
|
||||
float16 *val = static_cast<float16 *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = static_cast<float>(val[0]);
|
||||
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
|
||||
return true;
|
||||
} else if (type_id == kInt32->type_id()) {
|
||||
int *val = static_cast<int *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = val[0];
|
||||
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
|
||||
return true;
|
||||
}
|
||||
MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node_list);
|
||||
auto output = func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::IsRealKernel(output)) {
|
||||
// single output.
|
||||
node_list->push_back(std::make_pair(output, 0));
|
||||
return;
|
||||
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
// multi output.
|
||||
auto &inputs = output_cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
node_list->push_back(in_with_idx);
|
||||
}
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
|
||||
<< " of graph: " << func_graph->ToString();
|
||||
}
|
||||
|
||||
bool IsWeightBoundary(const AnfNodePtr &node) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,9 +20,12 @@
|
|||
#include <dirent.h>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
@ -79,13 +82,11 @@ bool CheckCache(const std::string &kernel_name);
|
|||
KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor);
|
||||
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
|
||||
TypeId DtypeToTypeId(const std::string &dtypes);
|
||||
std::string Dtype2String(const std::string &dtypes);
|
||||
std::string Dtype2ShortType(const std::string &dtypes);
|
||||
std::string TypeId2String(TypeId type_id);
|
||||
size_t GetDtypeNbyte(const std::string &dtypes);
|
||||
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
|
||||
bool IsAtomicNode(const CNodePtr &kernel_node);
|
||||
void SaveJsonInfo(const std::string &json_name, const std::string &info);
|
||||
std::string GetProcessor(const AnfNodePtr &anf_node);
|
||||
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
|
||||
|
@ -94,6 +95,18 @@ void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGr
|
|||
size_t outer_dim);
|
||||
void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
|
||||
size_t outer_dim);
|
||||
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
|
||||
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list);
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list);
|
||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
||||
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
|
||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
|
||||
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
|
||||
bool IsWeightBoundary(const AnfNodePtr &node);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <fstream>
|
||||
#include "mindspore/ccsrc/kernel/kernel.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/akg/akgkernelbuild.h"
|
||||
#include "kernel/akg/akg_kernel_build.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "securec/include/securec.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
|
||||
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
|
||||
|
||||
namespace kernel {
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "kernel/rts/rt_kernel_info.h"
|
||||
#include "kernel/hccl/hccl_kernel_metadata.h"
|
||||
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
||||
#include "kernel/akg/akg_kernel_metadata.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -59,10 +60,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|||
}
|
||||
}
|
||||
} // namespace
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
|
||||
void KernelQueryAll(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
|
||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||
|
||||
if (kernel_info_list->empty()) {
|
||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||
if (!kernel_info_list->empty()) {
|
||||
|
@ -82,6 +87,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
if (kernel_info_list->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
|
||||
}
|
||||
}
|
||||
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
|
||||
KernelType kernel_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
||||
switch (kernel_type) {
|
||||
case KernelType::AKG_KERNEL:
|
||||
AkgMetadataInfo(kernel_node, kernel_info_list);
|
||||
break;
|
||||
default:
|
||||
KernelQueryAll(kernel_node, kernel_info_list);
|
||||
break;
|
||||
}
|
||||
|
||||
if (kernel_info_list->empty()) {
|
||||
MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!";
|
||||
}
|
||||
// check output
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
|
||||
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
|
||||
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
|
||||
|
|
|
@ -272,8 +272,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
|
|||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->device_target() == kGPUDevice);
|
||||
if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) ||
|
||||
(!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
|
||||
if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
|
||||
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
||||
<< ", current op num: " << op_info_.size();
|
||||
return nullptr;
|
||||
|
|
|
@ -347,7 +347,7 @@ static int TypeStrToDstType(const std::string &type_str) {
|
|||
ret = 4;
|
||||
} else if (type_str == "UInt64") {
|
||||
ret = 10;
|
||||
} else if (type_str == "Bool_") {
|
||||
} else if (type_str == "Bool") {
|
||||
ret = 12;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Error type str is invailed: " << type_str;
|
||||
|
|
|
@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
|
|||
const std::map<std::string, std::string> type_str_maps = {
|
||||
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
|
||||
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
|
||||
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"},
|
||||
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"},
|
||||
};
|
||||
|
||||
const std::unordered_map<std::string, size_t> type_nbyte_maps = {
|
||||
|
|
|
@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
|
|||
|
||||
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("hyper_map");
|
||||
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
|
@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu
|
|||
MS_EXCEPTION_IF_NULL(a_tuple);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("tail");
|
||||
AnfNodePtr ptrTup = ret->add_parameter();
|
||||
|
||||
|
@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list
|
|||
MS_EXCEPTION_IF_NULL(a_list);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("tail");
|
||||
AnfNodePtr ptrList = ret->add_parameter();
|
||||
|
||||
|
@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
||||
b->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
b->set_output(b->NewCNode(grads));
|
||||
|
||||
fg->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
|
||||
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
|
||||
return fg;
|
||||
|
@ -504,7 +504,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
|||
const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args,
|
||||
bool applyJ) {
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
auto weights_node = weights;
|
||||
if (weights == nullptr && !args.empty()) {
|
||||
|
@ -625,7 +625,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
|
||||
std::ostringstream ss;
|
||||
ss << "grad{" << nparam << "}";
|
||||
dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
dfBuilder->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = dfBuilder->add_parameter();
|
||||
|
||||
|
@ -671,7 +671,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis
|
|||
}
|
||||
|
||||
FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
|
||||
fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg_ptr->debug_info()->set_name("list_map");
|
||||
AnfNodePtr fn = fg_ptr->add_parameter();
|
||||
|
||||
|
@ -741,7 +741,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
|
|||
// cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
|
||||
FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
|
||||
fgtrue_ptr->debug_info()->set_name("ftrue");
|
||||
fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
|
||||
auto inputs = fgtrue_output_cnode->inputs();
|
||||
|
@ -751,7 +751,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
|
|||
|
||||
FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
|
||||
fgfalse_ptr->debug_info()->set_name("ffalse");
|
||||
fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fgfalse_ptr->set_output(resl);
|
||||
|
||||
AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
|
||||
|
@ -808,7 +808,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
|
|||
}
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr p_tup_a = ret->add_parameter();
|
||||
AnfNodePtr p_tup_b = ret->add_parameter();
|
||||
|
||||
|
@ -912,7 +912,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
|||
GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr p_tuple = ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
|
||||
|
@ -941,7 +941,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
|
|||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr functions = ret_graph->add_parameter();
|
||||
auto index = ret_graph->add_parameter();
|
||||
|
||||
|
|
|
@ -304,7 +304,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi
|
|||
}
|
||||
auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters());
|
||||
func_graph->set_output(new_cnode);
|
||||
func_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace prim
|
||||
|
|
|
@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &
|
|||
MS_EXCEPTION_IF_NULL(arg0_list);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("append");
|
||||
AnfNodePtr arg0_node = ret->add_parameter();
|
||||
|
||||
|
|
|
@ -51,8 +51,8 @@ AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &f
|
|||
FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
|
||||
// Generate func for leaf nodes
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("map");
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
if (fn_leaf_ == nullptr) {
|
||||
|
@ -237,8 +237,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c
|
|||
|
||||
FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("map");
|
||||
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
|
|
|
@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
|||
|
||||
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||
auto ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
AnfNodePtr fnNode = ret_graph->add_parameter();
|
||||
std::vector<AnfNodePtr> elems;
|
||||
|
|
|
@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
|
|||
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
|
||||
});
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
|
||||
(void)ret_graph->add_parameter();
|
||||
}
|
||||
|
|
|
@ -50,6 +50,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
|
|||
const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
|
||||
const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
|
||||
const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
|
||||
const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
||||
const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
||||
|
||||
// Type introspection
|
||||
const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
|
@ -166,17 +172,20 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
|
|||
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
|
||||
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
|
||||
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
|
||||
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
|
||||
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
|
||||
const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
|
||||
const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
|
||||
const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
||||
const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
||||
const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
||||
const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||
const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
||||
const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
||||
|
||||
// NN
|
||||
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
|
||||
const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
|
@ -253,6 +262,7 @@ const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
|||
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
|
||||
|
||||
// Comm ops
|
||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
|
|
|
@ -59,6 +59,12 @@ extern const PrimitivePtr kPrimBoolNot;
|
|||
extern const PrimitivePtr kPrimBoolAnd;
|
||||
extern const PrimitivePtr kPrimBoolOr;
|
||||
extern const PrimitivePtr kPrimBoolEq;
|
||||
extern const PrimitivePtr kPrimGreater;
|
||||
extern const PrimitivePtr kPrimGreaterEqual;
|
||||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
extern const PrimitivePtr kPrimEqual;
|
||||
extern const PrimitivePtr kPrimNotEqual;
|
||||
|
||||
// Type introspection
|
||||
extern const PrimitivePtr kPrimTypeOf;
|
||||
|
@ -157,6 +163,10 @@ extern const PrimitivePtr KPrimTransData;
|
|||
extern const PrimitivePtr kPrimNMSWithMask;
|
||||
extern const PrimitivePtr kPrimPad;
|
||||
extern const PrimitivePtr kPrimArgMaxWithValue;
|
||||
extern const PrimitivePtr kPrimRealDiv;
|
||||
extern const PrimitivePtr kPrimSqrt;
|
||||
extern const PrimitivePtr kPrimReciprocal;
|
||||
extern const PrimitivePtr kPrimExpandDims;
|
||||
|
||||
// Maths
|
||||
extern const PrimitivePtr kPrimTensorAdd;
|
||||
|
@ -183,9 +193,11 @@ extern const PrimitivePtr kPrimCumProd;
|
|||
extern const PrimitivePtr kPrimSubscalar;
|
||||
extern const PrimitivePtr kPrimInplaceAdd;
|
||||
extern const PrimitivePtr kPrimInplaceSub;
|
||||
extern const PrimitivePtr kPrimPow;
|
||||
|
||||
// NN
|
||||
extern const PrimitivePtr kPrimFlatten;
|
||||
extern const PrimitivePtr kPrimSoftmax;
|
||||
extern const PrimitivePtr kPrimLogSoftmax;
|
||||
extern const PrimitivePtr kPrimLogSoftmaxGrad;
|
||||
extern const PrimitivePtr kPrimApplyCenteredRMSProp;
|
||||
|
@ -263,6 +275,7 @@ extern const PrimitivePtr kPrimInDict;
|
|||
extern const PrimitivePtr kPrimNotInDict;
|
||||
extern const PrimitivePtr kPrimMixedPrecisionCast;
|
||||
extern const PrimitivePtr kPrimIsConsant;
|
||||
extern const PrimitivePtr kPrimEquivFormat;
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimAllReduce;
|
||||
|
|
|
@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
|
||||
k_graph_ = std::make_shared<FuncGraph>();
|
||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||
}
|
||||
TraceManager::EndTrace();
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
|
||||
tape_ = std::make_shared<FuncGraph>();
|
||||
// Add "_Grad" postfix
|
||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
|
||||
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||
}
|
||||
TraceManager::EndTrace();
|
||||
|
||||
dout_ = tape_->add_parameter();
|
||||
|
@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|||
(void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
|
||||
(void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
|
||||
// Reset defer_inline to enable successive inlining
|
||||
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||
primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||
|
||||
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
||||
functor->Init();
|
||||
|
|
|
@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
|
||||
if (MsContext::GetInstance()->is_multi_graph_sink()) {
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
f->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -78,7 +78,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(cons);
|
||||
|
||||
auto dt = data->abstract();
|
||||
MS_EXCEPTION_IF_NULL(dt);
|
||||
if (dt == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!dt->isa<AbstractClass>()) {
|
||||
MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "optimizer/graph_kernel_reuse.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "./common.h"
|
||||
#include "utils/graph_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
|
||||
bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) {
|
||||
if (a->abstract() && b->abstract()) {
|
||||
auto a_type = a->abstract()->GetTypeTrack();
|
||||
auto b_type = b->abstract()->GetTypeTrack();
|
||||
|
||||
if (a_type != b_type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto a_shape = a->abstract()->GetShapeTrack();
|
||||
auto b_shape = b->abstract()->GetShapeTrack();
|
||||
if (a_shape != nullptr && a_shape == b_shape) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() &&
|
||||
b_shape->isa<abstract::Shape>()) {
|
||||
return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) {
|
||||
bool changed = false;
|
||||
auto fgs = manager->func_graphs();
|
||||
for (FuncGraphPtr &fg : fgs) {
|
||||
if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
continue;
|
||||
}
|
||||
std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) {
|
||||
if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) {
|
||||
FuncGraphPtr new_fg = nullptr;
|
||||
for (auto &cfg : graph_kernel_ops[key]) {
|
||||
// If two graphs have different size then continue
|
||||
auto fg_topos = TopoSort(fg->get_return());
|
||||
auto cfg_topos = TopoSort(cfg->get_return());
|
||||
if (fg_topos.size() != cfg_topos.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compare const tensor
|
||||
bool has_same = true;
|
||||
for (size_t i = 0; i < fg_topos.size(); ++i) {
|
||||
if (IsValueNode<tensor::Tensor>(fg_topos[i])) {
|
||||
if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) {
|
||||
has_same = false;
|
||||
break;
|
||||
}
|
||||
|
||||
auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]);
|
||||
auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]);
|
||||
if (!tensor1->ValueEqual(*tensor2)) {
|
||||
has_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_same) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto fg_input = fg->parameters();
|
||||
auto cfg_input = cfg->parameters();
|
||||
if (fg_input.size() != cfg_input.size()) {
|
||||
continue;
|
||||
}
|
||||
// Compare input
|
||||
for (size_t i = 0; i < fg_input.size(); ++i) {
|
||||
if (!CompareNode(fg_input[i], cfg_input[i])) {
|
||||
has_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_same) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compare output
|
||||
if (!CompareNode(fg->output(), cfg->output())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find reusable fg
|
||||
new_fg = cfg;
|
||||
break;
|
||||
}
|
||||
|
||||
if (new_fg != nullptr) {
|
||||
// Replace current fg with existing fg
|
||||
auto users = fg->func_graph_cnodes_index();
|
||||
for (auto &iter : users) {
|
||||
auto cnode = iter.first->first->cast<CNodePtr>();
|
||||
auto new_input = cnode->inputs();
|
||||
auto main_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(main_graph);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
|
||||
new_input[1] = NewValueNode(new_fg);
|
||||
} else {
|
||||
new_input[0] = NewValueNode(new_fg);
|
||||
}
|
||||
auto new_cnode = main_graph->NewCNode(new_input);
|
||||
manager->Replace(iter.first->first, new_cnode);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
} else {
|
||||
// Add current fg to map
|
||||
graph_kernel_ops[key].push_back(fg);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
graph_kernel_ops[key] = {fg};
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(root);
|
||||
|
||||
return DoReplace(manager);
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H
|
||||
|
||||
#include <mindspore/ccsrc/session/anf_runtime_algorithm.h>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
// Common subexpression elimination.
|
||||
class GraphKernelReuse {
|
||||
public:
|
||||
GraphKernelReuse() : count(0) {}
|
||||
virtual ~GraphKernelReuse() = default;
|
||||
|
||||
bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
|
||||
bool chg = ReuseGraphKernel(root, optimizer->resource()->manager());
|
||||
return chg;
|
||||
}
|
||||
|
||||
bool CompareNode(const AnfNodePtr a, const AnfNodePtr other);
|
||||
bool DoReplace(const FuncGraphManagerPtr manager);
|
||||
|
||||
bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops;
|
||||
int count;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H
|
|
@ -41,6 +41,8 @@
|
|||
#include "optimizer/irpass/incorporate_call.h"
|
||||
#include "optimizer/irpass/grad_var_prepare.h"
|
||||
#include "optimizer/irpass/param_replace.h"
|
||||
#include "optimizer/irpass/mark_interface_fusion.h"
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -48,7 +50,7 @@ namespace irpass {
|
|||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
|
||||
special_op_eliminate_ =
|
||||
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
||||
|
@ -90,7 +92,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
replace_refkey_by_param_ =
|
||||
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
|
||||
|
||||
// Gradient transforms
|
||||
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
|
||||
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
||||
|
@ -115,6 +116,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Incorporation
|
||||
incorporate_getitem_set_ =
|
||||
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
||||
incorporate_getitem_from_param_ =
|
||||
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
|
||||
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
|
||||
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
|
||||
|
||||
|
@ -124,6 +127,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
// Convert
|
||||
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
|
||||
|
||||
// Unused parameter eliminate
|
||||
unused_parameter_eliminate_ =
|
||||
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
|
||||
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// AddN eliminate
|
||||
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// Mark interface fusion
|
||||
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
|
|
|
@ -84,6 +84,7 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Incorporation
|
||||
SubstitutionPtr incorporate_getitem_set_;
|
||||
SubstitutionPtr incorporate_getitem_from_param_;
|
||||
SubstitutionPtr incorporate_call_;
|
||||
SubstitutionPtr incorporate_call_switch_;
|
||||
|
||||
|
@ -92,6 +93,16 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Convert
|
||||
SubstitutionPtr print_tuple_wrapper_;
|
||||
|
||||
// Unused parameter eliminate
|
||||
SubstitutionPtr unused_parameter_eliminate_;
|
||||
SubstitutionPtr unused_output_eliminate_;
|
||||
|
||||
// AddN eliminate
|
||||
SubstitutionPtr addn_eliminate_;
|
||||
|
||||
// Fusion
|
||||
SubstitutionPtr mark_interface_fusion_;
|
||||
};
|
||||
|
||||
// the collection of irpass for resolve action
|
||||
|
@ -145,6 +156,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
|||
return IsValueNode<FuncGraph>(inp0);
|
||||
}
|
||||
|
||||
// Check if CNode Input 0 is Func Graph of graph kernel.
|
||||
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inp0 = node->cast<CNodePtr>()->input(0);
|
||||
if (IsValueNode<FuncGraph>(inp0)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inp0);
|
||||
if (fg == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if CNode Input 0 is CNode
|
||||
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
|
|
|
@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// Support class used for checking if all values of a Tensor are equal `check_value_`
|
||||
// Supported data types: double, float/float32, int/int32
|
||||
class CheckTensorConstant {
|
||||
public:
|
||||
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
|
||||
~CheckTensorConstant() = default;
|
||||
bool IsTensorConstant(const ValuePtr &value) {
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
return false;
|
||||
}
|
||||
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
|
||||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||
if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
|
||||
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
|
||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||
if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
|
||||
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
|
||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||
if (data2[i] != check_value_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// Un-support Data Types
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsTensorScalarConstant(const ValuePtr &value) {
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
return false;
|
||||
}
|
||||
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
|
||||
if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
|
||||
return false;
|
||||
}
|
||||
return IsTensorConstant(value);
|
||||
}
|
||||
|
||||
private:
|
||||
int check_value_;
|
||||
};
|
||||
|
||||
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
|
||||
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
|
||||
class TensorMultiplyByZeroOrOne : public AnfVisitor {
|
||||
public:
|
||||
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {}
|
||||
~TensorMultiplyByZeroOrOne() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||
|
||||
if (is_zero_) {
|
||||
if (x_->func_graph() != node->func_graph()) {
|
||||
return nullptr;
|
||||
}
|
||||
return NewTensorFilledWithData(node);
|
||||
}
|
||||
if (is_one_) {
|
||||
return NewTensorFilledWithData(node, x_);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (is_zero_ || is_one_) {
|
||||
x_ = node;
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsParam(node)) {
|
||||
x_ = node;
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsCNode(node)) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
x_ = node;
|
||||
return;
|
||||
}
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
if (CheckTensorConstant(0).IsTensorConstant(value)) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
|
||||
is_one_ = true;
|
||||
return;
|
||||
}
|
||||
x_ = node;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
auto value = vnode->value();
|
||||
if (CheckTensorConstant(0).IsTensorConstant(value)) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
|
||||
is_one_ = true;
|
||||
return;
|
||||
}
|
||||
x_ = vnode;
|
||||
}
|
||||
void Reset() {
|
||||
x_ = nullptr;
|
||||
is_one_ = false;
|
||||
is_zero_ = false;
|
||||
}
|
||||
|
||||
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) {
|
||||
if (!node->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
|
||||
return tensor_ptr->data_c(writable);
|
||||
}
|
||||
|
||||
// Make a new tensor (when possible) with the same shape as of `node`
|
||||
// If x is nullptr then fill new tensor will "0"
|
||||
// If x is a tensor with empty shape then fill new tensor with the single value of x
|
||||
// If x is a tensor with same shape as `node` then return x as result
|
||||
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) {
|
||||
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
|
||||
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
|
||||
|
||||
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
|
||||
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
|
||||
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
|
||||
|
||||
if (x == nullptr) {
|
||||
std::memset(data, 0, mem_size);
|
||||
auto new_vnode = NewValueNode(new_tensor_ptr);
|
||||
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
|
||||
return new_vnode;
|
||||
}
|
||||
// x is not nullptr
|
||||
if (x->isa<CNode>()) {
|
||||
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
std::vector<int> x_shape = x_abstract->shape()->shape();
|
||||
|
||||
if (x_shape != tensor_shape) {
|
||||
return nullptr;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
if (!x->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto x_value = x->cast<ValueNodePtr>()->value();
|
||||
if (!x_value->isa<tensor::Tensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
|
||||
|
||||
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
|
||||
return nullptr;
|
||||
}
|
||||
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
|
||||
if (x_tensor_ptr->DataSize() == 1) {
|
||||
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
|
||||
memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr));
|
||||
}
|
||||
} else {
|
||||
memcpy(source_data, data, mem_size);
|
||||
}
|
||||
auto new_vnode = NewValueNode(new_tensor_ptr);
|
||||
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
|
||||
return new_vnode;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_zero_{false}, is_one_{false};
|
||||
ValuePtr zero_;
|
||||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimScalarAdd, X, 0}
|
||||
// {prim::kPrimScalarAdd, 0, X}
|
||||
class AddByZero : public AnfVisitor {
|
||||
|
@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor {
|
|||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) {
|
||||
if (node->isa<ValueNode>() &&
|
||||
((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
|
@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor {
|
|||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
x_ = node;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
auto value = vnode->value();
|
||||
if (CheckTensorConstant(0).IsTensorConstant(value)) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
x_ = nullptr;
|
||||
is_zero_ = false;
|
||||
|
@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor {
|
|||
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
|
||||
class ConstantDuplicateMul : public AnfVisitor {
|
||||
public:
|
||||
// Support function to multiply two constant tensors: partially support broadcasting shapes
|
||||
template <typename T>
|
||||
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
|
||||
int out_data_size) {
|
||||
T *data_1 = reinterpret_cast<T *>(in_data_1);
|
||||
T *data_2 = reinterpret_cast<T *>(in_data_2);
|
||||
T *data_out = new T[out_data_size];
|
||||
|
||||
if (in_data_1_size == 1) {
|
||||
for (int i = 0; i < out_data_size; i++) {
|
||||
data_out[i] = data_1[0];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out_data_size; i++) {
|
||||
data_out[i] = data_1[i];
|
||||
}
|
||||
}
|
||||
if (in_data_2_size == 1) {
|
||||
for (int i = 0; i < out_data_size; i++) {
|
||||
data_out[i] *= data_2[0];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out_data_size; i++) {
|
||||
data_out[i] *= data_2[i];
|
||||
}
|
||||
}
|
||||
*out_data = reinterpret_cast<void *>(data_out);
|
||||
return;
|
||||
}
|
||||
|
||||
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) {
|
||||
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
|
||||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value_1 = GetValueNode(vnode_1);
|
||||
auto value_2 = GetValueNode(vnode_2);
|
||||
|
||||
if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
|
||||
auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
|
||||
|
||||
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
|
||||
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
|
||||
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
|
||||
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
|
||||
|
||||
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
|
||||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
|
||||
|
||||
int data_out_size = 1;
|
||||
for (auto it : tensor_out_shape) {
|
||||
data_out_size *= it;
|
||||
}
|
||||
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
|
||||
return nullptr;
|
||||
}
|
||||
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *data_out;
|
||||
|
||||
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
|
||||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
|
||||
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
|
||||
tensor_ptr_2->DataSize(), &data_out, data_out_size);
|
||||
} else {
|
||||
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
|
||||
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
|
||||
tensor_ptr_2->DataSize(), &data_out, data_out_size);
|
||||
} else {
|
||||
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
|
||||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
|
||||
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
|
||||
tensor_ptr_2->DataSize(), &data_out, data_out_size);
|
||||
} else {
|
||||
// Un-support data types
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
|
||||
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
|
||||
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
|
||||
memcpy(data, data_out, mem_size);
|
||||
|
||||
auto new_vnode = NewValueNode(new_tensor_ptr);
|
||||
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
|
||||
return new_vnode;
|
||||
}
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
// {prim::kPrimMul, Tensor1, {...}}
|
||||
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
|
||||
if (vnode_ == nullptr || cnode_ == nullptr) {
|
||||
if (vnode_ == nullptr || c_p_node_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!IsCNode(c_p_node_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tensor1 = vnode_;
|
||||
auto mul = cnode_;
|
||||
auto mul = c_p_node_->cast<CNodePtr>();
|
||||
|
||||
Reset();
|
||||
// {prim::kPrimMul, Tensor2, {...}}
|
||||
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
|
||||
if (vnode_ == nullptr || cnode_ == nullptr) {
|
||||
if (vnode_ == nullptr || c_p_node_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor2 = vnode_;
|
||||
auto cnode = cnode_;
|
||||
auto c_p_node = c_p_node_;
|
||||
|
||||
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
|
||||
auto fg = node->func_graph();
|
||||
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
|
||||
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
|
||||
|
||||
auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node);
|
||||
if (new_mul_tensor == nullptr) {
|
||||
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
|
||||
return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg);
|
||||
}
|
||||
return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
|
@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor {
|
|||
vnode_ = node;
|
||||
}
|
||||
|
||||
if (IsCNode(node)) {
|
||||
cnode_ = node->cast<CNodePtr>();
|
||||
if (IsCNode(node) || IsParam(node)) {
|
||||
c_p_node_ = node;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
vnode_ = nullptr;
|
||||
cnode_ = nullptr;
|
||||
c_p_node_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr vnode_;
|
||||
CNodePtr cnode_;
|
||||
AnfNodePtr c_p_node_;
|
||||
};
|
||||
|
||||
class PowerOneEliminate : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (!IsValueNode<Scalar>(inputs[2])) {
|
||||
return nullptr;
|
||||
}
|
||||
auto scalar = GetValueNode<ScalarPtr>(inputs[2]);
|
||||
if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) {
|
||||
return inputs[1];
|
||||
} else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) {
|
||||
return inputs[1];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// grad = AllReduce(grad) / worker_number
|
||||
|
@ -341,17 +699,21 @@ class ArithmeticSimplify {
|
|||
public:
|
||||
ArithmeticSimplify()
|
||||
: multiply_by_zero_or_one_(),
|
||||
tensor_multiply_by_zero_or_one_(),
|
||||
add_by_zero_(),
|
||||
tensor_add_by_zero_(),
|
||||
identity_(prim::kPrimIdentity),
|
||||
opt_update_zero_tensor_(),
|
||||
constant_duplicate_mul_() {
|
||||
constant_duplicate_mul_(),
|
||||
power_one_() {
|
||||
eliminaters_.emplace_back(multiply_by_zero_or_one_);
|
||||
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_);
|
||||
eliminaters_.emplace_back(add_by_zero_);
|
||||
eliminaters_.emplace_back(tensor_add_by_zero_);
|
||||
eliminaters_.emplace_back(identity_);
|
||||
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
||||
eliminaters_.emplace_back(constant_duplicate_mul_);
|
||||
eliminaters_.emplace_back(power_one_);
|
||||
}
|
||||
~ArithmeticSimplify() = default;
|
||||
|
||||
|
@ -368,11 +730,13 @@ class ArithmeticSimplify {
|
|||
|
||||
private:
|
||||
MultiplyByZeroOrOne multiply_by_zero_or_one_;
|
||||
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_;
|
||||
AddByZero add_by_zero_;
|
||||
TensorAddByZero tensor_add_by_zero_;
|
||||
PrimEliminater identity_;
|
||||
OptUpdateZeroTensor opt_update_zero_tensor_;
|
||||
ConstantDuplicateMul constant_duplicate_mul_;
|
||||
PowerOneEliminate power_one_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
@ -28,7 +29,6 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
|
||||
|
||||
if (node->func_graph() != nullptr && idx_ >= 0 && fg_ != nullptr) {
|
||||
auto new_fg = getitem_transform_(fg_, idx_);
|
||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||
return node->func_graph()->NewCNode(args_);
|
||||
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
|
||||
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// If graph kernel has muti output, do not split.
|
||||
// some graph kernel output has EnvInstance node or DeadCode node should split.
|
||||
auto output = fg_->output();
|
||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
auto outputs = output_cnode->inputs();
|
||||
int real_output_cnt = 0;
|
||||
for (size_t i = 1; i < outputs.size(); ++i) {
|
||||
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
|
||||
real_output_cnt++;
|
||||
if (real_output_cnt > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto new_fg = getitem_transform_(fg_, idx_);
|
||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||
return node->func_graph()->NewCNode(args_);
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
|
@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
internal::GetitemTransform getitem_transform_;
|
||||
};
|
||||
|
||||
class IncorporateGetitemFromParam : public AnfVisitor {
|
||||
public:
|
||||
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &node_users = mng->node_users();
|
||||
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
|
||||
args_.push_back(cnode->input(input_idx + 1));
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &user : node_users[param]) {
|
||||
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
|
||||
// we do not process this case.
|
||||
args_.push_back(cnode->input(input_idx + 1));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// update new args.
|
||||
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
|
||||
// case 1
|
||||
replace_parameters_[input_idx] = true;
|
||||
need_update_ = true;
|
||||
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
||||
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
|
||||
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
|
||||
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
|
||||
} else {
|
||||
// case 2
|
||||
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
||||
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
|
||||
auto fg_output = prev_fg->output();
|
||||
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
|
||||
<< " should be a make tuple, but got: " << fg_output->DebugString();
|
||||
return;
|
||||
}
|
||||
replace_parameters_[input_idx] = true;
|
||||
need_update_ = true;
|
||||
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
|
||||
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
|
||||
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
|
||||
auto new_getitem =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))});
|
||||
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(SizeToInt(output_i)));
|
||||
new_getitem->input(2)->set_abstract(aptr);
|
||||
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
|
||||
args_.push_back(new_getitem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reset();
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto parameters = fg->parameters();
|
||||
if (parameters.size() != inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
replace_parameters_ = std::vector<bool>(parameters.size(), false);
|
||||
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
|
||||
auto node_fg = node->func_graph();
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
|
||||
Process(node_fg, cnode, parameters[i - 1], i - 1);
|
||||
} else {
|
||||
args_.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!need_update_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
|
||||
auto node_users = mng->node_users();
|
||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
size_t curr_input_idx{0};
|
||||
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
|
||||
if (!replace_parameters_[param_i]) {
|
||||
if (parameters[param_i]->abstract() != nullptr) {
|
||||
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
|
||||
}
|
||||
new_parameters.push_back(new_fg_parameters[param_i]);
|
||||
curr_input_idx++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// make a new parameter.
|
||||
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
|
||||
auto new_param = std::make_shared<Parameter>(new_fg);
|
||||
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
|
||||
|
||||
// update users of new parameter.
|
||||
for (auto &user : node_users[new_fg_parameters[param_i]]) {
|
||||
idx_ = -1;
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int32Imm>})(user.first);
|
||||
if (idx_ == -1) {
|
||||
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
|
||||
<< " must be tuple getitem here, but got: " << user.first->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (input_i == IntToSize(idx_)) {
|
||||
for (auto &sub_user : node_users[user.first]) {
|
||||
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sub_user_cnode);
|
||||
sub_user_cnode->set_input(sub_user.second, new_param);
|
||||
(void)mng->Replace(sub_user.first, sub_user_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
|
||||
new_parameters.push_back(new_param);
|
||||
curr_input_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
mng->SetParameters(new_fg, new_parameters);
|
||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||
auto new_call = node_fg->NewCNode(args_);
|
||||
new_call->set_abstract(node->abstract());
|
||||
return new_call;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int>(vnode->value()); }
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {}
|
||||
|
||||
void Reset() {
|
||||
replace_parameters_.clear();
|
||||
args_.clear();
|
||||
inputs_num_.clear();
|
||||
need_update_ = false;
|
||||
idx_ = -1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<bool> replace_parameters_{};
|
||||
std::vector<AnfNodePtr> args_{};
|
||||
std::vector<size_t> inputs_num_{};
|
||||
bool need_update_{false};
|
||||
int idx_{-1};
|
||||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
|
||||
class IncorporateGetitemSwitch : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -86,20 +86,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
|
|||
|
||||
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
auto &flags = node->func_graph()->flags();
|
||||
if (flags.find("inline_inside") != flags.end()) {
|
||||
return flags["inline_inside"];
|
||||
}
|
||||
return false;
|
||||
return node->func_graph()->has_flag("inline_inside");
|
||||
}
|
||||
|
||||
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) {
|
||||
auto &flags = fg->flags();
|
||||
if (flags.find("core") != flags.end()) {
|
||||
return flags["core"];
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
|
||||
|
||||
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
|
||||
|
||||
|
@ -123,6 +113,13 @@ class InlinerBase : public AnfVisitor {
|
|||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
return nullptr;
|
||||
}
|
||||
// Do not inline GraphKernel to Cell.
|
||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// If the GraphKernel only contains a return node, we make it inlined.
|
||||
if (fg->nodes().size() - fg->parameters().size() > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Reset();
|
||||
bool is_match = false;
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "operator/composite/composite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
static int count = 0;
|
||||
|
||||
std::string GetFusionNumber() {
|
||||
std::stringstream ss;
|
||||
ss << std::setw(4) << std::setfill('0') << count;
|
||||
std::string num = ss.str();
|
||||
++count;
|
||||
|
||||
return "_" + num;
|
||||
}
|
||||
|
||||
// Mark CNodes which can be merged in kernel build
|
||||
class MarkInterfaceFusion : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto condition = cnode->input(1);
|
||||
std::string cmp;
|
||||
std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"},
|
||||
{"LessEqual", "LE"}, {"Less", "LT"},
|
||||
{"Equal", "EQ"}, {"NotEqual", "NE"}};
|
||||
if (IsPrimitiveCNode(condition)) {
|
||||
auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>());
|
||||
if (cmp_list.count(prim_name) != 0) {
|
||||
// Mark Select and compare node
|
||||
cmp = cmp_list[prim_name];
|
||||
auto cnt = GetFusionNumber();
|
||||
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition);
|
||||
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) {
|
||||
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
|
||||
private:
|
||||
AnfNodePtr y_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor {
|
|||
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
||||
bool has_zero_like_{false};
|
||||
};
|
||||
|
||||
// {PrimAddN, {kPrimMakeTuple, Xs}}
|
||||
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
|
||||
// case0: AddN(inputs)(inputs size < 2) -> error
|
||||
// case1: AddN(inputs)(all inputs is ValueNode) -> error
|
||||
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
|
||||
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
|
||||
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
||||
class AddNEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
if (fg->recursive()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
need_update_ = false;
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
changed |= Process(new_fg);
|
||||
} while (changed);
|
||||
|
||||
if (!need_update_) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto new_sx = inputs;
|
||||
new_sx[0] = NewValueNode(new_fg);
|
||||
return node->func_graph()->NewCNode(new_sx);
|
||||
}
|
||||
}
|
||||
|
||||
bool Process(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto nodes = TopoSort(func_graph->output());
|
||||
bool changed = false;
|
||||
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
auto node = nodes[i];
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &tuple_input = cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
|
||||
auto &tuple_inputs = tuple_input_cnode->inputs();
|
||||
if (tuple_inputs.size() < 3) {
|
||||
// case0: inputs size < 2, error
|
||||
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
|
||||
}
|
||||
|
||||
int valuenode_num =
|
||||
std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) {
|
||||
if (IsValueNode<tensor::Tensor>(node)) {
|
||||
return accumulator + 1;
|
||||
} else {
|
||||
return accumulator;
|
||||
}
|
||||
});
|
||||
if (IntToSize(valuenode_num) == tuple_inputs.size()) {
|
||||
// case1: all inputs is ValueNode, error
|
||||
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
|
||||
}
|
||||
|
||||
if (tuple_inputs.size() == 3) {
|
||||
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
|
||||
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
|
||||
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
|
||||
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
|
||||
tuple_inputs[2]};
|
||||
mng->Replace(node, func_graph->NewCNode(new_xs));
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
||||
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
|
||||
if (first_valuenode == tuple_inputs.end()) {
|
||||
// no ValueNode input found.
|
||||
continue;
|
||||
} else {
|
||||
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
||||
std::vector<AnfNodePtr> make_tuple_new_xs{
|
||||
NewValueNode(prim::kPrimMakeTuple),
|
||||
};
|
||||
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
||||
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
|
||||
if (node != *first_valuenode) {
|
||||
make_tuple_new_xs.push_back(node);
|
||||
}
|
||||
});
|
||||
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
|
||||
auto new_addn = func_graph->NewCNode(
|
||||
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
|
||||
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
|
||||
auto new_add =
|
||||
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
|
||||
(void)mng->Replace(node, new_add);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
need_update_ |= changed;
|
||||
return changed;
|
||||
}
|
||||
|
||||
private:
|
||||
bool need_update_{false};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (x_ == nullptr) {
|
||||
if (!IsVNode(node) && x_ == nullptr) {
|
||||
if (IsValueNode<tensor::Tensor>(node)) {
|
||||
is_tensor_ = true;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include "optimizer/irpass.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "operator/composite/composite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor {
|
|||
this->y_ = node;
|
||||
return true;
|
||||
};
|
||||
|
||||
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
|
||||
return y_;
|
||||
}
|
||||
|
|
|
@ -142,7 +142,7 @@ class ResetDeferInline : public AnfVisitor {
|
|||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
fg->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
@ -41,7 +42,7 @@ class SpecializeTransform {
|
|||
~SpecializeTransform() = default;
|
||||
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
|
||||
std::vector<PrimitivePtr> prim_args) {
|
||||
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
|
||||
if (cache_.count(func_graph) == 0) {
|
||||
cache_[func_graph] = {};
|
||||
}
|
||||
|
@ -69,6 +70,13 @@ class SpecializeTransform {
|
|||
(void)mng->Replace(params[i], arg);
|
||||
continue;
|
||||
}
|
||||
if (value_args[i] != nullptr) {
|
||||
auto const_tensor = *value_args[i];
|
||||
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
|
||||
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
|
||||
(void)mng->Replace(params[i], arg);
|
||||
continue;
|
||||
}
|
||||
new_params.push_back(params[i]);
|
||||
}
|
||||
|
||||
|
@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
|
||||
std::vector<FuncGraphPtr> graph_args;
|
||||
std::vector<PrimitivePtr> prim_args;
|
||||
std::vector<tensor::TensorPtr> value_node_args;
|
||||
std::vector<AnfNodePtr> new_xs;
|
||||
bool hasVNode = false;
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
|
@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
|
||||
graph_args.push_back(fg_vnode);
|
||||
prim_args.emplace_back(nullptr);
|
||||
value_node_args.emplace_back(nullptr);
|
||||
hasVNode = true;
|
||||
} else if (IsValueNode<Primitive>(inputs[i])) {
|
||||
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.push_back(p_vnode);
|
||||
value_node_args.emplace_back(nullptr);
|
||||
hasVNode = true;
|
||||
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
|
||||
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.emplace_back(nullptr);
|
||||
value_node_args.emplace_back(t_vnode);
|
||||
hasVNode = true;
|
||||
} else {
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.emplace_back(nullptr);
|
||||
value_node_args.emplace_back(nullptr);
|
||||
new_xs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args);
|
||||
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
|
||||
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
||||
|
||||
return node->func_graph()->NewCNode(new_xs);
|
||||
|
@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
private:
|
||||
internal::SpecializeTransform specialize_transform_;
|
||||
};
|
||||
|
||||
// Eliminate unused parameters.
|
||||
// {G, Xs}
|
||||
class UnusedParasEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
||||
std::vector<AnfNodePtr> parameters = fg->parameters();
|
||||
size_t size = parameters.size();
|
||||
if (size != inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_xs;
|
||||
std::vector<bool> keep_parameters;
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &node_users = mng->node_users();
|
||||
bool has_unused_para = false;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto iter = node_users.find(parameters[i]);
|
||||
if (iter != node_users.end() && !iter->second.empty()) {
|
||||
keep_parameters.push_back(true);
|
||||
new_xs.push_back(inputs[i + 1]);
|
||||
continue;
|
||||
}
|
||||
keep_parameters.push_back(false);
|
||||
has_unused_para = true;
|
||||
}
|
||||
|
||||
if (!has_unused_para) {
|
||||
return nullptr;
|
||||
}
|
||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
|
||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (keep_parameters[i]) {
|
||||
if (parameters[i]->abstract() != nullptr) {
|
||||
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
|
||||
}
|
||||
new_parameters.push_back(new_fg_parameters[i]);
|
||||
}
|
||||
}
|
||||
mng->SetParameters(new_fg, new_parameters);
|
||||
|
||||
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
||||
return node->func_graph()->NewCNode(new_xs);
|
||||
}
|
||||
};
|
||||
|
||||
// Eliminate unused outputs.
|
||||
// {G, Xs}
|
||||
class UnusedOutputEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
if (fg->recursive()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
auto new_fg_output = new_fg->output();
|
||||
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto output_cnode = new_fg_output->cast<CNodePtr>();
|
||||
auto &node_users = mng->node_users();
|
||||
if (node_users.count(node) == 0 || node_users[node].empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unordered_set<int> used_output_idx;
|
||||
std::vector<std::pair<AnfNodePtr, int>> all_users;
|
||||
for (auto &node_user : node_users[node]) {
|
||||
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto user_cnode = node_user.first->cast<CNodePtr>();
|
||||
size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
|
||||
used_output_idx.insert(used_idx);
|
||||
all_users.push_back(std::make_pair(node_user.first, used_idx));
|
||||
}
|
||||
|
||||
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
|
||||
// all output has users.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (used_output_idx.empty()) {
|
||||
// we do not process this case.
|
||||
return nullptr;
|
||||
} else if (used_output_idx.size() == 1) {
|
||||
// after eliminate, only one output left.
|
||||
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
|
||||
// update users.
|
||||
for (auto &ret_user : all_users) {
|
||||
(void)mng->Replace(ret_user.first, node);
|
||||
}
|
||||
} else {
|
||||
// after eliminate, create new multi output.
|
||||
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
|
||||
std::unordered_map<int, int> new_idx_map;
|
||||
for (auto idx : used_output_idx) {
|
||||
new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1);
|
||||
new_output_inputs.push_back(output_cnode->input(idx + 1));
|
||||
}
|
||||
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
|
||||
// update users.
|
||||
for (auto &ret_user : all_users) {
|
||||
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
|
||||
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
|
||||
}
|
||||
}
|
||||
|
||||
auto new_sx = inputs;
|
||||
new_sx[0] = NewValueNode(new_fg);
|
||||
return node->func_graph()->NewCNode(new_sx);
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -89,7 +89,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
|
|||
class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
||||
public:
|
||||
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
|
||||
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {}
|
||||
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {}
|
||||
virtual ~Optimizer() = default;
|
||||
|
||||
void Init(const OptPassGroupMap &passes, bool run_only_once) {
|
||||
|
@ -132,6 +132,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
}
|
||||
|
||||
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
|
||||
if (!is_enable_) {
|
||||
return func_graph;
|
||||
}
|
||||
// Optimizer step counter;
|
||||
int counter = -1;
|
||||
bool changes = true;
|
||||
|
@ -171,7 +174,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
};
|
||||
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
|
||||
if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) {
|
||||
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||
auto fg_name =
|
||||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
|
@ -211,6 +214,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
void enable_watch_renormalize() { is_watch_renormalize_ = true; }
|
||||
void disable_watch_renormalize() { is_watch_renormalize_ = false; }
|
||||
bool is_watch_renormalize() { return is_watch_renormalize_; }
|
||||
void set_enable(bool enable) { is_enable_ = enable; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
|
@ -220,6 +224,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
bool run_only_once_;
|
||||
std::vector<AnfNodePtr> untyped_nodes_;
|
||||
bool is_watch_renormalize_;
|
||||
bool is_enable_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
|
||||
|
||||
// allreduce fusion only run once
|
||||
root->flags()[ALLREDUCE_FUSION_RUN_ONCE_ONLY] = true;
|
||||
root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
|
||||
res->results()[pipeline::kStepParallelGraph] = root;
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -158,8 +158,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) ||
|
||||
func_graph->flags()[TRAINING]) {
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) ||
|
||||
func_graph->has_flag(TRAINING)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
|
||||
|
||||
root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true;
|
||||
root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
|
||||
return changes;
|
||||
}
|
||||
|
||||
|
|
|
@ -2270,10 +2270,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
|
||||
if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
|
||||
if (HasStrategy(root)) {
|
||||
MS_LOG(INFO) << "strategies ignored in " << parallel_mode
|
||||
MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
|
||||
<< ", set_strategy() only valid in [semi_]auto_parallel.";
|
||||
}
|
||||
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true;
|
||||
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
|
||||
}
|
||||
|
||||
return changes;
|
||||
|
@ -2330,11 +2330,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
DumpGraph(root, std::string(STEP_PARALLEL_END));
|
||||
|
||||
// step parallel only run once
|
||||
root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true;
|
||||
root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
|
||||
res->results()[pipeline::kStepParallelGraph] = root;
|
||||
|
||||
// in auto parallel mode, no need to check if stategies set
|
||||
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true;
|
||||
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
|
||||
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
|
|
|
@ -151,7 +151,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
|
||||
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
|
||||
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
|
||||
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.");
|
||||
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
|
||||
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
||||
"Set the GraphKernel switch to on or off.")
|
||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -278,7 +278,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
|||
if (bprop_graph != nullptr) {
|
||||
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
|
||||
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
}
|
||||
}
|
||||
*data = func_graph;
|
||||
|
|
|
@ -1448,15 +1448,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
|
||||
for (auto &item : flags) {
|
||||
if (!py::isinstance<py::str>(item.first) || !py::isinstance<py::bool_>(item.second)) {
|
||||
if (!py::isinstance<py::str>(item.first)) {
|
||||
MS_LOG(ERROR) << "Type error in flags dict convert";
|
||||
return false;
|
||||
}
|
||||
auto name = py::cast<std::string>(item.first);
|
||||
auto value = py::cast<bool>(item.second);
|
||||
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
|
||||
|
||||
func_graph->set_flags(name, value);
|
||||
if (py::isinstance<py::bool_>(item.second)) {
|
||||
auto value = py::cast<bool>(item.second);
|
||||
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
|
||||
func_graph->set_flag(name, value);
|
||||
} else if (py::isinstance<py::str>(item.second)) {
|
||||
auto value = py::cast<std::string>(item.second);
|
||||
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
|
||||
func_graph->set_attr(name, MakeValue(value));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -223,8 +223,8 @@ class Parser {
|
|||
FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
|
||||
// In order to keep effect order in the sub-graphs which generated by control flow.
|
||||
// We copy the flags from the top graph to the sub-graphs.
|
||||
if (func_graph_ && !func_graph_->flags().empty()) {
|
||||
block->func_graph()->set_flags(func_graph_->flags());
|
||||
if (func_graph_ && !func_graph_->attrs().empty()) {
|
||||
block->func_graph()->set_attrs(func_graph_->attrs());
|
||||
}
|
||||
func_block_list_.push_back(block);
|
||||
return block;
|
||||
|
|
|
@ -25,12 +25,14 @@
|
|||
#include <functional>
|
||||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
#include "pipeline/resource.h"
|
||||
#include "pipeline/validator.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/cse.h"
|
||||
#include "optimizer/graph_kernel_reuse.h"
|
||||
#include "optimizer/clean.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/control_depend.h"
|
||||
|
@ -38,6 +40,7 @@
|
|||
#include "parallel/step_auto_parallel.h"
|
||||
#include "parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
|
@ -162,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig interface_fusion = opt::OptPassConfig({
|
||||
irpass.mark_interface_fusion_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())},
|
||||
{"interface_fusion", interface_fusion},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"cse", opt::OptPassConfig(opt::CSE(false))},
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig elim_1 = opt::OptPassConfig({
|
||||
irpass.addn_eliminate_,
|
||||
irpass.incorporate_getitem_from_param_,
|
||||
});
|
||||
opt::OptPassConfig elim_2 = opt::OptPassConfig({
|
||||
irpass.unused_parameter_eliminate_,
|
||||
irpass.unused_output_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"elim_1", elim_1},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"elim_2", elim_2},
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
}
|
||||
|
||||
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true);
|
||||
OptPassGroupMap map({
|
||||
|
@ -191,8 +228,19 @@ void InitOpt(const ResourcePtr &res) {
|
|||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
|
||||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
|
||||
g_pass_opts["opt_graph_kernel_a"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
|
||||
g_pass_opts["opt_graph_kernel_b"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
|
||||
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
|
||||
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
||||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
|
||||
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -224,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
|||
|
||||
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
|
||||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
|
||||
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
||||
|
||||
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
||||
|
||||
bool AddControlDependPass(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -270,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
|||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"cconv", CconvPass}};
|
||||
{"cconv", CconvPass},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
|
|
|
@ -488,7 +488,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
|
|||
#ifdef ENABLE_INFER
|
||||
// Now don't use the graph because the exec ge function don't take effect
|
||||
MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph);
|
||||
if (ENABLE_TRAIN != info.at(phase)->func_graph->flags()["training"]) {
|
||||
if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) {
|
||||
MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries";
|
||||
ConfigManager::GetInstance().ResetConfig();
|
||||
return py::none();
|
||||
|
|
|
@ -165,7 +165,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
return joined_args_spec_list;
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list == args_spec_list)) {
|
||||
trace_.push_back(joined_args_spec_list);
|
||||
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
return joined_args_spec_list;
|
||||
|
|
|
@ -479,7 +479,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|||
if (undetermined_fgs) {
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/ascend/ir_fission/bn_split.h"
|
||||
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
|
||||
|
@ -63,6 +64,9 @@
|
|||
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||
#include "pre_activate/pass/eliminate_redundant_op.h"
|
||||
#include "pre_activate/pass/common_subexpression_elimination.h"
|
||||
#include "pre_activate/pass/fuse_graph_kernel.h"
|
||||
#include "pre_activate/pass/fuse_basic.h"
|
||||
#include "pre_activate/pass/add_atomic_clean.h"
|
||||
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
|
||||
#include "pre_activate/ascend/format_type/check_consistency.h"
|
||||
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
|
||||
|
@ -88,6 +92,8 @@
|
|||
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
|
||||
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||
#include "pre_activate/ascend/ir_fission/split_fission.h"
|
||||
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
|
||||
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
@ -164,6 +170,19 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
|
|||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
|
||||
MS_EXCEPTION_IF_NULL(common_process);
|
||||
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
|
||||
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
|
||||
optimizer->AddPassManager(common_process);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
|
@ -332,7 +351,94 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
std::string file_path =
|
||||
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id()));
|
||||
DumpIRProto(kernel_graph, "after_hwopt");
|
||||
kernel_graph->DumpFuncGraph("hwopt_d_end");
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" +
|
||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||
".ir";
|
||||
DumpIR(file_path, kernel_graph);
|
||||
}
|
||||
|
||||
// Fuse graph kernels with basic ops
|
||||
FuseGraphKernel(kernel_graph, is_before_kernel_select);
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
|
||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||
".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" +
|
||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||
".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
}
|
||||
|
||||
// Fuse basic ops with basic ops
|
||||
FuseBasic(kernel_graph, is_before_kernel_select);
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" +
|
||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||
".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
if (save_graphs) {
|
||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph);
|
||||
}
|
||||
|
||||
AddAtomicClean(kernel_graph);
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_path =
|
||||
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,12 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select = false);
|
||||
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select = false);
|
||||
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "session/kernel_graph.h"
|
||||
|
@ -229,7 +230,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
} else {
|
||||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
}
|
||||
// if kernel info is null , it remarks this function is running ut
|
||||
if (cast->kernel_info() == nullptr) {
|
||||
|
@ -284,22 +285,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
TypeId origin_type;
|
||||
const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
|
||||
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto real_input_node = kernel_with_index.first;
|
||||
if (is_weight_boundary(real_input_node)) {
|
||||
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// weight
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
||||
if (origin_type == kTypeUnknown) {
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
|
||||
}
|
||||
} else {
|
||||
// feature map
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
|
@ -307,9 +303,13 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||
const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
|
||||
const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
||||
if (origin_type != device_type) {
|
||||
// In graph kernel, we check parameter,
|
||||
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
||||
new_inputs.push_back(cur_input);
|
||||
} else if (origin_type != device_type) {
|
||||
auto cast =
|
||||
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, origin_type);
|
||||
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
cast->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "common/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -74,11 +77,21 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
|
|||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
|
||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(node) << "["
|
||||
<< node->DebugString() << "]";
|
||||
|
||||
std::vector<AnfNodePtr> todos = {node};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
||||
}
|
||||
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
|
||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
||||
<< cnode->DebugString() << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "device/kernel_info.h"
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
|
@ -27,34 +28,45 @@
|
|||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
AbstractBasePtrList abstract_list;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) {
|
||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||
AnfNodePtr replace_node = nullptr;
|
||||
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
auto idx = NewValueNode(SizeToInt(output_idx));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int32Imm>(output_idx);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get());
|
||||
AnfNodePtr replace_node = nullptr;
|
||||
if (origin_type != device_type) {
|
||||
replace_node =
|
||||
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, origin_type);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
|
||||
if (need_insert_cast[output_idx]) {
|
||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
||||
}
|
||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||
if (origin_type != device_type) {
|
||||
replace_node =
|
||||
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
} else {
|
||||
replace_node = getitem;
|
||||
}
|
||||
} else {
|
||||
replace_node = getitem;
|
||||
}
|
||||
|
@ -65,9 +77,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
|||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
||||
|
@ -76,14 +89,23 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
MS_EXCEPTION_IF_NULL(cnode->Type());
|
||||
// Single output
|
||||
if (!cnode->Type()->isa<Tuple>()) {
|
||||
if (!need_insert_cast[0]) {
|
||||
return cnode;
|
||||
}
|
||||
|
||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
||||
}
|
||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
||||
AnfNodePtr replace_node = cnode;
|
||||
if (origin_type != device_type) {
|
||||
replace_node =
|
||||
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, origin_type);
|
||||
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
|
@ -91,7 +113,57 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
return replace_node;
|
||||
}
|
||||
// Multiple output
|
||||
return InsertCastForMultipleOutput(func_graph, cnode);
|
||||
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
// insert cast for ops in graph kernel.
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
|
||||
for (auto &t : todo) {
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
|
||||
// process input
|
||||
CNodePtr t_cnode = t->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(t_cnode);
|
||||
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
|
||||
AnfNodePtr t_new_node_1 = nullptr;
|
||||
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
|
||||
// process output
|
||||
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
|
||||
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
|
||||
if (iter != graph_rets.end()) {
|
||||
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
|
||||
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
|
||||
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
|
||||
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
|
||||
need_insert_cast[iter->second] = false;
|
||||
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
|
||||
need_insert_cast[iter->second] = false;
|
||||
}
|
||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
||||
} else {
|
||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
||||
}
|
||||
|
||||
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
|
||||
(void)mng->Replace(t, t_new_node_1);
|
||||
}
|
||||
}
|
||||
|
||||
// insert cast for graph kernel.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -106,13 +178,27 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return ProcessGraphKernelOp(func_graph, node);
|
||||
} else {
|
||||
// insert cast for single op.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
}
|
||||
// insert cast for single op.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node);
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,6 +133,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
return nullptr;
|
||||
}
|
||||
auto next_cnode = next_node->cast<CNodePtr>();
|
||||
if (AnfAlgo::IsGraphKernel(next_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto next_op_name = AnfAlgo::GetCNodeName(next_node);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
kernel_query->Query(next_cnode, &kernel_info_list);
|
||||
|
@ -206,6 +209,9 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
|||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prior_op);
|
||||
if (AnfAlgo::IsGraphKernel(prior_op)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
kernel_query->Query(prior_op, &kernel_info_list);
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/utils.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
|
||||
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
|
||||
if (input_shape.size() != 5) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrMultiples);
|
||||
if (multiples.size() == 4 && multiples[1] == 1) {
|
||||
multiples.push_back(1);
|
||||
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
|
||||
}
|
||||
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name == prim::kPrimTile->name()) {
|
||||
return ModifyTileOpAttrs(cnode);
|
||||
} else if (op_name == prim::kPrimReduceSum->name()) {
|
||||
// kPrimReduceMean
|
||||
// kPrimReduceSum
|
||||
// kPrimReduceAll
|
||||
// kPrimReduceMax
|
||||
// kPrimReduceMin
|
||||
return ModifyReduceOpsAttrs(cnode);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(fg, &todos);
|
||||
for (auto &t : todos) {
|
||||
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
|
||||
if (new_node != nullptr && new_node != t) {
|
||||
(void)manager->Replace(t, new_node);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ModifyOpAttrs : public PatternProcessPass {
|
||||
public:
|
||||
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
|
||||
~ModifyOpAttrs() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "pre_activate/common/helper.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != prim::kPrimReshape->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
|
||||
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return cnode->input(1);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(fg, &todos);
|
||||
for (auto &t : todos) {
|
||||
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
|
||||
if (new_node != nullptr && new_node != t) {
|
||||
(void)manager->Replace(t, new_node);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class RemoveNoUseReshapeOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
|
||||
~RemoveNoUseReshapeOp() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
|
@ -121,6 +121,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
|
|||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<CNodePtr> cast_nodes;
|
||||
|
|
|
@ -102,9 +102,12 @@ bool UnVisited(const BaseRef &n) {
|
|||
auto prim_py = value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
return !prim_py->HasAttr(kAttrVisited);
|
||||
} else {
|
||||
return false;
|
||||
} else if (IsValueNode<FuncGraph>(in)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(in);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return !func_graph->has_flag(kAttrVisited);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -188,9 +191,12 @@ bool Visited(const BaseRef &n) {
|
|||
auto prim_py = value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
return prim_py->HasAttr(kAttrVisited);
|
||||
} else {
|
||||
return false;
|
||||
} else if (IsValueNode<FuncGraph>(in)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(in);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return func_graph->has_flag(kAttrVisited);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue