U
    PhF                     @  s   d dl mZ d dlmZmZmZ d dlmZ d dlZ	d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZmZmZmZmZ d dlmZ G d	d
 d
ZdS )    )annotations)CallableMappingSequence)AnyN)
MetaTensor)	eval_mode)ComposeGaussianSmoothLambdaScaleIntensitySpatialCrop)ensure_tuple_repc                   @  s   e Zd ZdZd-dddd	d
ddddddZedddddddZed.dddddddZeddddddddddd
dd Zed!ddd"d#d$d%Z	d/dd'd(d)d*d+d,Z
d&S )0OcclusionSensitivitya	  
    This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity,
    we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be
    useful to understand why a network is making certain decisions.

    As important parts of the image are occluded, the probability of classifying the image correctly will decrease.
    Hence, more negative values imply the corresponding occluded volume was more important in the decision process.

    Two ``torch.Tensor`` will be returned by the ``__call__`` method: an occlusion map and an image of the most probable
    class. Both images will be cropped if a bounding box used, but voxel sizes will always match the input.

    The occlusion map shows the inference probabilities when the corresponding part of the image is occluded. Hence,
    more -ve values imply that region was important in the decision process. The map will have shape ``BCHW(D)N``,
    where ``N`` is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can
    be seen with ``map[...,i]``.

    The most probable class is an image of the probable class when the corresponding part of the image is occluded
    (equivalent to ``occ_map.argmax(dim=-1)``).

    See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via
    Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74.

    Examples:

    .. code-block:: python

        # densenet 2d
        from monai.networks.nets import DenseNet121
        from monai.visualize import OcclusionSensitivity
        import torch

        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
        occ_sens = OcclusionSensitivity(nn_module=model_2d)
        occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), b_box=[2, 40, 1, 62])

        # densenet 3d
        from monai.networks.nets import DenseNet
        from monai.visualize import OcclusionSensitivity

        model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
        occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10)
        occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), b_box=[1, 3, -1, -1, -1, -1])

    See Also:

        - :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.`
       Tgaussian      ?z	nn.Modulezint | Sequenceintboolzstr | float | Callablefloatzbool | CallableNone)	nn_module	mask_sizen_batchverbosemodeoverlapactivatereturnc                 C  sD   || _ || _|| _|| _|| _|| _t|tr:|dkr:t|| _	dS )at  
        Occlusion sensitivity constructor.

        Args:
            nn_module: Classification model to use for inference
            mask_size: Size of box to be occluded, centred on the central voxel. If a single number
                is given, this is used for all dimensions. If a sequence is given, this is used for each dimension
                individually.
            n_batch: Number of images in a batch for inference.
            verbose: Use progress bar (if ``tqdm`` available).
            mode: what should the occluded region be replaced with? If a float is given, that value will be used
                throughout the occlusion. Else, ``gaussian``, ``mean_img`` and ``mean_patch`` can be supplied:

                * ``gaussian``: occluded region is multiplied by 1 - gaussian kernel. In this fashion, the occlusion
                  will be 0 at the center and will be unchanged towards the edges, varying smoothly between. When
                  gaussian is used, a weighted average will be used to combine overlapping regions. This will be
                  done using the gaussian (not 1-gaussian) as occluded regions count more.
                * ``mean_patch``: occluded region will be replaced with the mean of occluded region.
                * ``mean_img``: occluded region will be replaced with the mean of the whole image.

            overlap: overlap between inferred regions. Should be in range 0<=x<1.
            activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any
                activation. If ``callable``, use callable on inferred outputs.

        )r   
mean_patchmean_imgN)
r   r   r   r   r   r   
isinstancestrNotImplementedErrorr   )selfr   r   r   r   r   r   r    r%   Z/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/visualize/occlusion_sensitivity.py__init__L   s    #zOcclusionSensitivity.__init__ztorch.Tensorr   ztuple[float, torch.Tensor])xvalr   r   c                 C  s.   t j| jdd || j| jd}d|| fS )zVOcclude with a constant occlusion. Multiplicative is zero, additive is constant value.N   devicedtyper   )torchonesshaper,   r-   )r(   r)   r   r/   r%   r%   r&   constant_occlusionz   s    "z'OcclusionSensitivity.constant_occlusionztuple[torch.Tensor, float])r(   r   sigmar   c                   s   t j| jd f|| j| jd}|jdd }tdgdd |D  }d||< tt fdd|D dtd	d
 t	 g}||d }|dfS )z
        For Gaussian occlusion, Multiplicative is 1-Gaussian, additive is zero.
        Default sigma of 0.25 empirically shown to give reasonable kernel, see here:
        https://github.com/Project-MONAI/MONAI/pull/5230#discussion_r984520714.
           r+   Nc                 S  s"   g | ]}t |d  |d  d qS )r*   r3   slice).0sr%   r%   r&   
<listcomp>   s     z;OcclusionSensitivity.gaussian_occlusion.<locals>.<listcomp>g      ?c                   s   g | ]}|  qS r%   r%   r6   br2   r%   r&   r8      s     r;   c                 S  s   |  S Nr%   )r(   r%   r%   r&   <lambda>       z9OcclusionSensitivity.gaussian_occlusion.<locals>.<lambda>r   )
r.   zerosr0   r,   r-   r5   r	   r
   r   r   )r(   r   r2   kernelspatial_shapecenterr   mulr%   r;   r&   gaussian_occlusion   s     &z'OcclusionSensitivity.gaussian_occlusionztorch.Tensor | floatr"   zMapping[str, Any])
cropped_gridr   r(   rC   addr   occ_moder   module_kwargsr   c	                 C  sZ  | j d }	| jd }
t||	d}tdgd tdg|
  }| | }t|D ]\}}t||d tdgdd t||D  }|| }|dkrt||	 
 |\}}t|r|||}n|| | }|dks|dkrtd|||< qN||f|}t|r||}n(|r4|j d dkr*| n|d}|D ]}tj|d	|d	d
}q8|S )aT  
        Predictor function to be passed to the sliding window inferer. Takes a cropped meshgrid,
        referring to the coordinates in the input image. We use the index of the top-left corner
        in combination ``mask_size`` to figure out which region of the image is to be occluded. The
        occlusion is performed on the original image, ``x``, using ``cropped_region * mul + add``. ``mul``
        and ``add`` are sometimes pre-computed (e.g., a constant Gaussian blur), or they are
        sometimes calculated on the fly (e.g., the mean of the occluded patch). For this reason
        ``occ_mode`` is given. Lastly, ``activate`` is used to activate after each call of the model.

        Args:
            cropped_grid: subsection of the meshgrid, where each voxel refers to the coordinate of
                the input image. The meshgrid is created by the ``OcclusionSensitivity`` class, and
                the generation of the subset is determined by ``sliding_window_inference``.
            nn_module: module to call on data.
            x: the image that was originally passed into ``OcclusionSensitivity.__call__``.
            mul: occluded region will be multiplied by this. Can be ``torch.Tensor`` or ``float``.
            add: after multiplication, this is added to the occluded region. Can be ``torch.Tensor`` or ``float``.
            mask_size: Size of box to be occluded, centred on the central voxel. Should be
                a sequence, one value for each spatial dimension.
            occ_mode: might be used to calculate ``mul`` and ``add`` on the fly.
            activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any
                activation. If ``callable``, use callable on inferred outputs.
            module_kwargs: kwargs to be passed onto module when inferring
        r   r*   Nr3   c                 S  s&   g | ]\}}t t|t|| qS r%   )r5   r   )r6   jmr%   r%   r&   r8      s     z2OcclusionSensitivity.predictor.<locals>.<listcomp>r   z,Shouldn't be here, something's gone wrong...)dim)r0   ndimr.   repeat_interleaver5   	enumeratezipr   r1   meanitemcallableRuntimeErrorsigmoidsoftmax	unsqueeze)rE   r   r(   rC   rF   r   rG   r   rH   r   sdimZcorner_coord_slicesZtop_cornersr:   tslicesZ
to_occludeoutrJ   r%   r%   r&   	predictor   s0    $

*


"zOcclusionSensitivity.predictorr   z(tuple[MetaTensor, SpatialCrop, Sequence])gridb_boxr   r   c                 C  s   dd |D }dd t |ddd |D }g }t |ddd || jdd D ]2\}}}|dkrn|| qP|t|| | qPdd t ||D }	t|	d	}
|
| d
 d }t|}t|jdd D ]\}}t||| ||< q||
|fS )zXCrop the meshgrid so we only perform occlusion sensitivity on a subsection of the image.c                 S  s   g | ]}|d  d qS )r3   r*   r%   )r6   rJ   r%   r%   r&   r8      s     z6OcclusionSensitivity.crop_meshgrid.<locals>.<listcomp>c                 S  s   g | ]\}}t || d qS r   max)r6   r:   rJ   r%   r%   r&   r8      s     Nr*   r3   rK   c                 S  s   g | ]\}}t ||qS r%   r4   )r6   r7   er%   r%   r&   r8      s     )
roi_slicesr   )rP   r0   appendminr   listrO   )r^   r_   r   Z	mask_edgebbox_minbbox_maxr:   rJ   r7   r[   croppercroppedir%   r%   r&   crop_meshgrid   s    *
z"OcclusionSensitivity.crop_meshgridNzSequence | Noner   z!tuple[torch.Tensor, torch.Tensor])r(   r_   kwargsr   c                 K  s*  |j d dkrtd|jd }t| j|}tttjdd |j dd D dd	id |j	|j
d
}|dk	r| |||\}}}tdd t|j dd |D rtd|j dd  d| dt| jtr| || j|\}}	nN| jdkr| ||  |\}}	n&| jdkr.| ||\}}	nd\}	}t| jV ddlm}
 |
||| jtj| j| jdkrpdnd| j| j||	||| j| j|d}W 5 Q R X |dk	r||d d }dd |ddd D }dd t|ddd |j dd D }t ||d}||d d }|j!ddd}||fS )a  
        Args:
            x: Image to use for inference. Should be a tensor consisting of 1 batch.
            b_box: Bounding box on which to perform the analysis. The output image will be limited to this size.
                There should be a minimum and maximum for all spatial dimensions: ``[min1, max1, min2, max2,...]``.
                * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might
                    be useful for larger images.
                * Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``.
                * Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension.
                * N.B.: we add half of the mask size to the bounding box to ensure that the region of interest has a
                    sufficiently large area surrounding it.
            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

        Returns:
            * Occlusion map:
                * Shows the inference probabilities when the corresponding part of the image is occluded.
                    Hence, more -ve values imply that region was important in the decision process.
                * The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the
                    network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``.
                * If `per_channel==False`, output ``C`` will equal 1: ``B1HW(D)N``
            * Most probable class:
                * The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``).
            Both images will be cropped if a bounding box used, but voxel sizes will always match the input.
        r   r3   zExpected batch size of 1.r*   c                 S  s   g | ]}t d |qS r`   )nparange)r6   rl   r%   r%   r&   r8   !  s     z1OcclusionSensitivity.__call__.<locals>.<listcomp>Nindexingijr+   c                 s  s   | ]\}}||kV  qd S r<   r%   )r6   grJ   r%   r%   r&   	<genexpr>*  s     z0OcclusionSensitivity.__call__.<locals>.<genexpr>zImage (spatial shape) z should be bigger than mask .r    r   )NN)sliding_window_inferenceconstant)roi_sizesw_batch_sizer]   r   r   progressr   r(   rF   rC   r   rG   r   rH   c                 S  s   g | ]}t |d qS r`   ra   r9   r%   r%   r&   r8   W  s     c                 S  s    g | ]\}}|d kr|n|qS r`   r%   )r6   r:   r7   r%   r%   r&   r8   X  s     )	roi_startroi_endT)rL   keepdim)"r0   
ValueErrorrM   r   r   r   ro   stackmeshgridr,   r-   rm   anyrP   r!   r   r   r1   rQ   rR   rD   r   r   monai.inferersrv   r   r   r]   r   r   r   inverser   argmax)r$   r(   r_   rn   rX   r   r^   rj   rC   rF   rv   Zsensitivity_imrh   ri   Zmost_probable_classr%   r%   r&   __call__   s^    
," 
(zOcclusionSensitivity.__call__)r   r   Tr   r   T)r   )N)__name__
__module____qualname____doc__r'   staticmethodr1   rD   r]   rm   r   r%   r%   r%   r&   r      s$   3      ."K r   )
__future__r   collections.abcr   r   r   typingr   numpyro   r.   torch.nnnnmonai.data.meta_tensorr   monai.networks.utilsr   monai.transformsr	   r
   r   r   r   monai.utilsr   r   r%   r%   r%   r&   <module>   s   