U
    Ph                     @  s   d dl mZ d dlmZ d dlmZmZ d dlZd dlm	Z	 d dl
mZ d dlmZ edd	d
\ZZddddgZG dd dejjZG dd dejjZG dd dZG dd deZG dd deZG dd deZdS )    )annotations)partial)AnyCallableN)replace_modules_temp)optional_import)ModelWithHookstqdmtrange)nameVanillaGrad
SmoothGradGuidedBackpropGradGuidedBackpropSmoothGradc                   @  s$   e Zd Zedd Zedd ZdS )_AutoGradReLUc                 C  s*   |dk |}t||}| || |S Nr   )type_astorchmulsave_for_backward)ctxxpos_maskoutput r   S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/visualize/gradient_based.pyforward   s    z_AutoGradReLU.forwardc                 C  sB   | j \}}|dk|}|dk|}t||}t||}|S r   )saved_tensorsr   r   r   )r   grad_outputr   _Z
pos_mask_1Z
pos_mask_2y
grad_inputr   r   r   backward%   s    
z_AutoGradReLU.backwardN)__name__
__module____qualname__staticmethodr   r"   r   r   r   r   r      s   
r   c                   @  s    e Zd ZdZdddddZdS )	_GradReLUzx
    A customized ReLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).
    torch.Tensor)r   returnc                 C  s   t |}|S N)r   apply)selfr   outr   r   r   r   4   s    
z_GradReLU.forwardN)r#   r$   r%   __doc__r   r   r   r   r   r'   /   s   r'   c                   @  sh   e Zd ZdZdddddZedd Zejd	d ZdddddddddZddddddddZ	dS )r   a  
    Given an input image ``x``, calling this class will perform the forward pass, then set to zero
    all activations except one (defined by ``index``) and propagate back to the image to achieve a gradient-based
    saliency map.

    If ``index`` is None, argmax of the output logits will be used.

    See also:

        - Simonyan et al. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
          (https://arxiv.org/abs/1312.6034)
    torch.nn.ModuleNone)modelr)   c                 C  s&   t |tst|ddd| _n|| _d S )Nr   T)target_layer_namesregister_backward)
isinstancer   _model)r,   r1   r   r   r   __init__G   s    
zVanillaGrad.__init__c                 C  s   | j jS r*   )r5   r1   )r,   r   r   r   r1   M   s    zVanillaGrad.modelc                 C  s   t |ts|| j_n|| _d S r*   )r4   r   r5   r1   )r,   mr   r   r   r1   Q   s    

Tr(   torch.Tensor | int | Noneboolr   )r   indexretain_graphkwargsr)   c                 K  sB   |j d dkrtdd|_| j|f||d| |j }|S )Nr      zexpect batch size of 1T)	class_idxr;   )shape
ValueErrorrequires_gradr5   graddetach)r,   r   r:   r;   r<   rB   r   r   r   get_gradX   s    
zVanillaGrad.get_gradNr   r:   r<   r)   c                 K  s   | j ||f|S r*   )rD   r,   r   r:   r<   r   r   r   __call__c   s    zVanillaGrad.__call__)T)N)
r#   r$   r%   r.   r6   propertyr1   setterrD   rG   r   r   r   r   r   9   s   

 c                      sH   e Zd ZdZddddddd	d
 fddZddddddddZ  ZS )r   z
    Compute averaged sensitivity map based on ``n_samples`` (Gaussian additive) of noisy versions
    of the input image ``x``.

    See also:

        - Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825
    333333?   Tr/   floatintr9   r0   )r1   stdev_spread	n_samples	magnitudeverboser)   c                   sN   t  | || _|| _|| _|  |rDtrDttd| jj	 d| _
nt
| _
d S )Nz
Computing )desc)superr6   rN   rO   rP   
has_tranger   r
   	__class__r#   range)r,   r1   rN   rO   rP   rQ   rU   r   r   r6   q   s    zSmoothGrad.__init__Nr(   r8   r   rE   c           
      K  s   | j | |    }t|}| | jD ]T}tjd||j	tj
|jd}|| }| }| j||f|}	|| jr~|	|	 n|	7 }q0| jr|d }|| j S )Nr   )sizedtypedeviceg      ?)rN   maxminitemr   
zeros_likerV   rO   normalr?   float32rZ   rC   rD   rP   )
r,   r   r:   r<   stdevZtotal_gradientsr   noiseZx_plus_noiserB   r   r   r   rG      s    
zSmoothGrad.__call__)rJ   rK   TT)N)r#   r$   r%   r.   r6   rG   __classcell__r   r   rW   r   r   g   s       c                      s.   e Zd ZdZd	ddddd fddZ  ZS )
r   ag  
    Based on Springenberg and Dosovitskiy et al. https://arxiv.org/abs/1412.6806,
    compute gradient-based saliency maps by backpropagating positive gradients and inputs (see ``_AutoGradReLU``).

    See also:

        - Springenberg and Dosovitskiy et al. Striving for Simplicity: The All Convolutional Net
          (https://arxiv.org/abs/1412.6806)
    Nr(   r8   r   rE   c              
     s>   t | jdt dd  t j||f|W  5 Q R  S Q R X d S NreluF)strict_matchr   r1   r'   rS   rG   rF   rW   r   r   rG      s    zGuidedBackpropGrad.__call__)Nr#   r$   r%   r.   rG   rc   r   r   rW   r   r      s   
c                      s.   e Zd ZdZd	ddddd fddZ  ZS )
r   zg
    Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``.
    Nr(   r8   r   rE   c              
     s>   t | jdt dd  t j||f|W  5 Q R  S Q R X d S rd   rg   rF   rW   r   r   rG      s    z!GuidedBackpropSmoothGrad.__call__)Nrh   r   r   rW   r   r      s   )
__future__r   	functoolsr   typingr   r   r   monai.networks.utilsr   monai.utils.moduler   Z%monai.visualize.class_activation_mapsr   r
   rT   __all__autogradFunctionr   nnModuler'   r   r   r   r   r   r   r   r   <module>   s   
.0