!22836 support dynamic compile static

Merge pull request !22836 from laiyongqiang/dynamic_compile_static
This commit is contained in:
i-robot 2021-09-03 09:06:12 +00:00 committed by Gitee
commit 9bc8ecd233
6 changed files with 27 additions and 4 deletions

View File

@ -224,9 +224,10 @@ def get_module_name(compute_op_info):
:param compute_op_info:
:return:
"""
dynamic_compile_static = compute_op_info["dynamic_compile_static"]
unknown_shape = compute_op_info["unknown_shape"]
op_module_name = compute_op_info["module_name"]
if unknown_shape:
if dynamic_compile_static or unknown_shape:
op_module_name = op_module_name.split(".")[0] + ".dynamic." + op_module_name.split(".")[-1]
return op_module_name

View File

@ -102,6 +102,7 @@ class OpInfo {
kernel_name_ = opinfo.kernel_name();
partial_flag_ = opinfo.partial_flag_;
dynamic_shape_ = opinfo.dynamic_shape_;
dynamic_compile_static_ = opinfo.dynamic_compile_static_;
op_pattern_ = opinfo.op_pattern();
processor_ = opinfo.processor_;
need_check_supported_ = opinfo.need_check_supported();
@ -125,6 +126,7 @@ class OpInfo {
std::string kernel_name() const { return kernel_name_; }
OpPattern op_pattern() const { return op_pattern_; }
bool dynamic_shape() const { return dynamic_shape_; }
bool dynamic_compile_static() const { return dynamic_compile_static_; }
std::string processor() const { return processor_; }
bool need_check_supported() const { return need_check_supported_; }
bool is_dynamic_format() const { return is_dynamic_format_; }
@ -134,6 +136,7 @@ class OpInfo {
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
void set_dynamic_compile_static_(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; }
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
@ -158,7 +161,8 @@ class OpInfo {
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
this->dynamic_shape_ == other_info->dynamic_shape_;
this->dynamic_shape_ == other_info->dynamic_shape_ &&
this->dynamic_compile_static_ == other_info->dynamic_compile_static_;
}
private:
@ -172,6 +176,7 @@ class OpInfo {
std::string kernel_name_;
bool partial_flag_ = false;
bool dynamic_shape_ = false;
bool dynamic_compile_static_ = false;
bool need_check_supported_ = false;
bool is_dynamic_format_ = false;
OpPattern op_pattern_ = kCommonPattern;

View File

@ -41,6 +41,7 @@ constexpr auto kNeedCheckSupported = "need_check_supported";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
constexpr auto kDynamicShape = "dynamic_shape";
constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs";
@ -121,6 +122,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_dynamic_shape(obj.at(kDynamicShape));
}
if (obj.find(kDynamicCompileStatic) != obj.end()) {
op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
}
if (obj.find(kIsDynamicFormat) != obj.end()) {
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
}

View File

@ -331,7 +331,7 @@ void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::
auto iter = tbe::opTypeAdapter.find(op_name);
(*compute_json)[kJType] = (iter != tbe::opTypeAdapter.end()) ? iter->second : op_name;
(*compute_json)[kJPyModulePath] = python_module_path;
(*compute_json)[kJDynamicCompileStatic] = false;
(*compute_json)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
(*compute_json)[kJInt64Mode] = false;
(*compute_json)[kJName] = cnode->fullname_with_scope();
(*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));

View File

@ -1053,7 +1053,7 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode
// replace special op type for buffer fusion op
auto type = GetRealOpType(origin_type);
(*compute_op_str)[kJtype] = type;
(*compute_op_str)[kJDynamicCompileStatic] = false;
(*compute_op_str)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
auto func_name = op_info_ptr->kernel_name();
(*compute_op_str)[kJFuncName] = func_name;
(*compute_op_str)[kJInt64Mode] = false;

View File

@ -412,6 +412,7 @@ class TBERegOp(RegOp):
self.partial_flag_ = False
self.reshape_type_ = ''
self.dynamic_shape_ = False
self.dynamic_compile_static_ = False
self.need_check_supported_ = False
self.is_dynamic_format_ = False
self.op_pattern_ = ""
@ -494,6 +495,17 @@ class TBERegOp(RegOp):
self.dynamic_shape_ = dynamic_shape
return self
def dynamic_compile_static(self, dynamic_compile_static=False):
"""
Whether the operator supports dynamic compile static.
Args:
dynamic_compile_static (bool): Value of dynamic compile static. Default: false.
"""
self._is_bool(dynamic_compile_static)
self.dynamic_compile_static_ = dynamic_compile_static
return self
def need_check_supported(self, need_check_supported=False):
"""
Whether the operator need check supports.