include explainer in cmake and modify __init__ in explanation classes.

This commit is contained in:
lixiaohui 2020-10-26 15:18:56 +08:00
parent d2b1e783e7
commit c8b62074c9
5 changed files with 23 additions and 0 deletions

View File

@ -263,6 +263,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ops
${CMAKE_SOURCE_DIR}/mindspore/communication
${CMAKE_SOURCE_DIR}/mindspore/profiler
${CMAKE_SOURCE_DIR}/mindspore/explainer
${CMAKE_SOURCE_DIR}/mindspore/compression
DESTINATION ${INSTALL_PY_DIR}
COMPONENT mindspore

View File

@ -32,6 +32,8 @@ class Attribution:
def __init__(self, network):
self._verify_model(network)
self._model = network
self._model.set_train(False)
self._model.set_grad(False)
@staticmethod
def _verify_model(model):

View File

@ -55,6 +55,11 @@ class GradCAM(IntermediateLayerAttribution):
layer (str): The layer name to generate the explanation at. Default: ''.
If default, the explantion will be generated at the input layer.
Notes:
The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`.
If you want to train the `network` afterwards, please reset it back to training mode through the opposite
operations.
Examples:
>>> net = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt")

View File

@ -64,6 +64,11 @@ class Gradient(Attribution):
Args:
network (Cell): The black-box model to be explained.
Notes:
The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`.
If you want to train the `network` afterwards, please reset it back to training mode through the opposite
operations.
Examples:
>>> net = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt")

View File

@ -69,6 +69,11 @@ class Deconvolution(ModifiedReLU):
Args:
network (Cell): The black-box model to be explained.
Notes:
The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`.
If you want to train the `network` afterwards, please reset it back to training mode through the opposite
operations.
Examples:
>>> net = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt")
@ -98,6 +103,11 @@ class GuidedBackprop(ModifiedReLU):
Args:
network (Cell): The black-box model to be explained.
Notes:
The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`.
If you want to train the `network` afterwards, please reset it back to training mode through the opposite
operations.
Examples:
>>> net = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt")