forked from mindspore-Ecosystem/mindspore
fix get loss error and NoneType error cause by _proceesor_specified_data
fix get loss error when it not a scalar and fix process specified data failed when the action is False, and collect_specified_data parameter is not None
This commit is contained in:
parent
fe82d82155
commit
cd868aea52
|
@ -239,7 +239,8 @@ class SummaryCollector(Callback):
|
||||||
|
|
||||||
unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
|
unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
|
||||||
if unexpected_params:
|
if unexpected_params:
|
||||||
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.')
|
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported, '
|
||||||
|
f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
|
||||||
|
|
||||||
if 'histogram_regular' in specified_data:
|
if 'histogram_regular' in specified_data:
|
||||||
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
|
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
|
||||||
|
@ -250,7 +251,8 @@ class SummaryCollector(Callback):
|
||||||
check_value_type(item, specified_data.get(item), bool)
|
check_value_type(item, specified_data.get(item), bool)
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data)
|
result = dict(self._DEFAULT_SPECIFIED_DATA)
|
||||||
|
result.update(specified_data)
|
||||||
else:
|
else:
|
||||||
result = specified_data
|
result = specified_data
|
||||||
return result
|
return result
|
||||||
|
@ -444,15 +446,12 @@ class SummaryCollector(Callback):
|
||||||
self._is_parse_loss_success = False
|
self._is_parse_loss_success = False
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(output, (int, float)):
|
if isinstance(output, (int, float, Tensor)):
|
||||||
loss = output
|
loss = output
|
||||||
elif isinstance(output, (list, tuple)):
|
elif isinstance(output, (list, tuple)) and output:
|
||||||
# If the output is a list, since the default network returns loss first,
|
# If the output is a list, since the default network returns loss first,
|
||||||
# we assume that the first one is loss.
|
# we assume that the first one is loss.
|
||||||
loss = output[0]
|
loss = output[0]
|
||||||
elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)):
|
|
||||||
loss_numpy = output.asnumpy()
|
|
||||||
loss = float(np.atleast_1d(loss_numpy)[0])
|
|
||||||
else:
|
else:
|
||||||
logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
|
logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
|
||||||
self._is_parse_loss_success = False
|
self._is_parse_loss_success = False
|
||||||
|
@ -461,6 +460,8 @@ class SummaryCollector(Callback):
|
||||||
if not isinstance(loss, Tensor):
|
if not isinstance(loss, Tensor):
|
||||||
loss = Tensor(loss)
|
loss = Tensor(loss)
|
||||||
|
|
||||||
|
precision = 4
|
||||||
|
loss = Tensor(round(np.mean(loss.asnumpy()), precision))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _get_optimizer(self, cb_params):
|
def _get_optimizer(self, cb_params):
|
||||||
|
|
|
@ -50,6 +50,9 @@ def get_value():
|
||||||
_VALUE_CACHE = list()
|
_VALUE_CACHE = list()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
_SPECIFIED_DATA = SummaryCollector._DEFAULT_SPECIFIED_DATA
|
||||||
|
_SPECIFIED_DATA['collect_metric'] = False
|
||||||
|
|
||||||
|
|
||||||
class CustomNet(Cell):
|
class CustomNet(Cell):
|
||||||
"""Define custom netwrok."""
|
"""Define custom netwrok."""
|
||||||
|
@ -190,8 +193,8 @@ class TestSummaryCollector:
|
||||||
data = {'unexpected_key': True}
|
data = {'unexpected_key': True}
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
SummaryCollector(summary_dir, collect_specified_data=data)
|
SummaryCollector(summary_dir, collect_specified_data=data)
|
||||||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
|
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
|
||||||
assert expected_msg == str(exc.value)
|
assert expected_msg in str(exc.value)
|
||||||
|
|
||||||
@pytest.mark.parametrize("custom_lineage_data", [
|
@pytest.mark.parametrize("custom_lineage_data", [
|
||||||
123,
|
123,
|
||||||
|
@ -273,12 +276,16 @@ class TestSummaryCollector:
|
||||||
assert name == 'train_dataset'
|
assert name == 'train_dataset'
|
||||||
|
|
||||||
@pytest.mark.parametrize("net_output, expected_loss", [
|
@pytest.mark.parametrize("net_output, expected_loss", [
|
||||||
|
(None, None),
|
||||||
(1, Tensor(1)),
|
(1, Tensor(1)),
|
||||||
|
(1.5, Tensor(1.5)),
|
||||||
|
(Tensor(1), Tensor(1)),
|
||||||
([1], Tensor(1)),
|
([1], Tensor(1)),
|
||||||
([Tensor(1)], Tensor(1)),
|
([Tensor(1)], Tensor(1)),
|
||||||
(Tensor([1]), Tensor(1)),
|
({}, None),
|
||||||
|
(Tensor([[1, 2], [3, 4]]), Tensor(2.5)),
|
||||||
|
([Tensor([[3, 4, 3]]), Tensor([3, 4])], Tensor(3.33333)),
|
||||||
(tuple([1]), Tensor(1)),
|
(tuple([1]), Tensor(1)),
|
||||||
(None, None)
|
|
||||||
])
|
])
|
||||||
def test_get_loss(self, net_output, expected_loss):
|
def test_get_loss(self, net_output, expected_loss):
|
||||||
"""Test get loss success and failed."""
|
"""Test get loss success and failed."""
|
||||||
|
@ -375,3 +382,20 @@ class TestSummaryCollector:
|
||||||
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
||||||
assert expected_names == [data[1] for data in result]
|
assert expected_names == [data[1] for data in result]
|
||||||
assert expected_values == [data[2] for data in result]
|
assert expected_values == [data[2] for data in result]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("specified_data, action, expected_result", [
|
||||||
|
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||||
|
(None, False, {}),
|
||||||
|
({}, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||||
|
({}, False, {}),
|
||||||
|
({'collect_metric': False}, True, _SPECIFIED_DATA),
|
||||||
|
({'collect_metric': True}, False, {'collect_metric': True})
|
||||||
|
])
|
||||||
|
def test_process_specified_data(self, specified_data, action, expected_result):
|
||||||
|
"""Test process specified data."""
|
||||||
|
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||||
|
summary_collector = SummaryCollector(summary_dir,
|
||||||
|
collect_specified_data=specified_data,
|
||||||
|
keep_default_action=action)
|
||||||
|
|
||||||
|
assert summary_collector._collect_specified_data == expected_result
|
||||||
|
|
Loading…
Reference in New Issue