forked from mindspore-Ecosystem/mindspore
fix get loss error and NoneType error cause by _proceesor_specified_data
fix get loss error when it not a scalar and fix process specified data failed when the action is False, and collect_specified_data parameter is not None
This commit is contained in:
parent
fe82d82155
commit
cd868aea52
|
@ -239,7 +239,8 @@ class SummaryCollector(Callback):
|
|||
|
||||
unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
|
||||
if unexpected_params:
|
||||
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.')
|
||||
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported, '
|
||||
f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
|
||||
|
||||
if 'histogram_regular' in specified_data:
|
||||
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
|
||||
|
@ -250,7 +251,8 @@ class SummaryCollector(Callback):
|
|||
check_value_type(item, specified_data.get(item), bool)
|
||||
|
||||
if action:
|
||||
result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data)
|
||||
result = dict(self._DEFAULT_SPECIFIED_DATA)
|
||||
result.update(specified_data)
|
||||
else:
|
||||
result = specified_data
|
||||
return result
|
||||
|
@ -444,15 +446,12 @@ class SummaryCollector(Callback):
|
|||
self._is_parse_loss_success = False
|
||||
return None
|
||||
|
||||
if isinstance(output, (int, float)):
|
||||
if isinstance(output, (int, float, Tensor)):
|
||||
loss = output
|
||||
elif isinstance(output, (list, tuple)):
|
||||
elif isinstance(output, (list, tuple)) and output:
|
||||
# If the output is a list, since the default network returns loss first,
|
||||
# we assume that the first one is loss.
|
||||
loss = output[0]
|
||||
elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)):
|
||||
loss_numpy = output.asnumpy()
|
||||
loss = float(np.atleast_1d(loss_numpy)[0])
|
||||
else:
|
||||
logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
|
||||
self._is_parse_loss_success = False
|
||||
|
@ -461,6 +460,8 @@ class SummaryCollector(Callback):
|
|||
if not isinstance(loss, Tensor):
|
||||
loss = Tensor(loss)
|
||||
|
||||
precision = 4
|
||||
loss = Tensor(round(np.mean(loss.asnumpy()), precision))
|
||||
return loss
|
||||
|
||||
def _get_optimizer(self, cb_params):
|
||||
|
|
|
@ -50,6 +50,9 @@ def get_value():
|
|||
_VALUE_CACHE = list()
|
||||
return value
|
||||
|
||||
_SPECIFIED_DATA = SummaryCollector._DEFAULT_SPECIFIED_DATA
|
||||
_SPECIFIED_DATA['collect_metric'] = False
|
||||
|
||||
|
||||
class CustomNet(Cell):
|
||||
"""Define custom netwrok."""
|
||||
|
@ -190,8 +193,8 @@ class TestSummaryCollector:
|
|||
data = {'unexpected_key': True}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=data)
|
||||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
|
||||
assert expected_msg == str(exc.value)
|
||||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
|
||||
assert expected_msg in str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("custom_lineage_data", [
|
||||
123,
|
||||
|
@ -273,12 +276,16 @@ class TestSummaryCollector:
|
|||
assert name == 'train_dataset'
|
||||
|
||||
@pytest.mark.parametrize("net_output, expected_loss", [
|
||||
(None, None),
|
||||
(1, Tensor(1)),
|
||||
(1.5, Tensor(1.5)),
|
||||
(Tensor(1), Tensor(1)),
|
||||
([1], Tensor(1)),
|
||||
([Tensor(1)], Tensor(1)),
|
||||
(Tensor([1]), Tensor(1)),
|
||||
({}, None),
|
||||
(Tensor([[1, 2], [3, 4]]), Tensor(2.5)),
|
||||
([Tensor([[3, 4, 3]]), Tensor([3, 4])], Tensor(3.33333)),
|
||||
(tuple([1]), Tensor(1)),
|
||||
(None, None)
|
||||
])
|
||||
def test_get_loss(self, net_output, expected_loss):
|
||||
"""Test get loss success and failed."""
|
||||
|
@ -375,3 +382,20 @@ class TestSummaryCollector:
|
|||
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
||||
assert expected_names == [data[1] for data in result]
|
||||
assert expected_values == [data[2] for data in result]
|
||||
|
||||
@pytest.mark.parametrize("specified_data, action, expected_result", [
|
||||
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||
(None, False, {}),
|
||||
({}, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||
({}, False, {}),
|
||||
({'collect_metric': False}, True, _SPECIFIED_DATA),
|
||||
({'collect_metric': True}, False, {'collect_metric': True})
|
||||
])
|
||||
def test_process_specified_data(self, specified_data, action, expected_result):
|
||||
"""Test process specified data."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
summary_collector = SummaryCollector(summary_dir,
|
||||
collect_specified_data=specified_data,
|
||||
keep_default_action=action)
|
||||
|
||||
assert summary_collector._collect_specified_data == expected_result
|
||||
|
|
Loading…
Reference in New Issue