!642 [Auto parallel] Delete the 'simplify_cal' in 'set_algo_parameters' and 'get_algo_parameters'

Merge pull request !642 from Xiaoda/delete-simplify-cal-attribute-in-the-interface
This commit is contained in:
mindspore-ci-bot 2020-04-24 15:11:02 +08:00 committed by Gitee
commit 29dc8048f2
3 changed files with 7 additions and 28 deletions

View File

@ -206,10 +206,6 @@ PYBIND11_MODULE(_c_expression, m) {
"Set the parameter cost_model_gamma of the DP algorithm") "Set the parameter cost_model_gamma of the DP algorithm")
.def("get_costmodel_gamma", &CostModelContext::costmodel_gamma, .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma,
"Get the parameter cost_model_gamma of the DP algorithm.") "Get the parameter cost_model_gamma of the DP algorithm.")
.def("set_simplify_cal", &CostModelContext::set_costmodel_simplify_cal,
"Set the parameter cost_model_simplify_cal of the DP algorithm.")
.def("get_simplify_cal", &CostModelContext::costmodel_simplify_cal,
"Get the parameter cost_model_simplify_cal of the DP algorithm.")
.def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold, .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold,
"Set the parameter cost_model_communi_threshold of the DP algorithm.") "Set the parameter cost_model_communi_threshold of the DP algorithm.")
.def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold, .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold,

View File

@ -45,14 +45,6 @@ class _AlgoParameterConfig():
if self._config_handle is None: if self._config_handle is None:
raise ValueError("Config handle is none!!!") raise ValueError("Config handle is none!!!")
def set_simplify_cal(self, simplify_cal):
self.check_config_handle()
self._config_handle.set_simplify_cal(simplify_cal)
def get_simplify_cal(self):
self.check_config_handle()
return self._config_handle.get_simplify_cal()
def set_fully_use_devices(self, not_fully): def set_fully_use_devices(self, not_fully):
self.check_config_handle() self.check_config_handle()
self._config_handle.set_fully_use_devices(not_fully) self._config_handle.set_fully_use_devices(not_fully)
@ -118,7 +110,6 @@ def _algo_parameter_config():
set_algo_parameters_config_func_map = { set_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().set_simplify_cal,
"fully_use_devices": _algo_parameter_config().set_fully_use_devices, "fully_use_devices": _algo_parameter_config().set_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
@ -126,14 +117,13 @@ set_algo_parameters_config_func_map = {
get_algo_parameters_config_func_map = { get_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().get_simplify_cal,
"fully_use_devices": _algo_parameter_config().get_fully_use_devices, "fully_use_devices": _algo_parameter_config().get_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}
@args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int, @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int,
fully_use_devices=bool, elementwise_op_strategy_follow=bool) fully_use_devices=bool, elementwise_op_strategy_follow=bool)
def set_algo_parameters(**kwargs): def set_algo_parameters(**kwargs):
""" """
@ -143,10 +133,10 @@ def set_algo_parameters(**kwargs):
Attribute name is needed. Attribute name is needed.
Args: Args:
simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True tensor_slice_align_enable (bool): Whether checking tensor slice shape for MatMul. Default: False
tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False tensor_slice_align_size (int): The minimum tensor slice shape of MatMul, the value must be in [1, 1024].
tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16 Default: 16
fully_use_devices (bool): Whether generating strategies that fully use all available devices. Default: True fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True
elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
subsequent operators. Default: False subsequent operators. Default: False

View File

@ -97,13 +97,8 @@ def test_two_matmul():
assert costmodel_communi_bias == 1024.0 assert costmodel_communi_bias == 1024.0
set_algo_parameters(simplify_cal=True, set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32,
tensor_slice_align_enable=False, fully_use_devices=False, elementwise_op_strategy_follow=False)
tensor_slice_align_size=32,
fully_use_devices=False,
elementwise_op_strategy_follow=False)
para_simplify_cal = get_algo_parameters("simplify_cal")
assert para_simplify_cal == True
para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
assert para_slice_align_enable == False assert para_slice_align_enable == False
para_slice_align_size = get_algo_parameters("tensor_slice_align_size") para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
@ -114,8 +109,6 @@ def test_two_matmul():
assert elementwise_op_strategy_follow == False assert elementwise_op_strategy_follow == False
reset_algo_parameters() reset_algo_parameters()
para_simplify_cal = get_algo_parameters("simplify_cal")
assert para_simplify_cal == True
para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
assert para_slice_align_enable == False assert para_slice_align_enable == False
para_slice_align_size = get_algo_parameters("tensor_slice_align_size") para_slice_align_size = get_algo_parameters("tensor_slice_align_size")