o
    i>                     @  s   d dl mZ d dlZd dlmZmZ d dlmZ d dlZ	d dl
Z
d dlmZ d dlm  mZ d dlmZ d dlmZ d dlmZ d dlmZ g d	ZdddZG dd dZG dd dZG dd deZG dd deZG dd deZdS )    )annotationsN)CallableSequence)cast)NdarrayTensor)ScaleIntensity)ensure_tuple)default_upsampler)CAMGradCAM	GradCAMppModelWithHooksdefault_normalizerxr   returnc                 C  s>   ddd}t | tjrtj||    | jdS || S )	a/  
    A linear intensity scaling by mapping the (min, max) to (1, 0).
    If the input data is PyTorch Tensor, the output data will be Tensor on the same device,
    otherwise, output data will be numpy array.

    Note: This will flip magnitudes (i.e., smallest will become biggest and vice versa).
    data
np.ndarrayr   c                   s(   t ddd tj fdd| D ddS )Ng      ?        )minvmaxvc                   s   g | ]} |qS  r   ).0iscalerr   g/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/visualize/class_activation_maps.py
<listcomp>*   s    z8default_normalizer.<locals>._compute.<locals>.<listcomp>r   )axis)r   npstack)r   r   r   r   _compute(   s   z$default_normalizer.<locals>._compute)deviceN)r   r   r   r   )
isinstancetorchTensor	as_tensordetachcpunumpyr!   )r   r    r   r   r   r      s   
	 r   c                   @  sV   e Zd ZdZ		d d!d
dZdd Zdd Zd"ddZd#ddZd$ddZ	dd Z
dS )%r   zy
    A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information.
    F	nn_module	nn.Moduletarget_layer_namesstr | Sequence[str]register_forwardboolregister_backwardc                 C  s   || _ t|| _i | _i | _d| _d| _|| _|| _g }|	 D ]4\}}|| jvr*q |
| | jrId|jv rA|jd rAd|jd< || | | jrT|| | q | jrmt|t| jkrotd| j d dS dS dS )ag  

        Args:
            nn_module: the model to be wrapped.
            target_layer_names: the names of the layer to cache.
            register_forward: whether to cache the forward pass output corresponding to `target_layer_names`.
            register_backward: whether to cache the backward pass output corresponding to `target_layer_names`.
        NinplaceFz<Not all target_layers exist in the network module: targets: .)modelr   target_layers	gradientsactivationsscore	class_idxr/   r-   named_modulesappend__dict__register_full_backward_hookbackward_hookregister_forward_hookforward_hooklenwarningswarn)selfr)   r+   r-   r/   _registerednamemodr   r   r   __init__7   s.   



zModelWithHooks.__init__c                       fdd}|S )Nc                   s   |d j  < d S )Nr   r4   )_moduleZ_grad_inputgrad_outputrD   rB   r   r   _hooka   s   z+ModelWithHooks.backward_hook.<locals>._hookr   rB   rD   rL   r   rK   r   r<   _      zModelWithHooks.backward_hookc                   rG   )Nc                   s   |j  < d S Nr5   )rI   _inputoutputrK   r   r   rL   h   s   z*ModelWithHooks.forward_hook.<locals>._hookr   rM   r   rK   r   r>   f   rN   zModelWithHooks.forward_hooklayer_id&str | Callable[[nn.Module], nn.Module]r   c                 C  sX   t |r	|| jS t|tr$| j D ]\}}||kr#ttj|  S qtd| d)z

        Args:
            layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`,
                this method will return the module `self.model.fc`.

        Returns:
            a submodule from self.model.
        zCould not find r1   )	callabler2   r"   strr8   r   nnModuleNotImplementedError)rB   rS   rD   rE   r   r   r   	get_layerm   s   


zModelWithHooks.get_layerlogitstorch.Tensorr7   intc                 C  s   |d d |f   S rO   )squeeze)rB   r[   r7   r   r   r   class_score   s   zModelWithHooks.class_scoreNc           
        s    j j} j    j |fi |}|d u r|dd n| _d\}} jr3t fdd jD } jrr 	|t
t j _ j    j j|d  jD ]}	|	 jvretd|	 d|	 d	 qRt fd
d jD }|ry j   |||fS )N   )NNc                 3  s    | ]} j | V  qd S rO   rP   r   layerrB   r   r   	<genexpr>   s    z*ModelWithHooks.__call__.<locals>.<genexpr>)retain_graphzBackward hook for z& is not triggered; `requires_grad` of z should be `True`.c                 3  s$    | ]}| j v r j | V  qd S rO   rH   rb   rd   r   r   re      s   " )r2   trainingevalmaxr7   r-   tupler3   r/   r_   r   r]   r6   	zero_gradsumbackwardr4   r@   rA   train)
rB   r   r7   rf   kwargsrn   r[   actigradrc   r   rd   r   __call__   s*   





zModelWithHooks.__call__c                 C  s   | j S rO   )r2   rd   r   r   r   get_wrapped_net      zModelWithHooks.get_wrapped_net)FF)r)   r*   r+   r,   r-   r.   r/   r.   )rS   rT   r   r*   )r[   r\   r7   r]   r   r\   )NF)__name__
__module____qualname____doc__rF   r<   r>   rZ   r_   rr   rs   r   r   r   r   r   2   s    (


r   c                   @  sF   e Zd ZdZeedfdddZdddZdddZdd Z	dd Z
dS )CAMBasez%
    Base class for CAM methods.
    Tr)   r*   r3   rV   	upsamplerr   postprocessingr/   r.   r   Nonec                 C  s8   |  t |tst||d|d| _n|| _|| _|| _d S )NT)r-   r/   )r"   r   r)   rz   r{   )rB   r)   r3   rz   r{   r/   r   r   r   rF      s   


zCAMBase.__init__r'   ra   c                 K  s$   | j tj|d|ifd|i|jS )a  
        Computes the actual feature map size given `nn_module` and the target_layer name.
        Args:
            input_size: shape of the input tensor
            device: the device used to initialise the input tensor
            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
        Returns:
            shape of the actual feature map.
        r!   	layer_idx)compute_mapr#   zerosshape)rB   
input_sizer!   r}   ro   r   r   r   feature_map_size   s   $zCAMBase.feature_map_sizeNc                 C     t  )a  
        Compute the actual feature map with input tensor `x`.

        Args:
            x: input to `nn_module`.
            class_idx: index of the class to be visualized. Default to `None` (computing `class_idx` from `argmax`)
            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.

        Returns:
            activation maps (raw outputs without upsampling/post-processing.)
        rY   )rB   r   r7   r}   r   r   r   r~      s   zCAMBase.compute_mapc                 C  s&   |j dd  }| ||}| |S )N   )r   rz   r{   )rB   acti_mapr   Zimg_spatialr   r   r   _upsample_and_post_process   s   
z"CAMBase._upsample_and_post_processc                 C  r   rO   r   rd   r   r   r   rr      rt   zCAMBase.__call__)r)   r*   r3   rV   rz   r   r{   r   r/   r.   r   r|   )r'   ra   Nra   )ru   rv   rw   rx   r	   r   rF   r   r~   r   rr   r   r   r   r   ry      s    

ry   c                      s>   e Zd ZdZdeefd fddZdddZdddZ  Z	S )r
   a  
    Compute class activation map from the last fully-connected layers before the spatial pooling.
    This implementation is based on:

        Zhou et al., Learning Deep Features for Discriminative Localization. CVPR '16,
        https://arxiv.org/abs/1512.04150

    Examples

    .. code-block:: python

        import torch

        # densenet 2d
        from monai.networks.nets import DenseNet121
        from monai.visualize import CAM

        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
        cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out")
        result = cam(x=torch.rand((1, 1, 48, 64)))

        # resnet 2d
        from monai.networks.nets import seresnet50
        from monai.visualize import CAM

        model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)
        cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear")
        result = cam(x=torch.rand((2, 3, 48, 64)))

    N.B.: To help select the target layer, it may be useful to list all layers:

    .. code-block:: python

        for name, _ in model.named_modules(): print(name)

    See Also:

        - :py:class:`monai.visualize.class_activation_maps.GradCAM`

    fcr)   r*   r3   rV   	fc_layersstr | Callablerz   r   r{   r   r|   c                   s    t  j||||dd || _dS )a'  
        Args:
            nn_module: the model to be visualized
            target_layers: name of the model layer to generate the feature map.
            fc_layers: a string or a callable used to get fully-connected weights to compute activation map
                from the target_layers (without pooling).  and evaluate it at every spatial location.
            upsampler: An upsampling method to upsample the output image. Default is
                N dimensional linear (bilinear, trilinear, etc.) depending on num spatial
                dimensions of input.
            postprocessing: a callable that applies on the upsampled output image.
                Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and
                smallest input will become 1).
        F)r)   r3   rz   r{   r/   N)superrF   r   )rB   r)   r3   r   rz   r{   	__class__r   r   rF     s   
zCAM.__init__Nra   c                   s   | j |fi |\}}}|| }|d u r|dd }|j^}}	}
tj|||	dddd}| j | j tj fdd|D ddtjfddt	|D ddj|dg|
R  S )	Nr`   ra   r   )dimc                   s   g | ]} |d  qS )).r   r   )r   a)r   r   r   r   *  s    z#CAM.compute_map.<locals>.<listcomp>c                   s$   g | ]\}} |||d  f qS )r`   r   )r   r   b)rR   r   r   r   +  s   $ r   )
r)   ri   r   r#   splitreshaperZ   r   r   	enumerate)rB   r   r7   r}   ro   r[   rp   _r   cspatialr   )r   rR   r   r~   "  s    zCAM.compute_mapc                 K  s"   | j |||fi |}| ||S )a  
        Compute the activation map with upsampling and postprocessing.

        Args:
            x: input tensor, shape must be compatible with `nn_module`.
            class_idx: index of the class to be visualized. Default to argmax(logits)
            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

        Returns:
            activation maps
        r~   r   )rB   r   r7   r}   ro   r   r   r   r   rr   .  s   zCAM.__call__)r)   r*   r3   rV   r   r   rz   r   r{   r   r   r|   r   )
ru   rv   rw   rx   r	   r   rF   r~   rr   __classcell__r   r   r   r   r
      s    -
r
   c                   @  s$   e Zd ZdZd	ddZd
ddZdS )r   a  
    Computes Gradient-weighted Class Activation Mapping (Grad-CAM).
    This implementation is based on:

        Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization,
        https://arxiv.org/abs/1610.02391

    Examples

    .. code-block:: python

        import torch

        # densenet 2d
        from monai.networks.nets import DenseNet121
        from monai.visualize import GradCAM

        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
        cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu")
        result = cam(x=torch.rand((1, 1, 48, 64)))

        # resnet 2d
        from monai.networks.nets import seresnet50
        from monai.visualize import GradCAM

        model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)
        cam = GradCAM(nn_module=model_2d, target_layers="layer4")
        result = cam(x=torch.rand((2, 3, 48, 64)))

    N.B.: To help select the target layer, it may be useful to list all layers:

    .. code-block:: python

        for name, _ in model.named_modules(): print(name)

    See Also:

        - :py:class:`monai.visualize.class_activation_maps.CAM`

    NFra   c                 K  s   | j |f||d|\}}}|| || }}|j^}	}
}||	|
ddj|	|
gdgt| R  }|| jddd}t|S )Nr7   rf   ra   r   r`   Tkeepdim)r)   r   viewmeanr?   rl   Frelu)rB   r   r7   rf   r}   ro   r   rp   rq   r   r   r   weightsr   r   r   r   r~   i  s   .
zGradCAM.compute_mapc                 K  s&   | j |f|||d|}| ||S )aD  
        Compute the activation map with upsampling and postprocessing.

        Args:
            x: input tensor, shape must be compatible with `nn_module`.
            class_idx: index of the class to be visualized. Default to argmax(logits)
            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
            retain_graph: whether to retain_graph for torch module backward call.
            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

        Returns:
            activation maps
        )r7   rf   r}   r   )rB   r   r7   r}   rf   ro   r   r   r   r   rr   q  s   zGradCAM.__call__NFra   )Nra   F)ru   rv   rw   rx   r~   rr   r   r   r   r   r   ?  s    
)r   c                   @  s   e Zd ZdZdddZdS )r   aW  
    Computes Gradient-weighted Class Activation Mapping (Grad-CAM++).
    This implementation is based on:

        Chattopadhyay et al., Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks,
        https://arxiv.org/abs/1710.11063

    See Also:

        - :py:class:`monai.visualize.class_activation_maps.GradCAM`

    NFra   c                 K  s  | j |f||d|\}}}|| || }}|j^}	}
}|d}|d||d|	|
ddj|	|
gdgt| R   }t|dk|t	|}|
|d }tttj| j j | }|| |	|
ddj|	|
gdgt| R  }|| jddd	}t|S )
Nr   r      ra   r`   r   gHz>Tr   )r)   r   powmulr   rl   r?   r#   where	ones_likedivr   r   r   r$   r6   exp)rB   r   r7   rf   r}   ro   r   rp   rq   r   r   r   Zalpha_nrZalpha_dralphaZ	relu_gradr   r   r   r   r   r~     s   
D2
zGradCAMpp.compute_mapr   )ru   rv   rw   rx   r~   r   r   r   r   r     s    r   )r   r   r   r   )
__future__r   r@   collections.abcr   r   typingr   r(   r   r#   torch.nnrW   torch.nn.functional
functionalr   monai.configr   monai.transformsr   monai.utilsr   Zmonai.visualize.visualizerr	   __all__r   r   ry   r
   r   r   r   r   r   r   <module>   s&   
j>eD