refactor dynamic& static op select

This commit is contained in:
jjfeing 2022-08-22 10:46:12 +08:00
parent 6ec539f79e
commit eabbdb39f3
38 changed files with 2993 additions and 957 deletions

File diff suppressed because one or more lines are too long

View File

@ -787,6 +787,7 @@ constexpr auto kOpFormat_ND = "ND";
constexpr auto kOpFormat_NCHW = "NCHW";
constexpr auto kOpFormat_NHWC = "NHWC";
constexpr auto kOpFormat_HWCN = "HWCN";
constexpr auto kOpFormat_CHWN = "CHWN";
constexpr auto kOpFormat_NC1HWC0 = "NC1HWC0";
constexpr auto kOpFormat_FRAC_Z = "FRACTAL_Z";
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
@ -820,6 +821,7 @@ COMMON_EXPORT bool IsOneOfDynamicShapeConstInputToAttrGPU(const std::string &nam
COMMON_EXPORT bool IsOneOfComputeDepend(const std::string &name);
COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format);
COMMON_EXPORT bool IsOneOfFormat(const std::string &format);
COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format);
// The map between kernel's output and input ref relationship.
// Key is the output index while the value is input index which will be used as the reference of output.

View File

@ -680,7 +680,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
auto axis_attr = primitive->GetAttr(kAxis);
if (axis_attr == nullptr) {
MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
return std::vector<int64_t>();
return {};
}
std::vector<int64_t> axis_list;
if (axis_attr->isa<Int64Imm>()) {

View File

@ -99,12 +99,7 @@ enum FusionType {
UNKNOWN_FUSION_TYPE = -1,
};
enum OpPattern {
kCommonPattern = 0,
kFormatAgnosticPattern = 1,
kBroadcastPattern = 2,
kReducePattern = 3,
};
enum OpPattern { kCommonPattern = 0, kFormatAgnosticPattern, kBroadcastPattern, kReducePattern, kDynamicFormatPattern };
// Backend processor
enum Processor {

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_
#define MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_
namespace mindspore::kernel::refactor {
// op info common
// dynamicShapeSupport
constexpr char kDynamicShapeSupport[] = "dynamicShapeSupport";
constexpr char kFlag[] = "flag";
constexpr char kTrue[] = "true";
constexpr char kFalse[] = "false";
// precision_reduce
constexpr char kPrecisionReduce[] = "precision_reduce";
// opFile
constexpr char kOpFile[] = "opFile";
constexpr char kValue[] = "value";
// opInterface
constexpr char kOpInterface[] = "opInterface";
// attr key
constexpr char kAttr[] = "attr";
constexpr char kList[] = "list";
// attr item key
constexpr char kDefaultValue[] = "defaultValue";
constexpr char kParamType[] = "paramType";
constexpr char kType[] = "type";
constexpr char kValue[] = "value";
// attr item value
constexpr char kInt[] = "int";
constexpr char kFloat[] = "float";
constexpr char kBool[] = "bool";
constexpr char kStr[] = "str";
constexpr char kListInt[] = "listInt";
constexpr char kListFloat[] = "listFloat";
constexpr char kListBool[] = "listBool";
constexpr char kListListInt[] = "listListInt";
constexpr char kRequired[] = "required";
// input/output
constexpr char kInputSuffix[] = "input";
constexpr char kDType[] = "dtype";
constexpr char kName[] = "name";
constexpr char kUnknownShapeFormat[] = "unknownshape_format";
constexpr char kFormat[] = "format";
} // namespace mindspore::kernel::refactor
#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_

View File

@ -68,6 +68,7 @@ class OpIOInfo {
const std::string &shape() const { return shape_; }
const std::vector<std::string> &dtypes() const { return dtypes_; }
const std::vector<std::string> &formats() const { return formats_; }
const std::vector<std::string> &unknown_shape_formats() const { return unknown_shape_formats_; }
const std::string &value_depend() const { return value_depend_; }
void set_index(const int index) { index_ = index; }
@ -78,6 +79,9 @@ class OpIOInfo {
void set_shape(const std::string &shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
void set_unknown_shape_formats(const std::vector<std::string> &unknown_shape_formats) {
unknown_shape_formats_ = unknown_shape_formats;
}
void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; }
private:
@ -89,6 +93,7 @@ class OpIOInfo {
std::string shape_;
std::vector<std::string> dtypes_;
std::vector<std::string> formats_;
std::vector<std::string> unknown_shape_formats_;
std::string value_depend_ = kIgnored;
};
@ -117,6 +122,7 @@ class OpInfo {
input_to_attr_index_ = opinfo.input_to_attr_index_;
real_input_index_ = opinfo.real_input_index_;
need_check_supported_ = opinfo.need_check_supported();
dynamic_rank_support_ = opinfo.dynamic_rank_support();
is_dynamic_format_ = opinfo.is_dynamic_format();
for (const auto &attr : opinfo.attrs_ptr()) {
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
@ -141,6 +147,7 @@ class OpInfo {
bool dynamic_compile_static() const { return dynamic_compile_static_; }
std::string processor() const { return processor_; }
bool need_check_supported() const { return need_check_supported_; }
bool dynamic_rank_support() const { return dynamic_rank_support_; }
bool is_dynamic_format() const { return is_dynamic_format_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
@ -162,6 +169,7 @@ class OpInfo {
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void set_processor(const std::string &processor) { processor_ = processor; }
void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; }
void set_dynamic_rank_support(bool dynamic_rank_support) { dynamic_rank_support_ = dynamic_rank_support; }
void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; }
void set_input_to_attr_index(const std::vector<size_t> &input_to_attr_index) {
input_to_attr_index_ = input_to_attr_index;
@ -197,6 +205,7 @@ class OpInfo {
bool dynamic_shape_ = false;
bool dynamic_compile_static_ = false;
bool need_check_supported_ = false;
bool dynamic_rank_support_ = false;
bool is_dynamic_format_ = false;
OpPattern op_pattern_ = kCommonPattern;
std::string processor_;
@ -211,5 +220,6 @@ class OpInfo {
using OpAttrPtr = std::shared_ptr<OpAttr>;
using OpIOInfoPtr = std::shared_ptr<OpIOInfo>;
using OpInfoPtr = std::shared_ptr<OpInfo>;
extern std::vector<std::string> SplitStrToVec(const std::string &input);
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/oplib/oplib.h"
#include <memory>
#include <map>
#include <fstream>
@ -39,11 +39,13 @@ constexpr auto kIsDynamicFormat = "is_dynamic_format";
constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kNeedCheckSupported = "need_check_supported";
constexpr auto kDynamicRankSupport = "dynamic_rank_support";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
constexpr auto kDynamicShape = "dynamic_shape";
constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kUnknownShapeFormat = "unknown_shape_format";
constexpr auto kInputToAttrIndex = "input_to_attr_index";
constexpr auto kRealInputIndex = "real_input_index";
constexpr auto kAttr = "attr";
@ -68,6 +70,9 @@ constexpr auto kNeedCompile = "need_compile";
constexpr auto kShape = "shape";
constexpr auto kProcessor = "processor";
static const std::map<std::string, OpImplyType> OpImplyTypeMap = {
{kTbe, kTBE}, {kAkg, kAKG}, {kAiCPU, kAICPU}, {kCpu, kCPU}, {kGpu, kGPU}};
static std::string ImplTypeToStr(OpImplyType impl_type) {
switch (impl_type) {
case kTBE:
@ -84,37 +89,52 @@ static std::string ImplTypeToStr(OpImplyType impl_type) {
return "unknown";
}
}
std::vector<std::string> SplitStrToVec(const std::string &input) {
static const std::map<std::string, std::string> kSpecFormat = {{kOpFormat_NCHW, kOpFormat_DEFAULT},
{kOpFormat_ND, kOpFormat_DEFAULT}};
if (input.empty()) {
MS_LOG(EXCEPTION) << "Op select ret item is null.";
}
// remove blank elem
std::string input_tmp = input;
(void)input_tmp.erase(remove(input_tmp.begin(), input_tmp.end(), ' '), input_tmp.end());
(void)input_tmp.append(",");
// split
const char sep = ',';
std::vector<std::string> result = {};
auto begin = 0U;
auto end = input_tmp.find(sep);
while (end != std::string::npos) {
auto format = input_tmp.substr(begin, end - begin);
auto find_iter = kSpecFormat.find(format);
if (find_iter != kSpecFormat.end()) {
format = find_iter->second;
}
(void)result.emplace_back(format);
begin = end + 1;
end = input_tmp.find(sep, begin);
}
return result;
}
bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
bool ret = false;
try {
auto op_json = nlohmann::json::parse(json_string);
std::string imply_type_string = op_json.at(kImplyType);
std::string op_name = op_json.at(kOpName);
if (imply_type_string == kTbe) {
OpImplyType imply_type = kTBE;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kAkg) {
OpImplyType imply_type = kAKG;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kAiCPU) {
OpImplyType imply_type = kAICPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kCpu) {
OpImplyType imply_type = kCPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kGpu) {
OpImplyType imply_type = kGPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else {
MS_LOG(ERROR) << "Not support imply_type";
auto find_iter = OpImplyTypeMap.find(imply_type_string);
if (find_iter == OpImplyTypeMap.end()) {
MS_LOG(ERROR) << "Not support imply_type, " << imply_type_string;
return false;
}
if (!ret) {
MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string;
if (!DecodeOpInfo(op_json, find_iter->second, impl_path)) {
MS_LOG(ERROR) << "RegOp failed: op_name: " << op_json.at(kOpName) << " imply_type " << imply_type_string;
return false;
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "get op json elements failed: " << e.what();
}
return ret;
return true;
}
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
@ -127,6 +147,9 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
op_info->set_need_check_supported(obj.at(kNeedCheckSupported));
if (obj.find(kDynamicRankSupport) != obj.end()) {
op_info->set_dynamic_rank_support(obj.at(kDynamicRankSupport));
}
if (obj.find(kDynamicShape) != obj.end()) {
op_info->set_dynamic_shape(obj.at(kDynamicShape));
@ -136,8 +159,13 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
}
if (obj.find(kIsDynamicFormat) != obj.end()) {
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
auto dynamic_iter = obj.find(kIsDynamicFormat);
if (dynamic_iter != obj.end()) {
bool is_dynamic_format = dynamic_iter->get<bool>();
if (is_dynamic_format) {
op_info->set_op_pattern(kDynamicFormatPattern);
}
op_info->set_is_dynamic_format(is_dynamic_format);
}
if (obj.find(kInputToAttrIndex) != obj.end()) {
@ -158,13 +186,12 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
if (obj.find(kOpPattern) != obj.end()) {
std::string op_pattern = obj.at(kOpPattern);
auto find_iter = kOpPatternMap.find(op_pattern);
if (find_iter == kOpPatternMap.end()) {
if (!op_pattern.empty()) {
MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
}
op_info->set_op_pattern(kCommonPattern);
} else {
if (find_iter != kOpPatternMap.end()) {
op_info->set_op_pattern(find_iter->second);
} else {
if (!op_pattern.empty()) {
MS_LOG(WARNING) << "The error pattern: " << op_pattern;
}
}
}
}
@ -265,6 +292,24 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI
return false;
}
}
if (obj.find(kUnknownShapeFormat) != obj.end()) {
auto unknown_shape_formats_obj = obj.at(kUnknownShapeFormat);
if (unknown_shape_formats_obj.size() != op_info->inputs_ptr().size() + op_info->outputs_ptr().size()) {
MS_LOG(ERROR) << "If unknown shape exist, the size should be equal (input size + output size).";
return false;
}
for (size_t i = 0; i < op_info->inputs_ptr().size(); ++i) {
auto unknown_shape_formats_str = unknown_shape_formats_obj.at(i);
auto unknown_shape_formats = SplitStrToVec(unknown_shape_formats_str);
op_info->inputs_ptr().at(i)->set_unknown_shape_formats(unknown_shape_formats);
}
for (size_t i = 0; i < op_info->outputs_ptr().size(); ++i) {
auto index = i + op_info->inputs_ptr().size();
auto unknown_shape_formats_str = unknown_shape_formats_obj.at(index);
auto unknown_shape_formats = SplitStrToVec(unknown_shape_formats_str);
op_info->outputs_ptr().at(i)->set_unknown_shape_formats(unknown_shape_formats);
}
}
if (CheckRepetition(op_info)) {
MS_LOG(WARNING) << "This op info has been already registered. op name: " << op_info->op_name()
<< ", impl type: " << ImplTypeToStr(op_info->imply_type())

View File

@ -102,6 +102,12 @@ std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name,
return op_info;
}
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
return FindOp(op_name, cnode);
}
inline std::string GetPrimitiveName(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return "";

View File

@ -36,6 +36,7 @@ class TbeDynamicShapeUtil {
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
static std::shared_ptr<OpInfo> FindOp(const CNodePtr &cnode);
static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
const TypeId &type);
static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/agnostic_selector/tbe_kernel_agnostic_selector.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
void TbeKernelAgnosticSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
SupportFormat support_format;
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(EXCEPTION) << "Agnostic only support one input. input_num: " << input_num << ", output num: " << output_num
<< ", full_name:" << cnode_ptr_->fullname_with_scope();
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
if (!IsOneOfFormat(format)) {
MS_LOG(ERROR) << "Got the unknown format " << format;
return;
}
GenerateSupportFormat(format, input_num, format, output_num, &support_format);
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
class TbeKernelAgnosticSelector {
public:
explicit TbeKernelAgnosticSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelAgnosticSelector() = default;
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
private:
CNodePtr cnode_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_selector/tbe_kernel_common_selector.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
namespace mindspore::kernel {
void TbeKernelCommonSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(support_format_dtype);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
MS_EXCEPTION_IF_NULL(op_info);
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(cnode_ptr_);
// TODO(jjfeing) usr format as dynamic format
for (const auto &input : op_info->inputs_ptr()) {
(void)support_format_dtype->input_dtypes.emplace_back(input->dtypes());
if (is_dynamic_shape) {
// TODO(jjfeing) (void)support_format_dtype->input_formats.emplace_back(input->unknown_shape_formats());
(void)support_format_dtype->input_formats.emplace_back(input->formats());
} else {
(void)support_format_dtype->input_formats.emplace_back(input->formats());
}
}
for (const auto &output : op_info->outputs_ptr()) {
(void)support_format_dtype->output_dtypes.emplace_back(output->dtypes());
if (is_dynamic_shape) {
// TODO(jjfeing) (void)support_format_dtype->output_formats.emplace_back(output->unknown_shape_formats());
(void)support_format_dtype->output_formats.emplace_back(output->formats());
} else {
(void)support_format_dtype->output_formats.emplace_back(output->formats());
}
}
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_
#include <utility>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
class TbeKernelCommonSelector {
public:
explicit TbeKernelCommonSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelCommonSelector() = default;
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
private:
CNodePtr cnode_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_

View File

@ -1,131 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include <set>
#include <string>
#include "base/base.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kNcdhwShapeSize = 5;
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
return param->input_size() > 0 && param->hidden_size() > 0;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
}
return false;
}
} // namespace
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < real_input_num; i++) {
auto format = AnfAlgo::GetInputFormat(node, i);
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
}
}
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < real_output_num; i++) {
auto format = AnfAlgo::GetOutputFormat(node, i);
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
}
}
return true;
}
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
std::vector<int64_t> infer_shape;
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
infer_shape = shape_ptr->shape();
}
if (infer_shape.empty()) {
return infer_shape;
}
if (trans::IsNeedPadding(format, infer_shape.size())) {
auto reshape_type =
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
}
auto temp_shape = infer_shape;
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && infer_shape.size() < kShape4dDims &&
!IsOneOf3DFormat(format)) {
MS_LOG(DEBUG) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape, node);
}
if (infer_shape.size() != kNcdhwShapeSize && IsOneOf3DFormat(format)) {
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape, node);
}
return temp_shape;
}
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
if (infer_shape.empty()) {
return true;
}
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
kOpFormat_NC1HWC0_C04};
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
if (check_4D_format.find(format) != check_4D_format.end()) {
return infer_shape.size() == kShape4dDims;
}
if (check_5D_format.find(format) != check_5D_format.end()) {
return infer_shape.size() == kShape5dDims;
}
if (format == kOpFormat_FRAC_NZ) {
return infer_shape.size() >= kShape2dDims ||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
}
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
}
if (format == kOpFormat_ND_RNN_BIAS) {
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,47 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#include <string>
#include <vector>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
namespace kernel {
struct SupportFormat {
std::vector<std::vector<std::string>> input_format;
std::vector<std::vector<std::string>> output_format;
};
using SupportFormatItem = std::vector<std::string>;
class HostCheck {
public:
HostCheck() = default;
~HostCheck() = default;
static bool CheckValidDeviceShape(const AnfNodePtr &node);
private:
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_

View File

@ -0,0 +1,82 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/dynamic_selector/tbe_kernel_dynamic_selector.h"
#include <string>
#include <map>
#include <vector>
#include "include/common/utils/utils.h"
#include "include/common/utils/json_operation_utils.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
namespace mindspore::kernel {
constexpr auto kDType1 = "dtype";
constexpr auto kFormat1 = "format";
constexpr auto kUnknownSahpeFormat1 = "unknownshape_format";
constexpr auto kPrefixInput1 = "input";
constexpr auto kPrefixOutput1 = "output";
namespace {
void ParseOpSelectJson(bool is_dynamic_impl, const std::string &op_select_json,
SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(support_format_dtype);
nlohmann::json json_obj;
if (!ParseJson(op_select_json, &json_obj)) {
MS_LOG(EXCEPTION) << "Parse op_select_json error.";
}
if (!json_obj.is_object()) {
MS_LOG(EXCEPTION) << "JsonStr is not an object, the json is:" << op_select_json;
}
for (const auto &item : json_obj.items()) {
const std::string &item_name = item.key();
bool is_input = (item_name.find(kPrefixInput1) != std::string::npos);
bool is_output = (item_name.find(kPrefixOutput1) != std::string::npos);
if (!is_input && !is_output) {
MS_LOG(EXCEPTION) << "Op select ret json is error, the json is:" << op_select_json;
}
auto dtypes = SplitStrToVec(item.value().at(kDType1));
std::string formats_str = item.value().at(kFormat1);
if (is_dynamic_impl && item.value().find(kUnknownSahpeFormat1) != item.value().end()) {
formats_str = item.value().at(kUnknownSahpeFormat1);
}
auto formats = SplitStrToVec(formats_str);
if (is_input) {
(void)support_format_dtype->input_dtypes.emplace_back(dtypes);
(void)support_format_dtype->input_formats.emplace_back(formats);
} else {
(void)support_format_dtype->output_dtypes.emplace_back(dtypes);
(void)support_format_dtype->output_formats.emplace_back(formats);
}
}
}
} // namespace
void TbeKernelDynamicSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
std::string op_format_dtype_str = build_manager.TbeOpSelectFormat(cnode_ptr_);
if (op_format_dtype_str.empty()) {
MS_LOG(EXCEPTION) << "Op select format failed, "
<< "node name: " << cnode_ptr_->fullname_with_scope();
}
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
MS_EXCEPTION_IF_NULL(op_info);
bool is_dynamic_impl = op_info->dynamic_compile_static() || common::AnfAlgo::IsDynamicShape(cnode_ptr_);
ParseOpSelectJson(is_dynamic_impl, op_format_dtype_str, support_format_dtype);
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_
#include <utility>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
class TbeKernelDynamicSelector {
public:
explicit TbeKernelDynamicSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelDynamicSelector() = default;
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
private:
CNodePtr cnode_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_

View File

@ -0,0 +1,393 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_broadcast_selector.h"
#include "include/common/utils/utils.h"
#include "utils/trace_base.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
constexpr size_t kChannelN = 0;
constexpr size_t kChannelC = 1;
constexpr int64_t kAlignmented16 = 16;
constexpr int64_t kDynamicInputSize = 2;
constexpr int64_t kLastIndex = 1;
constexpr int64_t kLastButOneIndex = 2;
void TbeKernelBroadcastSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
SupportFormat support_format;
GetSupportOriFormat(cnode_ptr_, &support_format);
GetBroadCastNodeInfo();
GetCheckInfo();
GetBroadcastSupport5HD(&support_format);
GetBroadcastSupportFracZ(&support_format);
GetBroadcastSupportC1HWNCoC0(&support_format);
GetBroadcastSupportFracNZ(&support_format);
GetBroadcastSupportNDC1HWC0(&support_format);
GetBroadcastSupportFracZ3D(&support_format);
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
}
void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) {
auto dynamic_size_vec = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kAttrDynInputSizes);
if (dynamic_size_vec.empty()) {
MS_LOG(EXCEPTION) << "Node [" << common::AnfAlgo::GetCNodeName(cnode_ptr_)
<< "]'s attr [dyn_input_sizes] is empty" << trace::DumpSourceLines(cnode_ptr_);
}
if (dynamic_size_vec[0] < kDynamicInputSize) {
MS_LOG(EXCEPTION) << "Node [" << common::AnfAlgo::GetCNodeName(cnode_ptr_)
<< "]'s attr [dyn_input_sizes] value less than " << kDynamicInputSize
<< trace::DumpSourceLines(cnode_ptr_);
}
auto dynamic_input_shape0_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kIndex0);
PadScalarShape(&dynamic_input_shape0_);
(void)input_shapes_.emplace_back(dynamic_input_shape0_);
input_num_ = 1;
} else {
input_num_ = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_num_; ++i) {
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&input_shape);
(void)input_shapes_.emplace_back(input_shape);
}
}
output_num_ = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
for (size_t i = 0; i < output_num_; ++i) {
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&output);
(void)output_shapes_.emplace_back(output);
}
}
void TbeKernelBroadcastSelector::GetCheckInfo() {
is_same_shape_ = IsSameShape();
has_scalar_input_ = HasScalarInput();
}
void TbeKernelBroadcastSelector::GetBroadcastSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
// case1: same shape & no scalar input
if (is_same_shape_ && !has_scalar_input_) {
GenerateSupportFormat(kOpFormat_NC1HWC0, input_num_, kOpFormat_NC1HWC0, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (has_scalar_input_) {
// case2: part no scalar input.
// scalar input: default | no scalar input(4D, channel_c % 16 == 0): 5hd
// scalar output: default | no scalar input: 5hd
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return;
}
if (shape[kChannelC] % kAlignmented16 != 0) {
return;
}
(void)input_support_format.emplace_back(kOpFormat_NC1HWC0);
}
}
} else {
// case3: has no scalar inputs
// input shape dims == 4 && channel_c can not been broadcast
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return;
}
if (IsInputBroadcastByChannel(kChannelC)) {
MS_LOG(INFO) << "This node broadcast c channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_NC1HWC0);
}
}
GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
void TbeKernelBroadcastSelector::GetBroadcastSupportFracZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_same_shape_ && !has_scalar_input_) {
GenerateSupportFormat(kOpFormat_FRAC_Z, input_num_, kOpFormat_FRAC_Z, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return;
}
if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) {
return;
}
(void)input_support_format.emplace_back(kOpFormat_FRAC_Z);
}
}
} else {
// case3: has no scalar inputs
// input shape dims == 4 && channel_c can not been broadcast
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return;
}
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
MS_LOG(INFO) << "This node broadcast c channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_FRAC_Z);
}
}
GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
void TbeKernelBroadcastSelector::GetBroadcastSupportC1HWNCoC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_same_shape_ && !has_scalar_input_) {
GenerateSupportFormat(kOpFormat_C1HWNCoC0, input_num_, kOpFormat_C1HWNCoC0, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (has_scalar_input_) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return;
}
if (shape[kChannelN] % kAlignmented16 != 0) {
return;
}
(void)input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return;
}
}
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
MS_LOG(INFO) << "This node broadcast n || c channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
}
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
void TbeKernelBroadcastSelector::GetBroadcastSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_same_shape_ && !has_scalar_input_) {
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_num_, kOpFormat_FRAC_NZ, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (has_scalar_input_) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (shape.size() < kShape2dDims) {
return;
}
if (shape[shape.size() - kLastIndex] % kAlignmented16 != 0 ||
shape[shape.size() - kLastButOneIndex] % kAlignmented16 != 0) {
return;
}
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
}
}
} else {
auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(),
[](const ShapeVector &elem) { return elem.size() < kShape2dDims; });
if (less_2dims) {
MS_LOG(INFO) << "This node dim less 2.";
return;
}
if (IsInputBroadcastByChannel(input_shapes_.front().size() - kLastIndex) ||
IsInputBroadcastByChannel(input_shapes_.front().size() - kLastButOneIndex)) {
MS_LOG(INFO) << "This node broadcast last channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_FRAC_NZ);
}
GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
void TbeKernelBroadcastSelector::GetBroadcastSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_same_shape_ && !has_scalar_input_) {
if (input_shapes_.front().size() < kShape5dDims) {
return;
}
GenerateSupportFormat(kOpFormat_NDC1HWC0, input_num_, kOpFormat_NDC1HWC0, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (has_scalar_input_) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is5DShape(shape)) {
return;
}
if (shape[kChannelC] % kAlignmented16 != 0) {
return;
}
(void)input_support_format.emplace_back(kOpFormat_NDC1HWC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is5DShape(shape)) {
return;
}
}
if (IsInputBroadcastByChannel(kChannelC)) {
MS_LOG(INFO) << "This node broadcast c channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_NDC1HWC0);
}
GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
void TbeKernelBroadcastSelector::GetBroadcastSupportFracZ3D(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_same_shape_ && !has_scalar_input_) {
if (input_shapes_.front().size() < kShape5dDims) {
return;
}
GenerateSupportFormat(kOpFormat_FRACTAL_Z_3D, input_num_, kOpFormat_FRACTAL_Z_3D, output_num_, support_format);
return;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (has_scalar_input_) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is5DShape(shape)) {
return;
}
if (shape[kChannelC] % kAlignmented16 != 0) {
return;
}
if (shape[kChannelN] % kAlignmented16 != 0) {
return;
}
(void)input_support_format.emplace_back(kOpFormat_FRACTAL_Z_3D);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is5DShape(shape)) {
return;
}
}
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
MS_LOG(INFO) << "This node broadcast c channel.";
return;
}
input_support_format.assign(input_num_, kOpFormat_FRACTAL_Z_3D);
}
GenOutputSupportFormat(kOpFormat_FRACTAL_Z_3D, &output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
}
bool TbeKernelBroadcastSelector::Is4DShape(const ShapeVector &shape) { return shape.size() == kShape4dDims; }
bool TbeKernelBroadcastSelector::Is5DShape(const ShapeVector &shape) { return shape.size() == kShape5dDims; }
bool TbeKernelBroadcastSelector::IsSameShape() const {
auto shape = input_shapes_.begin();
for (const auto &item : input_shapes_) {
if (shape->size() != item.size()) {
return false;
}
for (size_t i = 0; i < shape->size(); ++i) {
if (shape->at(i) != item.at(i) || (shape->at(i) == -1 && item.at(i) == -1)) {
return false;
}
}
}
return true;
}
bool TbeKernelBroadcastSelector::IsScalarShape(const ShapeVector &shape) {
return (shape.size() == 1 && shape[0] == 1);
}
bool TbeKernelBroadcastSelector::HasScalarInput() const {
return std::any_of(input_shapes_.begin(), input_shapes_.end(),
[&](const auto &shape) { return this->IsScalarShape(shape); });
}
bool TbeKernelBroadcastSelector::IsInputBroadcastByChannel(size_t channel) const {
ShapeValueDType left = 0;
if (channel < input_shapes_.front().size()) {
left = input_shapes_.front()[channel];
}
return std::any_of(input_shapes_.begin() + 1, input_shapes_.end(), [&left, &channel](const ShapeVector &elem) {
ShapeValueDType right = 0;
if (channel < elem.size()) {
right = elem[channel];
}
return (left != right || (left == -1 && right == -1));
});
}
void TbeKernelBroadcastSelector::GenOutputSupportFormat(const std::string &support_format,
SupportFormatItem *output_support_item) const {
MS_EXCEPTION_IF_NULL(output_support_item);
for (const auto &shape : output_shapes_) {
if (IsScalarShape(shape)) {
(void)output_support_item->emplace_back(kOpFormat_DEFAULT);
} else {
(void)output_support_item->emplace_back(support_format);
}
}
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
class TbeKernelBroadcastSelector {
public:
explicit TbeKernelBroadcastSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelBroadcastSelector() = default;
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
private:
void GetBroadCastNodeInfo();
void GetCheckInfo();
void GetBroadcastSupport5HD(SupportFormat *support_format) const;
void GetBroadcastSupportFracZ(SupportFormat *support_format) const;
void GetBroadcastSupportC1HWNCoC0(SupportFormat *support_format) const;
void GetBroadcastSupportFracNZ(SupportFormat *support_format) const;
void GetBroadcastSupportNDC1HWC0(SupportFormat *support_format) const;
void GetBroadcastSupportFracZ3D(SupportFormat *support_format) const;
[[nodiscard]] inline bool IsSameShape() const;
[[nodiscard]] static inline bool Is4DShape(const ShapeVector &shape);
[[nodiscard]] static inline bool Is5DShape(const ShapeVector &shape);
[[nodiscard]] static inline bool IsScalarShape(const ShapeVector &shape);
[[nodiscard]] inline bool HasScalarInput() const;
[[nodiscard]] inline bool IsInputBroadcastByChannel(size_t channel) const;
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
// broadcast
CNodePtr cnode_ptr_ = nullptr;
size_t input_num_ = 0;
size_t output_num_ = 0;
std::vector<ShapeVector> input_shapes_{};
std::vector<ShapeVector> output_shapes_{};
// check info
bool is_same_shape_ = false;
bool has_scalar_input_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_

View File

@ -0,0 +1,295 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_reduce_selector.h"
#include <string>
#include <vector>
#include <algorithm>
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
#include "kernel/common_utils.h"
#include "common/util/platform_info.h"
namespace mindspore::kernel {
constexpr int64_t kChannelN = 0;
constexpr int64_t kChannelC = 1;
constexpr size_t kReduceNZMinDim = 2;
const int64_t kCubeSize = 16;
void TbeKernelReduceSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
SupportFormat support_format;
// step1: set ori support
GetSupportOriFormat(cnode_ptr_, &support_format);
// step2: get reduce node info
GetReduceNodeInfo();
// step3: generate support format
GetReduceSupport5HD(&support_format);
GetReduceSupportNDC1HWC0(&support_format);
GetReduceSupportFracZ(&support_format);
GetReduceSupportC1HWNCoC0(&support_format);
GetReduceSupportFracZ3D(&support_format);
GetReduceSupportFracNZ(&support_format);
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
FilterInvalidFormatDType(support_format_dtype);
}
void TbeKernelReduceSelector::GetReduceNodeInfo() {
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(WARNING) << "Reduce operator input/output is not 1, input num: " << input_num
<< ", output num: " << output_num << ", node info: " << cnode_ptr_->DebugString();
}
// get input/output shape
for (size_t i = 0; i < input_num; ++i) {
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&shape);
(void)input_shape_.emplace_back(shape);
}
for (size_t i = 0; i < output_num; ++i) {
auto shape = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&shape);
(void)output_shape_.emplace_back(shape);
}
// get keep dim attr
GetReduceAttrKeepDim();
// get axis attr
axis_ = GetReduceAttrAxis(cnode_ptr_);
std::transform(axis_.begin(), axis_.end(), axis_.begin(), [&](int64_t elem) {
if (elem < 0) {
elem += SizeToLong(input_shape_.at(kIndex0).size());
}
return elem;
});
GetCheckInfo();
}
void TbeKernelReduceSelector::GetCheckInfo() {
is_shape_4_dims_ = CheckOriginInputShapeDimEqual(kShape4dDims);
is_shape_5_dims_ = CheckOriginInputShapeDimEqual(kShape5dDims);
is_shape_less_2_dims_ = CheckOriginInputShapeDimLess(kShape2dDims);
is_shape_less_4_dims_ = CheckOriginInputShapeDimLess(kShape4dDims);
is_shape_less_5_dims_ = CheckOriginInputShapeDimLess(kShape5dDims);
is_reduce_c_channel_ = CheckReduceContainChanel(kChannelC);
is_reduce_n_channel_ = CheckReduceContainChanel(kChannelN);
}
void TbeKernelReduceSelector::GetReduceSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
// Note: if input and output size == 2, the last index only support 5hd(float16) --> default
if (!is_shape_4_dims_) {
return;
}
auto support_output_format = is_reduce_c_channel_ ? kOpFormat_DEFAULT : kOpFormat_NC1HWC0;
GenerateSupportFormat(kOpFormat_NC1HWC0, input_shape_.size(), support_output_format, output_shape_.size(),
support_format);
}
void TbeKernelReduceSelector::GetReduceSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (!is_shape_5_dims_) {
return;
}
if (is_reduce_c_channel_) {
return;
}
GenerateSupportFormat(kOpFormat_NDC1HWC0, input_shape_.size(), kOpFormat_NDC1HWC0, output_shape_.size(),
support_format);
}
void TbeKernelReduceSelector::GetReduceSupportFracZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_shape_less_4_dims_) {
return;
}
if (!keep_dims_) {
return;
}
if (is_reduce_c_channel_) {
return;
}
if (is_reduce_n_channel_) {
return;
}
GenerateSupportFormat(kOpFormat_FRAC_Z, input_shape_.size(), kOpFormat_FRAC_Z, output_shape_.size(), support_format);
}
void TbeKernelReduceSelector::GetReduceSupportC1HWNCoC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_shape_less_4_dims_) {
return;
}
if (!keep_dims_) {
return;
}
if (is_reduce_c_channel_) {
return;
}
if (is_reduce_n_channel_) {
return;
}
GenerateSupportFormat(kOpFormat_C1HWNCoC0, input_shape_.size(), kOpFormat_C1HWNCoC0, output_shape_.size(),
support_format);
}
void TbeKernelReduceSelector::GetReduceSupportFracZ3D(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_shape_less_5_dims_) {
return;
}
if (!keep_dims_) {
return;
}
if (is_reduce_c_channel_) {
return;
}
if (is_reduce_n_channel_) {
return;
}
GenerateSupportFormat(kOpFormat_FRACTAL_Z_3D, input_shape_.size(), kOpFormat_FRACTAL_Z_3D, output_shape_.size(),
support_format);
}
void TbeKernelReduceSelector::GetReduceSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (is_shape_less_2_dims_) {
return;
}
if (!keep_dims_) {
return;
}
int64_t last_channel = SizeToLong(input_shape_.at(0).size()) - 1;
bool is_reduce_last = CheckReduceContainChanel(last_channel);
int64_t last_channel_but_one = SizeToLong(input_shape_.at(0).size()) - 2;
bool is_reduce_last_but_one = CheckReduceContainChanel(last_channel_but_one);
if (!is_reduce_last && !is_reduce_last_but_one) {
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_shape_.size(), kOpFormat_FRAC_NZ, output_shape_.size(),
support_format);
return;
}
if (last_channel >= 0 && LongToSize(last_channel) < input_shape_.at(kIndex0).size() &&
input_shape_.at(kIndex0).at(last_channel) % kCubeSize != 0) {
return;
}
if (last_channel_but_one >= 0 && LongToSize(last_channel_but_one) < input_shape_.at(kIndex0).size() &&
input_shape_.at(kIndex0).at(last_channel_but_one) % kCubeSize != 0) {
return;
}
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_shape_.size(), kOpFormat_DEFAULT, output_shape_.size(),
support_format);
}
void TbeKernelReduceSelector::GetReduceAttrKeepDim() {
if (!common::AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) {
MS_LOG(WARNING) << "This node doesn't have keep_attr.";
keep_dims_ = false;
return;
}
keep_dims_ = common::AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kAttrKeepDims);
}
void TbeKernelReduceSelector::FilterInvalidFormatDType(SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(support_format_dtype);
if (support_format_dtype->input_dtypes.size() != 1 || support_format_dtype->output_dtypes.size() != 1) {
MS_LOG(WARNING) << "The reduce node input or output num is not 1.";
return;
}
SupportDTypeItem input_dtypes = support_format_dtype->input_dtypes.at(0);
SupportFormatItem input_formats = support_format_dtype->input_formats.at(0);
SupportDTypeItem output_dtypes = support_format_dtype->output_dtypes.at(0);
SupportFormatItem output_formats = support_format_dtype->output_formats.at(0);
SupportDTypeItem input_dtypes_new;
SupportFormatItem input_formats_new;
SupportDTypeItem output_dtypes_new;
SupportFormatItem output_formats_new;
for (size_t i = 0; i < input_dtypes.size(); ++i) {
auto input_dtype = input_dtypes.at(i);
auto input_format = input_formats.at(i);
auto output_dtype = output_dtypes.at(i);
auto output_format = output_formats.at(i);
if (input_format == kOpFormat_NC1HWC0 && output_format == kOpFormat_DEFAULT && input_dtype == "float16") {
MS_LOG(WARNING) << "Input 5hd, input type fp16 ane output default not supported.";
continue;
}
// TODO(jjfeing)
if (input_format == kOpFormat_NC1HWC0 && !CheckUBSizeEnable(input_format, input_dtype)) {
MS_LOG(WARNING) << "Input 5hd, total size need to greater than block size.";
continue;
}
(void)input_dtypes_new.emplace_back(input_dtype);
(void)input_formats_new.emplace_back(input_format);
(void)output_dtypes_new.emplace_back(output_dtype);
(void)output_formats_new.emplace_back(output_format);
}
support_format_dtype->input_dtypes = {input_dtypes_new};
support_format_dtype->input_formats = {input_formats_new};
support_format_dtype->output_dtypes = {output_dtypes_new};
support_format_dtype->output_formats = {output_formats_new};
}
bool TbeKernelReduceSelector::CheckUBSizeEnable(const std::string &input_format, const std::string &input_dtype) {
if (axis_.empty()) {
return true;
}
if (keep_dims_) {
return true;
}
auto soc_version = device::ascend::GetSocVersion();
fe::PlatformInfo platform_info;
fe::OptionalInfo opti_compilation_info;
fe::PlatformInfoManager &inst = fe::PlatformInfoManager::Instance();
if (inst.GetPlatformInfo(soc_version, platform_info, opti_compilation_info) != 0) {
MS_LOG(ERROR) << "GetPlatformInfo failed.";
return false;
}
// TODO(jjfeing)
// auto ub_size = platform_info.ai_core_spec.ub_size;
// auto data_size = GetDtypeNbyte(input_dtype);
// auto ori_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0);
return true;
}
bool TbeKernelReduceSelector::CheckOriginInputShapeDimEqual(size_t support_dim_size) const {
// Note: identify format not check
return std::any_of(input_shape_.begin(), input_shape_.end(),
[&support_dim_size](const auto &shape) { return (shape.size() != support_dim_size); });
}
bool TbeKernelReduceSelector::CheckOriginInputShapeDimLess(size_t support_min_dim_size) const {
// Note: identify format not check
return std::any_of(input_shape_.begin(), input_shape_.end(),
[&support_min_dim_size](const auto &shape) { return (shape.size() < support_min_dim_size); });
}
bool TbeKernelReduceSelector::CheckReduceContainChanel(int64_t channel_index) const {
// if axis is empty, means reduce all axes, return true;
if (axis_.empty()) {
return true;
}
// channel out size input size, return true;
if (channel_index < 0 || input_shape_.at(kIndex0).size() <= LongToSize(channel_index)) {
return true;
}
// any of elem contain channel, return true;
return std::any_of(axis_.begin(), axis_.end(),
[&channel_index](const int64_t &elem) { return (elem == channel_index); });
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore::kernel {
class TbeKernelReduceSelector {
public:
explicit TbeKernelReduceSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelReduceSelector() = default;
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
private:
void GetReduceNodeInfo();
void GetCheckInfo();
void GetReduceSupport5HD(SupportFormat *support_format) const;
void GetReduceSupportNDC1HWC0(SupportFormat *support_format) const;
void GetReduceSupportFracZ(SupportFormat *support_format) const;
void GetReduceSupportC1HWNCoC0(SupportFormat *support_format) const;
void GetReduceSupportFracZ3D(SupportFormat *support_format) const;
void GetReduceSupportFracNZ(SupportFormat *support_format) const;
void GetReduceAttrKeepDim();
void FilterInvalidFormatDType(SupportFormatDType *support_format_dtype);
bool CheckUBSizeEnable(const std::string &input_format, const std::string &input_dtype);
[[nodiscard]] inline bool CheckOriginInputShapeDimEqual(size_t support_dim_size) const;
[[nodiscard]] inline bool CheckOriginInputShapeDimLess(size_t support_min_dim_size) const;
[[nodiscard]] inline bool CheckReduceContainChanel(int64_t channel_index) const;
CNodePtr cnode_ptr_;
std::vector<ShapeVector> input_shape_{};
std::vector<ShapeVector> output_shape_{};
std::vector<int64_t> axis_{};
bool keep_dims_ = false;
// check info
bool is_shape_4_dims_ = false;
bool is_shape_5_dims_ = false;
bool is_shape_less_2_dims_ = false;
bool is_shape_less_4_dims_ = false;
bool is_shape_less_5_dims_ = false;
bool is_reduce_n_channel_ = false;
bool is_reduce_c_channel_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,11 +14,12 @@
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include <algorithm>
#include "include/common/utils/utils.h"
#include "utils/trace_base.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore {
namespace kernel {
@ -56,7 +57,7 @@ bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
for (size_t i = 0; i < input_num_; ++i) {
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&input_shape);
input_shapes_.emplace_back(input_shape);
(void)input_shapes_.emplace_back(input_shape);
}
}
@ -64,7 +65,7 @@ bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
for (size_t i = 0; i < output_num_; ++i) {
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&output);
output_shapes_.emplace_back(output);
(void)output_shapes_.emplace_back(output);
}
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
return true;
@ -85,7 +86,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_fo
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
@ -93,7 +94,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_fo
if (shape[kChannelC] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_NC1HWC0);
(void)input_support_format.emplace_back(kOpFormat_NC1HWC0);
}
}
} else {
@ -167,7 +168,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
@ -175,7 +176,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
if (shape[kChannelN] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
(void)input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
}
}
} else {
@ -196,8 +197,8 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
}
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
(void)support_format->input_format.emplace_back(input_support_format);
(void)support_format->output_format.emplace_back(output_support_format);
return true;
}
@ -216,7 +217,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (shape.size() < kShape2dDims) {
return false;
@ -225,7 +226,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - SECOND_LAST] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
(void)input_support_format.emplace_back(kOpFormat_FRAC_NZ);
}
}
} else {
@ -269,13 +270,13 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *suppo
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_NCDHW);
(void)input_support_format.emplace_back(kOpFormat_NCDHW);
} else if (!Is5DShape(shape)) {
return false;
} else if (shape[kChannelC] % kAlignmented16 != 0) {
return false;
} else {
input_support_format.emplace_back(kOpFormat_NDC1HWC0);
(void)input_support_format.emplace_back(kOpFormat_NDC1HWC0);
}
}
} else {
@ -322,7 +323,7 @@ bool TbeKernelBroadCastSelecter::IsSameShape() const {
void TbeKernelBroadCastSelecter::PadScalarShape(ShapeVector *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
shape->emplace_back(1);
(void)shape->emplace_back(1);
}
}
@ -346,9 +347,9 @@ void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &suppo
MS_EXCEPTION_IF_NULL(output_support_item);
for (const auto &shape : output_shapes_) {
if (IsScalarShape(shape)) {
output_support_item->emplace_back(kOpFormat_DEFAULT);
(void)output_support_item->emplace_back(kOpFormat_DEFAULT);
} else {
output_support_item->emplace_back(support_format);
(void)output_support_item->emplace_back(support_format);
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,14 +14,14 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_BROADCAST_SELECTER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_BROADCAST_SELECTER_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore {
namespace kernel {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,12 +15,14 @@
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include <string>
#include <vector>
#include <algorithm>
#include "include/common/utils/utils.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include "kernel/common_utils.h"
namespace mindspore {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,13 +14,13 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_REDUCE_SELECTER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_REDUCE_SELECTER_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore {
namespace kernel {
class TbeKernelReduceSelecter {

View File

@ -20,6 +20,9 @@
#include <memory>
#include <set>
#include <utility>
#include <vector>
#include <iterator>
#include <algorithm>
#include "kernel/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_convert_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
@ -28,6 +31,7 @@
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_property_checker.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_selector_creator.h"
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/json_operation_utils.h"
@ -41,6 +45,7 @@ constexpr auto kPrefixOutput = "output";
constexpr char kParamTypeDynamic[] = "dynamic";
constexpr char kParamTypeRequre[] = "required";
constexpr char kParamTypeOptional[] = "optional";
constexpr int64_t kDynamicInvalidNum = -1;
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
auto tbe_selector = TbeKernelSelect(kernel_node, kernel_info_list);
@ -152,56 +157,83 @@ void TbeKernelSelect::GetKernelHashName() {
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << cnode_ptr_->fullname_with_scope() << "]";
}
kernel_hash_name = json_creator->GetJsonName();
return;
}
void TbeKernelSelect::TbeMetadataInfoEx() {
if (!check_cnode) {
return;
}
GetKernelHashName();
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
full_name_ = cnode_ptr_->fullname_with_scope();
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
MS_EXCEPTION_IF_NULL(op_info_ptr);
// if op pattern is FormatAgnostic, can't use select cache
bool skip_cache = (op_info_ptr->op_pattern() == kFormatAgnosticPattern);
auto iter = select_cache_.find(kernel_hash_name);
if (iter != select_cache_.end() && !skip_cache) {
for (auto &cache_info : iter->second) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
kernel_info_list_->emplace_back(builder.Build());
std::string new_select_process = common::GetEnv("XXX");
if (new_select_process.empty()) {
// Step1: Initialize, get node info
if (!Initialize()) {
MS_LOG(DEBUG) << "Initialize failed, node name: " << full_name_;
return;
}
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node "
<< cnode_ptr_->fullname_with_scope();
return;
}
if (op_info_ptr->is_dynamic_format()) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
// Step2: if kernel build info in cache, use cache return
if (GetKernelBuildInfoFromCache()) {
return;
}
// Step3:
auto get_support_format_dtype_func = GetSelectorFunc(cnode_ptr_);
if (!get_support_format_dtype_func) {
MS_LOG(ERROR) << "Get selector func failed, node name: " << full_name_;
return;
}
SupportFormatDType support_format_dtype;
get_support_format_dtype_func(cnode_ptr_, &support_format_dtype);
PrintSupportedFormatDtype(support_format_dtype);
GenerateKernelBuildInfo(support_format_dtype);
FilterInvalidKernelInfo();
AddKernelBuildInfoToCache();
} else {
OpPattern pattern = op_info_ptr->op_pattern();
if (pattern == kCommonPattern) {
GetCommonPatternKernelInfo(*op_info_ptr);
} else if (pattern == kFormatAgnosticPattern) {
GetAgnosticPatternKernelInfo(*op_info_ptr);
} else if (pattern == kBroadcastPattern) {
GetBroadcastPatternKernelInfo(*op_info_ptr);
} else if (pattern == kReducePattern) {
GetReducePatternKernelInfo(*op_info_ptr);
} else {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
if (!check_cnode) {
return;
}
}
// check support
FilterInVaildKernelInfo(*op_info_ptr);
GetKernelHashName();
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
full_name_ = cnode_ptr_->fullname_with_scope();
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
MS_EXCEPTION_IF_NULL(op_info_ptr);
// if op pattern is FormatAgnostic, can't use select cache
bool skip_cache = (op_info_ptr->op_pattern() == kFormatAgnosticPattern);
for (auto &kernel_info : *kernel_info_list_) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(kernel_info);
select_cache_[kernel_hash_name].emplace_back(builder.Build());
auto iter = select_cache_.find(kernel_hash_name);
if (iter != select_cache_.end() && !skip_cache) {
for (auto &cache_info : iter->second) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
(void)kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node "
<< cnode_ptr_->fullname_with_scope();
return;
}
if (op_info_ptr->is_dynamic_format()) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
} else {
OpPattern pattern = op_info_ptr->op_pattern();
if (pattern == kCommonPattern) {
GetCommonPatternKernelInfo(*op_info_ptr);
} else if (pattern == kFormatAgnosticPattern) {
GetAgnosticPatternKernelInfo(*op_info_ptr);
} else if (pattern == kBroadcastPattern) {
GetBroadcastPatternKernelInfo(*op_info_ptr);
} else if (pattern == kReducePattern) {
GetReducePatternKernelInfo(*op_info_ptr);
} else {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
}
}
// check support
FilterInvalidKernelInfo();
AddKernelBuildInfoToCache();
}
}
void TbeKernelSelect::AddKernelBuildInfoToCache() {
if (kernel_info_list_->empty()) {
select_cache_[kernel_hash_name] = {};
return;
}
select_cache_[kernel_hash_name] = *kernel_info_list_;
MS_LOG(INFO) << "Add select kernel cache " << kernel_hash_name << " from node " << cnode_ptr_->fullname_with_scope()
<< ", cache size: " << select_cache_.size();
}
@ -248,7 +280,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build());
(void)kernel_info_list_->emplace_back(builder.Build());
}
}
@ -272,8 +304,8 @@ void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
SupportFormatItem output_item;
input_item.assign(op_info.inputs_ptr().size(), format);
output_item.assign(op_info.outputs_ptr().size(), format);
support_format.input_format.emplace_back(input_item);
support_format.output_format.emplace_back(output_item);
(void)support_format.input_format.emplace_back(input_item);
(void)support_format.output_format.emplace_back(output_item);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
@ -306,32 +338,31 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
GetCommonPatternKernelInfo(op_info_new);
}
void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
void TbeKernelSelect::FilterInvalidKernelInfo() {
if (kernel_info_list_->empty()) {
MS_LOG(INFO) << "Warning: get kernel build info failed. Skip check supported. Op name: " << full_name_;
return;
}
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
auto dynamic_inputs = GetNodeDynamicInputs();
auto need_check_supported = op_info.need_check_supported();
for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
for (const auto &kernel_build_info : *kernel_info_list_) {
if (!FilterInvalidShape(kernel_build_info, !dynamic_inputs.empty())) {
continue;
}
if (need_check_supported && !TbeCheckSupported(iter)) {
if (!TbeCheckSupported(kernel_build_info)) {
continue;
}
kernel_info_list.emplace_back(*iter);
(void)kernel_info_list.emplace_back(kernel_build_info);
}
if (kernel_info_list.empty()) {
MS_LOG(DEBUG) << "After tbe check supported, all valid AI CORE kernel infos were filtered out. Node:" << full_name_;
}
(*kernel_info_list_) = kernel_info_list;
(*kernel_info_list_).swap(kernel_info_list);
}
bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
bool TbeKernelSelect::FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
const auto &kernel_build_info_inputs_format = kernel_build_info->GetAllInputFormats();
// dynamic input just need to check first input, because other inputs copy from 1th input;
auto iter_num =
is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size();
@ -342,7 +373,7 @@ bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build
return false;
}
}
const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
const auto &kernel_build_info_outputs_format = kernel_build_info->GetAllOutputFormats();
for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
auto shape = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
const auto &format = kernel_build_info_outputs_format[j];
@ -354,36 +385,81 @@ bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build
}
bool TbeKernelSelect::IsShapeMatchFormat(const ShapeVector &shape, const std::string &format) {
// if format is default, it means support all format
if (format == kOpFormat_DEFAULT) {
return true;
}
static const std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (!IsOneOfFormat(format)) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
}
// server not support format with C04 suffix
if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
kServerNotSupportFormat.end()) {
if (IsOneOfServerFormatC04(format)) {
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() > kShape2dDims) {
return true;
// kOpFormat_FRAC_NZ >=2 || %16 == 0
if (format == kOpFormat_FRAC_NZ) {
return (shape.size() >= kShape2dDims) ||
(shape.size() == 1 && (shape[0] == 1 || (shape[0] % SizeToLong(kCubeSize) == 0)));
}
// RNN
if (!IsShapeMatchFormatRNN(shape, format)) {
return false;
}
// not support format:
// 1 3d formats with shape size > 5
if (IsOneOf3DFormat(format) && shape.size() > kShape5dDims) {
// 3D formats with shape size > 5
if (IsOneOf3DFormat(format)) {
return shape.size() <= kShape5dDims;
}
// check format is valid.
if (!IsOneOfFormat(format)) {
MS_LOG(WARNING) << "Got the unknown format " << format;
return false;
}
return true;
}
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
bool TbeKernelSelect::IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format) {
// kOpFormat_FRACTAL_ZN_RNN >=2
if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
if (!common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode_ptr_) ||
!common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode_ptr_)) {
return false;
}
auto input_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode_ptr_, kAttrInputSize);
auto hidden_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode_ptr_, kAttrHiddenSize);
// kOpFormat_FRACTAL_ZN_RNN >=2
if (format == kOpFormat_FRACTAL_ZN_RNN) {
if (shape.size() < kDim2) {
return false;
}
auto last_but_one_dim = shape[shape.size() - kDim2];
if ((last_but_one_dim != abstract::Shape::SHP_ANY) && (last_but_one_dim != input_size) &&
(last_but_one_dim != hidden_size) && (last_but_one_dim != input_size + hidden_size)) {
return false;
}
}
// kOpFormat_ND_RNN_BIAS shape not empty()
if (format == kOpFormat_ND_RNN_BIAS) {
if (shape.empty()) {
return false;
}
auto last_dim = shape[shape.size() - kDim1];
if (last_dim != abstract::Shape::SHP_ANY && last_dim % hidden_size != 0) {
return false;
}
}
}
return true;
}
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info) {
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
if (!op_info->need_check_supported()) {
return true;
}
MS_EXCEPTION_IF_NULL(kernel_build_info);
// replace kernel_info with current kernel info
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cnode_ptr_.get());
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
auto ret =
HostCheck::CheckValidDeviceShape(cnode_ptr_) && build_manager.TbeOpCheckSupported(cnode_ptr_, &kernel_json);
@ -408,9 +484,12 @@ std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
// get dynamic inputs
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<int64_t> dyn_input_sizes;
std::vector<int64_t> dyn_input_sizes = {};
if (primitive->HasAttr(kAttrDynInputSizes)) {
dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
if (dyn_input_sizes.size() != 1) {
MS_LOG(EXCEPTION) << "Now dynamic input only support one in ascend.";
}
}
return dyn_input_sizes;
}
@ -447,10 +526,10 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
}
int64_t dynamic_input_size = dyn_input_sizes[dynamic_input_index];
for (int64_t i = 0; i < dynamic_input_size; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
value_depends->emplace_back(value_depend);
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
(void)formats->emplace_back(kernel_build_info_format);
(void)reshape_types->emplace_back(reshape_type);
(void)value_depends->emplace_back(value_depend);
}
dynamic_input_index++;
real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, LongToSize(dynamic_input_size));
@ -459,19 +538,19 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
}
for (size_t i = 0; i < real_io_tensor_num; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
value_depends->emplace_back(value_depend);
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
(void)formats->emplace_back(kernel_build_info_format);
(void)reshape_types->emplace_back(reshape_type);
(void)value_depends->emplace_back(value_depend);
}
real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, real_io_tensor_num);
}
} else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
// require or optional io
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
value_depends->emplace_back(value_depend);
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
(void)formats->emplace_back(kernel_build_info_format);
(void)reshape_types->emplace_back(reshape_type);
(void)value_depends->emplace_back(value_depend);
real_io_tensor_index++;
} else {
MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
@ -511,7 +590,7 @@ void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io
for (const auto &formats : support_format_item) {
auto format = formats.at(index);
for (size_t j = 0; j < dtype.size(); ++j) {
format_new.emplace_back(format);
(void)format_new.emplace_back(format);
}
}
op_io_info_new->set_formats(format_new);
@ -537,7 +616,7 @@ std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_se
if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
obj = kDynamicFormatMap.at(obj);
}
ret.emplace_back(obj);
(void)ret.emplace_back(obj);
begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
sep_pos = op_select_tmp.find(sep, begin);
}
@ -620,7 +699,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
select_input.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_input.formats = SplitStrToVec(input_format_item);
inputs.emplace_back(select_input);
(void)inputs.emplace_back(select_input);
} else {
SelectOpIOInfo select_output;
select_output.name = item.value().at(kName);
@ -628,7 +707,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
select_output.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_output.formats = SplitStrToVec(input_format_item);
outputs.emplace_back(select_output);
(void)outputs.emplace_back(select_output);
}
}
@ -671,25 +750,159 @@ void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io
op_io_info_new->set_formats(support_format);
}
void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
if (support_format.input_format.size() != support_format.output_format.size()) {
MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
<< support_format.output_format.size() << ") size not match.";
void TbeKernelSelect::PrintSupportedFormatDtype(const SupportFormatDType &support_format_dtype) {
MS_LOG(DEBUG) << "full_name: " << full_name_;
MS_LOG(DEBUG) << "==============input dtype=============";
for (auto input : support_format_dtype.input_dtypes) {
std::stringstream ss;
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
}
for (size_t i = 0; i < support_format.input_format.size(); ++i) {
auto input_items = support_format.input_format.at(i);
auto output_items = support_format.output_format.at(i);
std::string print_str = "[";
for (const auto &input : input_items) {
(void)print_str.append(input);
(void)print_str.append(", ");
}
(void)print_str.append("] -->");
for (const auto &output : output_items) {
(void)print_str.append(output);
(void)print_str.append(", ");
}
MS_LOG(INFO) << "Support format: " << print_str;
MS_LOG(DEBUG) << "==============input format=============";
for (auto input : support_format_dtype.input_formats) {
std::stringstream ss;
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
}
MS_LOG(DEBUG) << "==============output dtype=============";
for (auto input : support_format_dtype.output_dtypes) {
std::stringstream ss;
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
}
MS_LOG(DEBUG) << "==============output format=============";
for (auto input : support_format_dtype.output_formats) {
std::stringstream ss;
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
}
}
bool TbeKernelSelect::Initialize() {
// Init 1.op_name, 2.full_name, 3.op_info, 4.kernel_json, 5.kernel_hash_name
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
full_name_ = cnode_ptr_->fullname_with_scope();
op_info_ = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
if (!op_info_) {
return false;
}
auto json_creator = std::make_shared<SelectTbeJsonCreator>();
MS_EXCEPTION_IF_NULL(json_creator);
auto ret = json_creator->GenJson(cnode_ptr_, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << cnode_ptr_->fullname_with_scope() << "]";
}
kernel_hash_name = json_creator->GetJsonName();
return true;
}
bool TbeKernelSelect::GetKernelBuildInfoFromCache() {
// Note: kFormatAgnosticPattern need select, like cast ...
if (op_info_->op_pattern() == kFormatAgnosticPattern) {
return false;
}
auto iter = select_cache_.find(kernel_hash_name);
if (iter == select_cache_.end()) {
return false;
}
for (const auto &cache_info : iter->second) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
(void)kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node " << cnode_ptr_->fullname_with_scope();
return true;
}
void TbeKernelSelect::GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype) {
auto dyn_input_sizes = GetNodeDynamicInputs();
// get real input/output num
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
auto op_info_input_num = support_format_dtype.input_dtypes.size();
auto op_info_output_num = support_format_dtype.output_dtypes.size();
if (op_info_output_num == 0 && op_info_input_num == 0) {
MS_LOG(EXCEPTION) << "input and output is null, please check, " << full_name_;
}
auto select_support_num = support_format_dtype.output_dtypes.at(0).size();
auto dynamic_input_num = dyn_input_sizes.empty() ? kDynamicInvalidNum : dyn_input_sizes.at(0);
for (size_t support_index = 0; support_index < select_support_num; ++support_index) {
KernelBuildInfoItem input_kernel_build_info;
KernelBuildInfoItem output_kernel_build_info;
// size_t io_index = 0;
// size_t real_put_index = 0;
for (size_t io_index = 0, real_put_index = 0; real_put_index < real_input_num && io_index < op_info_input_num;) {
auto op_io_info = op_info_->inputs_ptr().at(io_index);
auto support_dtype = support_format_dtype.input_dtypes.at(io_index).at(support_index);
auto support_format = support_format_dtype.input_formats.at(io_index).at(support_index);
ConstructIOKernelBuildInfo(op_io_info, support_dtype, support_format, dynamic_input_num, &input_kernel_build_info,
&io_index, &real_put_index);
}
for (size_t io_index = 0, real_put_index = 0; real_put_index < real_output_num && io_index < op_info_output_num;) {
auto op_io_info = op_info_->outputs_ptr().at(io_index);
auto support_dtype = support_format_dtype.output_dtypes.at(io_index).at(support_index);
auto support_format = support_format_dtype.output_formats.at(io_index).at(support_index);
int64_t dynamic_output_num = dynamic_input_num;
if (op_io_info->param_type() == kParamTypeDynamic) {
if (op_info_->outputs_ptr().size() != 1) {
MS_LOG(EXCEPTION) << "Dynamic output num only support 1, output name: " << op_io_info->name()
<< ", node name: " << full_name_;
}
dynamic_output_num = SizeToLong(real_output_num);
}
ConstructIOKernelBuildInfo(op_io_info, support_dtype, support_format, dynamic_output_num,
&output_kernel_build_info, &io_index, &real_put_index);
}
ConstructKernelBuildInfo(input_kernel_build_info, output_kernel_build_info);
}
}
void TbeKernelSelect::ConstructKernelBuildInfo(const KernelBuildInfoItem &input_kernel_build_info,
const KernelBuildInfoItem &output_kernel_build_info) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetProcessor(AICORE);
std::string fusion_name = op_info_->fusion_type();
auto fusion_type = GetFusionTypeByName(fusion_name);
if (fusion_type != UNKNOWN_FUSION_TYPE) {
builder.SetFusionType(fusion_type);
}
builder.SetOpPattern(op_info_->op_pattern());
builder.SetKernelType(TBE_KERNEL);
builder.SetInputsDeviceType(input_kernel_build_info.device_types);
builder.SetInputsFormat(input_kernel_build_info.formats);
builder.SetInputsReshapeType(input_kernel_build_info.reshape_types);
builder.SetInputsValueDepend(input_kernel_build_info.value_depends);
builder.SetOutputsDeviceType(output_kernel_build_info.device_types);
builder.SetOutputsFormat(output_kernel_build_info.formats);
builder.SetOutputsReshapeType(output_kernel_build_info.reshape_types);
(void)kernel_info_list_->emplace_back(builder.Build());
}
void TbeKernelSelect::ConstructIOKernelBuildInfo(const OpIOInfoPtr &op_io_info, const std::string &support_dtype,
const std::string &support_format, int64_t dynamic_num,
KernelBuildInfoItem *kernel_build_info_item, size_t *io_index,
size_t *real_put_index) const {
MS_EXCEPTION_IF_NULL(kernel_build_info_item);
MS_EXCEPTION_IF_NULL(io_index);
MS_EXCEPTION_IF_NULL(real_put_index);
if (op_io_info->param_type() == kParamTypeDynamic) {
if (dynamic_num == kDynamicInvalidNum) {
MS_LOG(EXCEPTION) << "Get node dynamic inputs num failed, node name: " << full_name_;
}
for (int64_t i = 0; i < dynamic_num; ++i) {
(void)kernel_build_info_item->formats.emplace_back(support_format);
(void)kernel_build_info_item->device_types.emplace_back(tbe::DtypeToTypeId(support_dtype));
(void)kernel_build_info_item->reshape_types.emplace_back(op_io_info->reshape_type());
(void)kernel_build_info_item->value_depends.emplace_back(op_io_info->value_depend());
}
(*real_put_index) += LongToSize(dynamic_num);
} else {
(void)kernel_build_info_item->formats.emplace_back(support_format);
(void)kernel_build_info_item->device_types.emplace_back(tbe::DtypeToTypeId(support_dtype));
(void)kernel_build_info_item->reshape_types.emplace_back(op_io_info->reshape_type());
(void)kernel_build_info_item->value_depends.emplace_back(op_io_info->value_depend());
(*real_put_index) += 1;
}
(*io_index) += 1;
}
} // namespace mindspore::kernel

View File

@ -22,17 +22,15 @@
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
namespace mindspore {
namespace kernel {
namespace mindspore::kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
bool TbeCheckIsSupportedSpec(const CNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool TbeCheckIsSupportedAny(const CNodePtr &kernel_node);
class TbeKernelSelect {
using OpInfoPtr = std::shared_ptr<OpInfo>;
using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator;
public:
TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
@ -47,10 +45,11 @@ class TbeKernelSelect {
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInVaildKernelInfo(const OpInfo &op_info);
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input);
static bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
void FilterInvalidKernelInfo();
bool FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input);
bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
bool IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info);
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
std::vector<int64_t> GetNodeDynamicInputs();
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
@ -71,8 +70,6 @@ class TbeKernelSelect {
void GetKernelHashName();
bool CheckCNode();
static void PrintSupportedFormat(const SupportFormat &support_format);
CNodePtr cnode_ptr_;
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_;
std::string node_name_;
@ -81,8 +78,23 @@ class TbeKernelSelect {
std::string kernel_hash_name;
bool check_cnode;
inline static mindspore::HashMap<std::string, std::vector<std::shared_ptr<KernelBuildInfo>>> select_cache_ = {};
private:
bool Initialize();
bool GetKernelBuildInfoFromCache();
void GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype);
void ConstructIOKernelBuildInfo(const OpIOInfoPtr &op_io_info, const std::string &support_dtype,
const std::string &support_format, int64_t dynamic_num,
KernelBuildInfoItem *kernel_build_info_item, size_t *io_index,
size_t *real_put_index) const;
OpInfoPtr op_info_ = nullptr;
void ConstructKernelBuildInfo(const KernelBuildInfoItem &input_kernel_build_info,
const KernelBuildInfoItem &output_kernel_build_info);
void AddKernelBuildInfoToCache();
void PrintSupportedFormatDtype(const SupportFormatDType &support_format_dtype);
};
} // namespace kernel
} // namespace mindspore
} // namespace mindspore::kernel
#endif // MINDSPORE_TBE_KERNEL_SELECT_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,229 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include <set>
#include <string>
#include "base/base.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
namespace mindspore::kernel {
namespace {
constexpr size_t kNcdhwShapeSize = 5;
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
return param->input_size() > 0 && param->hidden_size() > 0;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
}
return false;
}
} // namespace
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < real_input_num; i++) {
auto format = AnfAlgo::GetInputFormat(node, i);
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
}
}
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < real_output_num; i++) {
auto format = AnfAlgo::GetOutputFormat(node, i);
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
}
}
return true;
}
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
std::vector<int64_t> infer_shape;
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
infer_shape = shape_ptr->shape();
}
if (infer_shape.empty()) {
return infer_shape;
}
if (trans::IsNeedPadding(format, infer_shape.size())) {
auto reshape_type =
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
}
auto temp_shape = infer_shape;
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && infer_shape.size() < kShape4dDims &&
!IsOneOf3DFormat(format)) {
MS_LOG(DEBUG) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape, node);
}
if (infer_shape.size() != kNcdhwShapeSize && IsOneOf3DFormat(format)) {
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape, node);
}
return temp_shape;
}
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
if (infer_shape.empty()) {
return true;
}
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
kOpFormat_NC1HWC0_C04};
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
if (check_4D_format.find(format) != check_4D_format.end()) {
return infer_shape.size() == kShape4dDims;
}
if (check_5D_format.find(format) != check_5D_format.end()) {
return infer_shape.size() == kShape5dDims;
}
if (format == kOpFormat_FRAC_NZ) {
return infer_shape.size() >= kShape2dDims ||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
}
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
}
if (format == kOpFormat_ND_RNN_BIAS) {
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
}
return true;
}
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(support_format);
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
MS_EXCEPTION_IF_NULL(op_info);
// Note: in reduce/broadcast case: dynamic/optional input/output not care.
if (op_info->inputs_ptr().size() != input_num || op_info->outputs_ptr().size() != output_num) {
MS_LOG(WARNING) << "Input/output maybe have optional or dynamic io, input num:" << input_num
<< ", output_num:" << output_num << ", op_info->inputs size: " << op_info->inputs_ptr().size()
<< ", op_info->output size: " << op_info->outputs_ptr().size();
}
SupportFormatItem input_item(input_num, kOpFormat_DEFAULT);
(void)support_format->input_format.emplace_back(input_item);
SupportFormatItem output_item(output_num, kOpFormat_DEFAULT);
(void)support_format->output_format.emplace_back(output_item);
}
void PadScalarShape(ShapeVector *shape) {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
(void)shape->emplace_back(1);
}
}
void GenerateSupportFormat(const std::string &support_input_format, size_t input_num,
const std::string &support_output_format, size_t output_num, SupportFormat *support_format) {
SupportFormatItem input_item(input_num, support_input_format);
(void)support_format->input_format.emplace_back(input_item);
SupportFormatItem output_item(output_num, support_output_format);
(void)support_format->output_format.emplace_back(output_item);
}
void ConstructSupportDTypes(const std::vector<OpIOInfoPtr> &puts, size_t format_size,
std::vector<SupportDTypeItem> *support_dtypes) {
MS_EXCEPTION_IF_NULL(support_dtypes);
for (const auto &put : puts) {
MS_EXCEPTION_IF_NULL(put);
auto dtypes = put->dtypes();
SupportDTypeItem support_dtype_item = {};
for (size_t i = 0; i < format_size; ++i) {
(void)support_dtype_item.insert(support_dtype_item.cbegin(), dtypes.cbegin(), dtypes.cend());
}
(void)support_dtypes->emplace_back(support_dtype_item);
}
}
void ConstructSupportFormats(size_t put_size, const std::vector<SupportFormatItem> &support_format, size_t type_size,
std::vector<SupportFormatItem> *support_formats) {
MS_EXCEPTION_IF_NULL(support_formats);
for (size_t i = 0; i < put_size; ++i) {
SupportFormatItem support_format_item = {};
for (const auto &formats : support_format) {
for (size_t j = 0; j < type_size; ++j) {
(void)support_format_item.emplace_back(formats.at(i));
}
}
(void)support_formats->emplace_back(support_format_item);
}
}
void GenerateSupportFormatDType(const OpInfoPtr &op_info, const SupportFormat &support_format, bool is_dynamic_shape,
SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(op_info);
if (op_info->inputs_ptr().size() != support_format.input_format.at(0).size()) {
MS_LOG(WARNING) << "Op name: " << op_info->op_name()
<< " has optional input, but in the graph, this input not exist."
<< "op info input num: " << op_info->inputs_ptr().size()
<< "graph node input num: " << support_format.input_format.at(0).size();
}
// TODO(jjfeing) usr format as dynamic format
auto type_size = op_info->outputs_ptr().at(0)->dtypes().size();
auto format_size = support_format.input_format.size();
auto input_size = op_info->inputs_ptr().size();
auto output_size = op_info->outputs_ptr().size();
ConstructSupportDTypes(op_info->inputs_ptr(), format_size, &support_format_dtype->input_dtypes);
ConstructSupportDTypes(op_info->outputs_ptr(), format_size, &support_format_dtype->output_dtypes);
ConstructSupportFormats(input_size, support_format.input_format, type_size, &support_format_dtype->input_formats);
ConstructSupportFormats(output_size, support_format.output_format, type_size, &support_format_dtype->output_formats);
if (support_format_dtype->input_dtypes.size() != op_info->inputs_ptr().size() ||
support_format_dtype->output_dtypes.size() != op_info->outputs_ptr().size() ||
support_format_dtype->output_dtypes.at(0).size() != support_format_dtype->output_formats.at(0).size()) {
MS_LOG(ERROR) << "GenerateSupportFormatDType failed.";
}
}
void GenerateSupportFormatDType(const CNodePtr &cnode, const SupportFormat &support_format,
SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(support_format_dtype);
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode);
GenerateSupportFormatDType(op_info, support_format, is_dynamic_shape, support_format_dtype);
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_
#include <string>
#include <vector>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "kernel/oplib/opinfo.h"
namespace mindspore::kernel {
using SupportFormatItem = std::vector<std::string>;
struct SupportFormat {
std::vector<SupportFormatItem> input_format;
std::vector<SupportFormatItem> output_format;
};
using SupportDTypeItem = std::vector<std::string>;
struct SupportFormatDType {
std::vector<SupportDTypeItem> input_dtypes;
std::vector<SupportFormatItem> input_formats;
std::vector<SupportDTypeItem> output_dtypes;
std::vector<SupportFormatItem> output_formats;
};
struct KernelBuildInfoItem {
std::vector<std::string> formats;
std::vector<TypeId> device_types;
std::vector<std::string> reshape_types;
std::vector<std::string> value_depends;
};
class HostCheck {
public:
HostCheck() = default;
~HostCheck() = default;
static bool CheckValidDeviceShape(const AnfNodePtr &node);
private:
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
};
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format);
void PadScalarShape(ShapeVector *shape);
void GenerateSupportFormat(const std::string &support_input_format, size_t input_num,
const std::string &support_output_format, size_t output_num, SupportFormat *support_format);
void ConstructSupportDTypes(const std::vector<OpIOInfoPtr> &puts, size_t format_size,
std::vector<SupportDTypeItem> *support_dtypes);
void ConstructSupportFormats(size_t put_size, const std::vector<SupportFormatItem> &support_format, size_t type_size,
std::vector<SupportFormatItem> *support_formats);
void GenerateSupportFormatDType(const OpInfoPtr &op_info, const SupportFormat &support_format, bool is_dynamic_shape,
SupportFormatDType *support_format_dtype);
void GenerateSupportFormatDType(const CNodePtr &cnode, const SupportFormat &support_format,
SupportFormatDType *support_format_dtype);
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_

View File

@ -0,0 +1,76 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_selector_creator.h"
#include <map>
#include "kernel/oplib/opinfo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/agnostic_selector/tbe_kernel_agnostic_selector.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_selector/tbe_kernel_common_selector.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/dynamic_selector/tbe_kernel_dynamic_selector.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_broadcast_selector.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_reduce_selector.h"
namespace mindspore::kernel {
namespace {
void GetCommonSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
auto selector = TbeKernelCommonSelector(cnode);
selector.GetSupportedFormatDType(support_format_dtype);
}
void GetAgnosticSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
auto selector = TbeKernelAgnosticSelector(cnode);
selector.GetSupportedFormatDType(support_format_dtype);
}
void GetDynamicSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
auto selector = TbeKernelDynamicSelector(cnode);
selector.GetSupportedFormatDType(support_format_dtype);
}
void GetBroadcastSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
auto selector = TbeKernelBroadcastSelector(cnode);
selector.GetSupportedFormatDType(support_format_dtype);
}
void GetReduceSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode);
auto selector = TbeKernelReduceSelector(cnode);
selector.GetSupportedFormatDType(support_format_dtype);
}
const std::map<OpPattern, GetSupportedFormatDTypeFunc> selector_funcs = {
{kCommonPattern, GetCommonSupportedFormatDType},
{kFormatAgnosticPattern, GetAgnosticSupportedFormatDType},
{kBroadcastPattern, GetBroadcastSupportedFormatDType},
{kReducePattern, GetReduceSupportedFormatDType},
{kDynamicFormatPattern, GetDynamicSupportedFormatDType}};
} // namespace
GetSupportedFormatDTypeFunc GetSelectorFunc(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode);
if (!op_info) {
return nullptr;
}
auto pattern = op_info->op_pattern();
auto iter = selector_funcs.find(pattern);
return iter == selector_funcs.end() ? nullptr : iter->second;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_
#include <functional>
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
#include "ir/anf.h"
namespace mindspore::kernel {
using GetSupportedFormatDTypeFunc =
std::function<void(const CNodePtr &cnode, SupportFormatDType *support_format_dtype)>;
GetSupportedFormatDTypeFunc GetSelectorFunc(const CNodePtr &cnode);
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_

View File

@ -322,10 +322,6 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
trans_node = reshape_node;
}
if (trans_opname == prim::kPrimTransDataRNN->name()) {
common::AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
common::AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
}
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
@ -417,6 +413,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fracz_group), trans_node);
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fracz_group), trans_node);
}
} else if (op_name == prim::kPrimTransDataRNN->name()) {
common::AnfAlgo::CopyNodeAttr(kAttrHiddenSize, orig_node, trans_node);
common::AnfAlgo::CopyNodeAttr(kAttrInputSize, orig_node, trans_node);
}
if (is_dynamic_shape) {
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);

View File

@ -186,29 +186,22 @@ bool IsOneOfHWSpecialFormat(const std::string &format) {
}
bool IsOneOfFormat(const std::string &format) {
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT,
kOpFormat_NC1KHKWHWC0,
kOpFormat_ND,
kOpFormat_NCHW,
kOpFormat_NHWC,
kOpFormat_HWCN,
kOpFormat_NC1HWC0,
kOpFormat_FRAC_Z,
kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04,
kOpFormat_NDHWC,
kOpFormat_FRACTAL_ZN_LSTM,
kOpFormat_FRACTAL_ZN_RNN,
kOpFormat_ND_RNN_BIAS,
kOpFormat_NDC1HWC0,
kOpFormat_NCDHW,
kOpFormat_FRACTAL_Z_3D,
kOpFormat_DHWNC,
kOpFormat_DHWCN};
const std::set<std::string> kOpFormatList = {
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_CHWN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM,
kOpFormat_FRACTAL_ZN_RNN, kOpFormat_ND_RNN_BIAS, kOpFormat_NDC1HWC0,
kOpFormat_NCDHW, kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC,
kOpFormat_DHWCN};
auto iter = kOpFormatList.find(format);
return iter != kOpFormatList.end();
}
bool IsOneOfServerFormatC04(const std::string &format) {
const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
}
} // namespace mindspore

View File

@ -26,7 +26,6 @@ realdiv_op_info = TBERegOp("RealDiv") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.op_pattern("broadcast") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \

View File

@ -502,6 +502,7 @@ class TBERegOp(RegOp):
self.kernel_name_ = ''
self.partial_flag_ = False
self.reshape_type_ = ''
self.dynamic_rank_support_ = False
self.dynamic_shape_ = False
self.dynamic_compile_static_ = False
self.need_check_supported_ = False
@ -509,6 +510,29 @@ class TBERegOp(RegOp):
self.op_pattern_ = ""
self.real_input_index_ = []
self.input_to_attr_index_ = []
self.unknown_shape_formats_ = []
def unknown_shape_formats(self, unknown_shape_formats):
"""
Description operator front end and tbe operator input mapping.
Args:
unknown_shape_formats (list): Value of unknown_shape_formats. Default: ().
"""
RegOp._is_list(unknown_shape_formats)
self.unknown_shape_formats_.append(unknown_shape_formats)
return self
def dynamic_rank_support(self, dynamic_rank_support):
"""
Description operator front end and tbe operator input mapping.
Args:
dynamic_rank_support (bool): Value of dynamic_rank_support. Default: false.
"""
self._is_bool(dynamic_rank_support)
self.dynamic_rank_support_ = dynamic_rank_support
return self
def real_input_index(self, real_input_index):
"""
@ -1146,3 +1170,49 @@ class DataType:
C64_Default = ("complex64", "DefaultFormat")
C128_Default = ("complex128", "DefaultFormat")
class Format:
r"""
Various combinations of unknown shape format of ascend ops.
current support:
.. code-block::
C1HWNCoC0 = "C1HWNCoC0"
ChannelLast = "ChannelLast"
Default = "DefaultFormat"
DHWCN = "DHWCN"
FHD = "NC1HWC0"
FracNZ = "FRACTAL_NZ"
FRACTAL_Z_3D = "FRACTAL_Z_3D"
FracZ = "FRACTAL_Z"
FracZNLSTM = "FRACTAL_ZN_LSTM"
FracZNRNN = "FRACTAL_ZN_RNN"
HWCN = "HWCN"
NCDHW = "NCDHW"
NCHW = "NCHW"
ND_RNNBIAS = "ND_RNN_BIAS"
NDC1HWC0 = "NDC1HWC0"
NDHWC = "NDHWC"
NHWC = "NHWC"
NULL = "NULL"
"""
C1HWNCoC0 = "C1HWNCoC0"
ChannelLast = "ChannelLast"
Default = "DefaultFormat"
DHWCN = "DHWCN"
FHD = "NC1HWC0"
FracNZ = "FRACTAL_NZ"
FRACTAL_Z_3D = "FRACTAL_Z_3D"
FracZ = "FRACTAL_Z"
FracZNLSTM = "FRACTAL_ZN_LSTM"
FracZNRNN = "FRACTAL_ZN_RNN"
HWCN = "HWCN"
NCDHW = "NCDHW"
NCHW = "NCHW"
ND_RNNBIAS = "ND_RNN_BIAS"
NDC1HWC0 = "NDC1HWC0"
NDHWC = "NDHWC"
NHWC = "NHWC"
NULL = "NULL"

View File

@ -378,6 +378,6 @@ def test_train_feed(num_classes=65536):
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [11.344119, 10.747661, 11.134097]
expect_out = [11.344119, 10.747664, 11.132818]
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)