o
    i                     @  s   d dl mZ d dlmZ d dlmZmZ d dlZd dlm	Z	 d dl
mZ d dlmZ edd	d
\ZZg dZG dd dejjZG dd dejjZG dd dZG dd deZG dd deZG dd deZdS )    )annotations)partial)AnyCallableN)replace_modules_temp)optional_import)ModelWithHookstqdmtrange)name)VanillaGrad
SmoothGradGuidedBackpropGradGuidedBackpropSmoothGradc                   @  s$   e Zd Zedd Zedd ZdS )_AutoGradReLUc                 C  s*   |dk |}t||}| || |S Nr   )type_astorchmulsave_for_backward)ctxxpos_maskoutput r   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/visualize/gradient_based.pyforward   s   z_AutoGradReLU.forwardc                 C  sB   | j \}}|dk|}|dk|}t||}t||}|S r   )saved_tensorsr   r   r   )r   grad_outputr   _Z
pos_mask_1Z
pos_mask_2y
grad_inputr   r   r   backward%   s   
z_AutoGradReLU.backwardN)__name__
__module____qualname__staticmethodr   r"   r   r   r   r   r      s
    
r   c                   @  s   e Zd ZdZdddZdS )		_GradReLUzx
    A customized ReLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).
    r   torch.Tensorreturnc                 C  s   t |}|S N)r   apply)selfr   outr   r   r   r   4   s   
z_GradReLU.forwardN)r   r(   r)   r(   )r#   r$   r%   __doc__r   r   r   r   r   r'   /   s    r'   c                   @  sN   e Zd ZdZdddZedd	 Zejd
d	 Z	ddddZddddZ	dS )r   a  
    Given an input image ``x``, calling this class will perform the forward pass, then set to zero
    all activations except one (defined by ``index``) and propagate back to the image to achieve a gradient-based
    saliency map.

    If ``index`` is None, argmax of the output logits will be used.

    See also:

        - Simonyan et al. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
          (https://arxiv.org/abs/1312.6034)
    modeltorch.nn.Moduler)   Nonec                 C  s(   t |tst|ddd| _d S || _d S )Nr   T)target_layer_namesregister_backward)
isinstancer   _model)r,   r/   r   r   r   __init__G   s   

zVanillaGrad.__init__c                 C  s   | j jS r*   )r5   r/   )r,   r   r   r   r/   M   s   zVanillaGrad.modelc                 C  s    t |ts|| j_d S || _d S r*   )r4   r   r5   r/   )r,   mr   r   r   r/   Q   s   

Tr   r(   indextorch.Tensor | int | Noneretain_graphboolkwargsr   c                 K  sB   |j d dkrtdd|_| j|f||d| |j }|S )Nr      zexpect batch size of 1T)	class_idxr:   )shape
ValueErrorrequires_gradr5   graddetach)r,   r   r8   r:   r<   rB   r   r   r   get_gradX   s   
zVanillaGrad.get_gradNc                 K  s   | j ||fi |S r*   )rD   r,   r   r8   r<   r   r   r   __call__c   s   zVanillaGrad.__call__)r/   r0   r)   r1   )T)
r   r(   r8   r9   r:   r;   r<   r   r)   r(   r*   r   r(   r8   r9   r<   r   r)   r(   )
r#   r$   r%   r.   r6   propertyr/   setterrD   rF   r   r   r   r   r   9   s    


r   c                      s8   e Zd ZdZ				dd fddZddddZ  ZS )r   z
    Compute averaged sensitivity map based on ``n_samples`` (Gaussian additive) of noisy versions
    of the input image ``x``.

    See also:

        - Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825
    333333?   Tr/   r0   stdev_spreadfloat	n_samplesint	magnituder;   verboser)   r1   c                   sP   t  | || _|| _|| _|  |r#tr#ttd| jj	 d| _
d S t
| _
d S )Nz
Computing )desc)superr6   rL   rN   rP   
has_tranger   r
   	__class__r#   range)r,   r/   rL   rN   rP   rQ   rU   r   r   r6   q   s   
zSmoothGrad.__init__Nr   r(   r8   r9   r<   r   c           
      K  s   | j | |    }t|}| | jD ],}tjd||j	tj
|jd}|| }| }| j||fi |}	|| jrA|	|	 n|	7 }q| jrL|d }|| j S )Nr   )sizedtypedeviceg      ?)rL   maxminitemr   
zeros_likerV   rN   normalr?   float32rZ   rC   rD   rP   )
r,   r   r8   r<   stdevZtotal_gradientsr   noiseZx_plus_noiserB   r   r   r   rF      s   

zSmoothGrad.__call__)rJ   rK   TT)r/   r0   rL   rM   rN   rO   rP   r;   rQ   r;   r)   r1   r*   rG   )r#   r$   r%   r.   r6   rF   __classcell__r   r   rW   r   r   g   s    r   c                      $   e Zd ZdZdd fd
dZ  ZS )r   ag  
    Based on Springenberg and Dosovitskiy et al. https://arxiv.org/abs/1412.6806,
    compute gradient-based saliency maps by backpropagating positive gradients and inputs (see ``_AutoGradReLU``).

    See also:

        - Springenberg and Dosovitskiy et al. Striving for Simplicity: The All Convolutional Net
          (https://arxiv.org/abs/1412.6806)
    Nr   r(   r8   r9   r<   r   r)   c                   N   t | jdt dd t j||fi |W  d    S 1 s w   Y  d S NreluF)strict_matchr   r/   r'   rS   rF   rE   rW   r   r   rF         $zGuidedBackpropGrad.__call__r*   rG   r#   r$   r%   r.   rF   rc   r   r   rW   r   r      s    
r   c                      rd   )r   zg
    Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``.
    Nr   r(   r8   r9   r<   r   r)   c                   re   rf   ri   rE   rW   r   r   rF      rj   z!GuidedBackpropSmoothGrad.__call__r*   rG   rk   r   r   rW   r   r      s    r   )
__future__r   	functoolsr   typingr   r   r   monai.networks.utilsr   monai.utils.moduler   Z%monai.visualize.class_activation_mapsr   r
   rT   __all__autogradFunctionr   nnModuler'   r   r   r   r   r   r   r   r   <module>   s   
.0