forked from mindspore-Ecosystem/mindspore
fix the summary operator is not work in constant folding scene
The summary operator will be optimized when it return the origin value in constant folding scene. So I return a None value to avoid this.
This commit is contained in:
parent
11d78e9c36
commit
3c4f621d75
|
@ -32,6 +32,13 @@ def _check_summary_param(name, value, class_name):
|
|||
validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
|
||||
|
||||
|
||||
# Note: The return value of the summary operator is not used,
|
||||
# so there's nothing special about the return `dtype` or `shape`, any value is ok.
|
||||
# The `value` should be set to None, else summary operators may be optimized at compile graph phase,
|
||||
# it cause summary operators can not record data in constant folding scene.
|
||||
SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
|
||||
|
||||
|
||||
class ScalarSummary(PrimitiveWithInfer):
|
||||
"""
|
||||
Output scalar to protocol buffer through scalar summary operator.
|
||||
|
@ -67,7 +74,7 @@ class ScalarSummary(PrimitiveWithInfer):
|
|||
raise ValueError(f"For 'value' the type should be scalar, "
|
||||
f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.")
|
||||
|
||||
return value
|
||||
return SUMMARY_RETURN_VALUE
|
||||
|
||||
|
||||
class ImageSummary(PrimitiveWithInfer):
|
||||
|
@ -104,7 +111,7 @@ class ImageSummary(PrimitiveWithInfer):
|
|||
raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__},"
|
||||
f" but got {len(v_shape)}.")
|
||||
|
||||
return value
|
||||
return SUMMARY_RETURN_VALUE
|
||||
|
||||
|
||||
class TensorSummary(PrimitiveWithInfer):
|
||||
|
@ -142,7 +149,7 @@ class TensorSummary(PrimitiveWithInfer):
|
|||
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
|
||||
f"shape should not be [].")
|
||||
|
||||
return value
|
||||
return SUMMARY_RETURN_VALUE
|
||||
|
||||
|
||||
class HistogramSummary(PrimitiveWithInfer):
|
||||
|
@ -180,7 +187,7 @@ class HistogramSummary(PrimitiveWithInfer):
|
|||
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
|
||||
f"shape should not be [].")
|
||||
|
||||
return value
|
||||
return SUMMARY_RETURN_VALUE
|
||||
|
||||
|
||||
class InsertGradientOf(PrimitiveWithInfer):
|
||||
|
|
Loading…
Reference in New Issue