From c8b62074c9bf32fbe48e01bcaef20261e798abf0 Mon Sep 17 00:00:00 2001 From: lixiaohui Date: Mon, 26 Oct 2020 15:18:56 +0800 Subject: [PATCH] include explainer in cmake and modify __init__ in explanation classes. --- cmake/package.cmake | 1 + .../explainer/explanation/_attribution/_attribution.py | 2 ++ .../explanation/_attribution/_backprop/gradcam.py | 5 +++++ .../explanation/_attribution/_backprop/gradient.py | 5 +++++ .../_attribution/_backprop/modified_relu.py | 10 ++++++++++ 5 files changed, 23 insertions(+) diff --git a/cmake/package.cmake b/cmake/package.cmake index eaf36c37796..b3cacd8cb87 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -263,6 +263,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/profiler + ${CMAKE_SOURCE_DIR}/mindspore/explainer ${CMAKE_SOURCE_DIR}/mindspore/compression DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore diff --git a/mindspore/explainer/explanation/_attribution/_attribution.py b/mindspore/explainer/explanation/_attribution/_attribution.py index 97b2e4d51fa..78e4131103e 100644 --- a/mindspore/explainer/explanation/_attribution/_attribution.py +++ b/mindspore/explainer/explanation/_attribution/_attribution.py @@ -32,6 +32,8 @@ class Attribution: def __init__(self, network): self._verify_model(network) self._model = network + self._model.set_train(False) + self._model.set_grad(False) @staticmethod def _verify_model(model): diff --git a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py index 1ac5c06ed26..07789ad0fe3 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py @@ -55,6 +55,11 @@ class GradCAM(IntermediateLayerAttribution): layer (str): The layer name to generate the explanation at. Default: ''. If default, the explantion will be generated at the input layer. + Notes: + The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`. + If you want to train the `network` afterwards, please reset it back to training mode through the opposite + operations. + Examples: >>> net = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") diff --git a/mindspore/explainer/explanation/_attribution/_backprop/gradient.py b/mindspore/explainer/explanation/_attribution/_backprop/gradient.py index 33bd9d7e241..00ca94629a8 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/gradient.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/gradient.py @@ -64,6 +64,11 @@ class Gradient(Attribution): Args: network (Cell): The black-box model to be explained. + Notes: + The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`. + If you want to train the `network` afterwards, please reset it back to training mode through the opposite + operations. + Examples: >>> net = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") diff --git a/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py b/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py index ac8d8c5aa81..f84e6dbc3c5 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py @@ -69,6 +69,11 @@ class Deconvolution(ModifiedReLU): Args: network (Cell): The black-box model to be explained. + Notes: + The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`. + If you want to train the `network` afterwards, please reset it back to training mode through the opposite + operations. + Examples: >>> net = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") @@ -98,6 +103,11 @@ class GuidedBackprop(ModifiedReLU): Args: network (Cell): The black-box model to be explained. + Notes: + The parsed `network` will be set to eval mode through `network.set_grad(False)` and `network.set_train(False)`. + If you want to train the `network` afterwards, please reset it back to training mode through the opposite + operations. + Examples: >>> net = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt")