forked from mindspore-Ecosystem/mindspore
!22836 support dynamic compile static
Merge pull request !22836 from laiyongqiang/dynamic_compile_static
This commit is contained in:
commit
9bc8ecd233
|
@ -224,9 +224,10 @@ def get_module_name(compute_op_info):
|
|||
:param compute_op_info:
|
||||
:return:
|
||||
"""
|
||||
dynamic_compile_static = compute_op_info["dynamic_compile_static"]
|
||||
unknown_shape = compute_op_info["unknown_shape"]
|
||||
op_module_name = compute_op_info["module_name"]
|
||||
if unknown_shape:
|
||||
if dynamic_compile_static or unknown_shape:
|
||||
op_module_name = op_module_name.split(".")[0] + ".dynamic." + op_module_name.split(".")[-1]
|
||||
return op_module_name
|
||||
|
||||
|
|
|
@ -102,6 +102,7 @@ class OpInfo {
|
|||
kernel_name_ = opinfo.kernel_name();
|
||||
partial_flag_ = opinfo.partial_flag_;
|
||||
dynamic_shape_ = opinfo.dynamic_shape_;
|
||||
dynamic_compile_static_ = opinfo.dynamic_compile_static_;
|
||||
op_pattern_ = opinfo.op_pattern();
|
||||
processor_ = opinfo.processor_;
|
||||
need_check_supported_ = opinfo.need_check_supported();
|
||||
|
@ -125,6 +126,7 @@ class OpInfo {
|
|||
std::string kernel_name() const { return kernel_name_; }
|
||||
OpPattern op_pattern() const { return op_pattern_; }
|
||||
bool dynamic_shape() const { return dynamic_shape_; }
|
||||
bool dynamic_compile_static() const { return dynamic_compile_static_; }
|
||||
std::string processor() const { return processor_; }
|
||||
bool need_check_supported() const { return need_check_supported_; }
|
||||
bool is_dynamic_format() const { return is_dynamic_format_; }
|
||||
|
@ -134,6 +136,7 @@ class OpInfo {
|
|||
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
|
||||
|
||||
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
|
||||
void set_dynamic_compile_static_(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; }
|
||||
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
|
||||
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
|
||||
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
|
||||
|
@ -158,7 +161,8 @@ class OpInfo {
|
|||
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
|
||||
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
|
||||
this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
|
||||
this->dynamic_shape_ == other_info->dynamic_shape_;
|
||||
this->dynamic_shape_ == other_info->dynamic_shape_ &&
|
||||
this->dynamic_compile_static_ == other_info->dynamic_compile_static_;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -172,6 +176,7 @@ class OpInfo {
|
|||
std::string kernel_name_;
|
||||
bool partial_flag_ = false;
|
||||
bool dynamic_shape_ = false;
|
||||
bool dynamic_compile_static_ = false;
|
||||
bool need_check_supported_ = false;
|
||||
bool is_dynamic_format_ = false;
|
||||
OpPattern op_pattern_ = kCommonPattern;
|
||||
|
|
|
@ -41,6 +41,7 @@ constexpr auto kNeedCheckSupported = "need_check_supported";
|
|||
constexpr auto kBroadcast = "broadcast";
|
||||
constexpr auto kReduce = "reduce";
|
||||
constexpr auto kDynamicShape = "dynamic_shape";
|
||||
constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
|
||||
constexpr auto kDtypeFormat = "dtype_format";
|
||||
constexpr auto kAttr = "attr";
|
||||
constexpr auto kIputs = "inputs";
|
||||
|
@ -121,6 +122,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
op_info->set_dynamic_shape(obj.at(kDynamicShape));
|
||||
}
|
||||
|
||||
if (obj.find(kDynamicCompileStatic) != obj.end()) {
|
||||
op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
|
||||
}
|
||||
|
||||
if (obj.find(kIsDynamicFormat) != obj.end()) {
|
||||
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
|
||||
}
|
||||
|
|
|
@ -331,7 +331,7 @@ void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
auto iter = tbe::opTypeAdapter.find(op_name);
|
||||
(*compute_json)[kJType] = (iter != tbe::opTypeAdapter.end()) ? iter->second : op_name;
|
||||
(*compute_json)[kJPyModulePath] = python_module_path;
|
||||
(*compute_json)[kJDynamicCompileStatic] = false;
|
||||
(*compute_json)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
|
||||
(*compute_json)[kJInt64Mode] = false;
|
||||
(*compute_json)[kJName] = cnode->fullname_with_scope();
|
||||
(*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));
|
||||
|
|
|
@ -1053,7 +1053,7 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode
|
|||
// replace special op type for buffer fusion op
|
||||
auto type = GetRealOpType(origin_type);
|
||||
(*compute_op_str)[kJtype] = type;
|
||||
(*compute_op_str)[kJDynamicCompileStatic] = false;
|
||||
(*compute_op_str)[kJDynamicCompileStatic] = op_info_ptr->dynamic_compile_static();
|
||||
auto func_name = op_info_ptr->kernel_name();
|
||||
(*compute_op_str)[kJFuncName] = func_name;
|
||||
(*compute_op_str)[kJInt64Mode] = false;
|
||||
|
|
|
@ -412,6 +412,7 @@ class TBERegOp(RegOp):
|
|||
self.partial_flag_ = False
|
||||
self.reshape_type_ = ''
|
||||
self.dynamic_shape_ = False
|
||||
self.dynamic_compile_static_ = False
|
||||
self.need_check_supported_ = False
|
||||
self.is_dynamic_format_ = False
|
||||
self.op_pattern_ = ""
|
||||
|
@ -494,6 +495,17 @@ class TBERegOp(RegOp):
|
|||
self.dynamic_shape_ = dynamic_shape
|
||||
return self
|
||||
|
||||
def dynamic_compile_static(self, dynamic_compile_static=False):
|
||||
"""
|
||||
Whether the operator supports dynamic compile static.
|
||||
|
||||
Args:
|
||||
dynamic_compile_static (bool): Value of dynamic compile static. Default: false.
|
||||
"""
|
||||
self._is_bool(dynamic_compile_static)
|
||||
self.dynamic_compile_static_ = dynamic_compile_static
|
||||
return self
|
||||
|
||||
def need_check_supported(self, need_check_supported=False):
|
||||
"""
|
||||
Whether the operator need check supports.
|
||||
|
|
Loading…
Reference in New Issue