mirror of https://github.com/microsoft/autogen.git
accommodate nni usage pattern (#209)
This commit is contained in:
parent
a9d39b71da
commit
0ba58e0ace
|
@ -884,7 +884,7 @@ except ImportError:
|
|||
pass
|
||||
|
||||
def extract_scalar_reward(x: Dict):
|
||||
return x.get("reward")
|
||||
return x.get("default")
|
||||
|
||||
|
||||
class BlendSearchTuner(BlendSearch, NNITuner):
|
||||
|
@ -949,12 +949,19 @@ class BlendSearchTuner(BlendSearch, NNITuner):
|
|||
config[key] = qrandn(*v)
|
||||
else:
|
||||
raise ValueError(f"unsupported type in search_space {_type}")
|
||||
add_cost_to_space(config, {}, {})
|
||||
# low_cost_partial_config is passed to constructor,
|
||||
# which is before update_search_space() is called
|
||||
init_config = self._ls.init_config
|
||||
add_cost_to_space(config, init_config, self._cat_hp_cost)
|
||||
self._ls = self.LocalSearch(
|
||||
{},
|
||||
init_config,
|
||||
self._ls.metric,
|
||||
self._mode,
|
||||
config,
|
||||
self._ls.prune_attr,
|
||||
self._ls.min_resource,
|
||||
self._ls.max_resource,
|
||||
self._ls.resource_multiple_factor,
|
||||
cost_attr=self.cost_attr,
|
||||
seed=self._ls.seed,
|
||||
)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from flaml.searcher.blendsearch import CFO
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from ray import __version__ as ray_version
|
||||
assert ray_version >= '1.0.0'
|
||||
|
||||
assert ray_version >= "1.0.0"
|
||||
from ray.tune import sample
|
||||
except (ImportError, AssertionError):
|
||||
from flaml.tune import sample
|
||||
|
@ -15,7 +17,7 @@ except (ImportError, AssertionError):
|
|||
|
||||
def test_searcher():
|
||||
searcher = Searcher()
|
||||
searcher = Searcher(metric=['m1', 'm2'], mode=['max', 'min'])
|
||||
searcher = Searcher(metric=["m1", "m2"], mode=["max", "min"])
|
||||
searcher.set_search_properties(None, None, None)
|
||||
searcher.suggest = searcher.on_pause = searcher.on_unpause = lambda _: {}
|
||||
searcher.on_trial_complete = lambda trial_id, result, error: None
|
||||
|
@ -30,64 +32,73 @@ except (ImportError, AssertionError):
|
|||
searcher.set_state({})
|
||||
print(searcher.get_state())
|
||||
import optuna
|
||||
|
||||
config = {
|
||||
"a": optuna.distributions.UniformDistribution(6, 8),
|
||||
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
|
||||
}
|
||||
searcher = OptunaSearch(
|
||||
config, points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
evaluated_rewards=[{'m': 2}], metric='m', mode='max'
|
||||
config,
|
||||
points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
evaluated_rewards=[{"m": 2}],
|
||||
metric="m",
|
||||
mode="max",
|
||||
)
|
||||
config = {
|
||||
"a": sample.uniform(6, 8),
|
||||
"b": sample.loguniform(1e-4, 1e-2)
|
||||
}
|
||||
config = {"a": sample.uniform(6, 8), "b": sample.loguniform(1e-4, 1e-2)}
|
||||
searcher = OptunaSearch(
|
||||
config, points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
evaluated_rewards=[{'m': 2}], metric='m', mode='max'
|
||||
config,
|
||||
points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
evaluated_rewards=[{"m": 2}],
|
||||
metric="m",
|
||||
mode="max",
|
||||
)
|
||||
searcher = OptunaSearch(
|
||||
define_search_space, points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
define_search_space,
|
||||
points_to_evaluate=[{"a": 6, "b": 1e-3}],
|
||||
# evaluated_rewards=[{'m': 2}], metric='m', mode='max'
|
||||
mode='max'
|
||||
mode="max",
|
||||
)
|
||||
searcher = OptunaSearch()
|
||||
# searcher.set_search_properties('m', 'min', define_search_space)
|
||||
searcher.set_search_properties('m', 'min', config)
|
||||
searcher.suggest('t1')
|
||||
searcher.on_trial_complete('t1', None, False)
|
||||
searcher.suggest('t2')
|
||||
searcher.on_trial_complete('t2', None, True)
|
||||
searcher.suggest('t3')
|
||||
searcher.on_trial_complete('t3', {'m': np.nan})
|
||||
searcher.save('test/tune/optuna.pickle')
|
||||
searcher.restore('test/tune/optuna.pickle')
|
||||
searcher.set_search_properties("m", "min", config)
|
||||
searcher.suggest("t1")
|
||||
searcher.on_trial_complete("t1", None, False)
|
||||
searcher.suggest("t2")
|
||||
searcher.on_trial_complete("t2", None, True)
|
||||
searcher.suggest("t3")
|
||||
searcher.on_trial_complete("t3", {"m": np.nan})
|
||||
searcher.save("test/tune/optuna.pickle")
|
||||
searcher.restore("test/tune/optuna.pickle")
|
||||
searcher = BlendSearch(
|
||||
metric="m",
|
||||
global_search_alg=searcher, metric_constraints=[("c", "<", 1)])
|
||||
metric="m", global_search_alg=searcher, metric_constraints=[("c", "<", 1)]
|
||||
)
|
||||
searcher.set_search_properties(metric="m2", config=config)
|
||||
searcher.set_search_properties(config={"time_budget_s": 0})
|
||||
c = searcher.suggest('t1')
|
||||
c = searcher.suggest("t1")
|
||||
searcher.on_trial_complete("t1", {"config": c}, True)
|
||||
c = searcher.suggest('t2')
|
||||
c = searcher.suggest("t2")
|
||||
searcher.on_trial_complete(
|
||||
"t2", {"config": c, "m2": 1, "c": 2, "time_total_s": 1})
|
||||
"t2", {"config": c, "m2": 1, "c": 2, "time_total_s": 1}
|
||||
)
|
||||
config1 = config.copy()
|
||||
config1['_choice_'] = 0
|
||||
config1["_choice_"] = 0
|
||||
searcher._expand_admissible_region(
|
||||
lower={'root': [{'a': 0.5}, {'a': 0.4}]},
|
||||
upper={'root': [{'a': 0.9}, {'a': 0.8}]},
|
||||
space={'root': config1},
|
||||
lower={"root": [{"a": 0.5}, {"a": 0.4}]},
|
||||
upper={"root": [{"a": 0.9}, {"a": 0.8}]},
|
||||
space={"root": config1},
|
||||
)
|
||||
searcher = CFO(
|
||||
metric='m', mode='min', space=config,
|
||||
points_to_evaluate=[{'a': 7, 'b': 1e-3}, {'a': 6, 'b': 3e-4}],
|
||||
evaluated_rewards=[1, 1])
|
||||
metric="m",
|
||||
mode="min",
|
||||
space=config,
|
||||
points_to_evaluate=[{"a": 7, "b": 1e-3}, {"a": 6, "b": 3e-4}],
|
||||
evaluated_rewards=[1, 1],
|
||||
)
|
||||
searcher.suggest("t1")
|
||||
searcher.suggest("t2")
|
||||
searcher.on_trial_result('t3', {})
|
||||
searcher.on_trial_result("t3", {})
|
||||
c = searcher.generate_parameters(1)
|
||||
searcher.receive_trial_result(1, c, {'reward': 0})
|
||||
searcher.receive_trial_result(1, c, {"default": 0})
|
||||
searcher.update_search_space(
|
||||
{
|
||||
"a": {
|
||||
|
@ -99,7 +110,7 @@ except (ImportError, AssertionError):
|
|||
"_type": "randint",
|
||||
},
|
||||
"c": {
|
||||
"_value": [.1, 3],
|
||||
"_value": [0.1, 3],
|
||||
"_type": "uniform",
|
||||
},
|
||||
"d": {
|
||||
|
|
Loading…
Reference in New Issue