fix get loss error and NoneType error cause by _proceesor_specified_data

fix get loss error when it not a scalar and fix process specified data
failed when the action is False, and collect_specified_data parameter is
not None
This commit is contained in:
ougongchang 2020-06-30 19:44:28 +08:00
parent fe82d82155
commit cd868aea52
2 changed files with 36 additions and 11 deletions

View File

@ -239,7 +239,8 @@ class SummaryCollector(Callback):
unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
if unexpected_params:
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.')
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported, '
f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
if 'histogram_regular' in specified_data:
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
@ -250,7 +251,8 @@ class SummaryCollector(Callback):
check_value_type(item, specified_data.get(item), bool)
if action:
result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data)
result = dict(self._DEFAULT_SPECIFIED_DATA)
result.update(specified_data)
else:
result = specified_data
return result
@ -444,15 +446,12 @@ class SummaryCollector(Callback):
self._is_parse_loss_success = False
return None
if isinstance(output, (int, float)):
if isinstance(output, (int, float, Tensor)):
loss = output
elif isinstance(output, (list, tuple)):
elif isinstance(output, (list, tuple)) and output:
# If the output is a list, since the default network returns loss first,
# we assume that the first one is loss.
loss = output[0]
elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)):
loss_numpy = output.asnumpy()
loss = float(np.atleast_1d(loss_numpy)[0])
else:
logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
self._is_parse_loss_success = False
@ -461,6 +460,8 @@ class SummaryCollector(Callback):
if not isinstance(loss, Tensor):
loss = Tensor(loss)
precision = 4
loss = Tensor(round(np.mean(loss.asnumpy()), precision))
return loss
def _get_optimizer(self, cb_params):

View File

@ -50,6 +50,9 @@ def get_value():
_VALUE_CACHE = list()
return value
_SPECIFIED_DATA = SummaryCollector._DEFAULT_SPECIFIED_DATA
_SPECIFIED_DATA['collect_metric'] = False
class CustomNet(Cell):
"""Define custom netwrok."""
@ -190,8 +193,8 @@ class TestSummaryCollector:
data = {'unexpected_key': True}
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data=data)
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
assert expected_msg == str(exc.value)
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
assert expected_msg in str(exc.value)
@pytest.mark.parametrize("custom_lineage_data", [
123,
@ -273,12 +276,16 @@ class TestSummaryCollector:
assert name == 'train_dataset'
@pytest.mark.parametrize("net_output, expected_loss", [
(None, None),
(1, Tensor(1)),
(1.5, Tensor(1.5)),
(Tensor(1), Tensor(1)),
([1], Tensor(1)),
([Tensor(1)], Tensor(1)),
(Tensor([1]), Tensor(1)),
({}, None),
(Tensor([[1, 2], [3, 4]]), Tensor(2.5)),
([Tensor([[3, 4, 3]]), Tensor([3, 4])], Tensor(3.33333)),
(tuple([1]), Tensor(1)),
(None, None)
])
def test_get_loss(self, net_output, expected_loss):
"""Test get loss success and failed."""
@ -375,3 +382,20 @@ class TestSummaryCollector:
assert PluginEnum.HISTOGRAM.value == result[0][0]
assert expected_names == [data[1] for data in result]
assert expected_values == [data[2] for data in result]
@pytest.mark.parametrize("specified_data, action, expected_result", [
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
(None, False, {}),
({}, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
({}, False, {}),
({'collect_metric': False}, True, _SPECIFIED_DATA),
({'collect_metric': True}, False, {'collect_metric': True})
])
def test_process_specified_data(self, specified_data, action, expected_result):
"""Test process specified data."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_collector = SummaryCollector(summary_dir,
collect_specified_data=specified_data,
keep_default_action=action)
assert summary_collector._collect_specified_data == expected_result