U
    tPhd<                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlm	Z	m
Z
 d dlmZ d dlmZ dd	d
ddZdd	d
ddZdd	d
ddZeeedZG dd deZG dd deZdS )    )annotationsN)
functional)_Loss)gaussian_1dseparable_filtering)LossReduction)look_up_optioninttorch.Tensor)kernel_sizereturnc                 C  s
   t | S )N)torchones)r    r   U/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/image_dissimilarity.pymake_rectangular_kernel   s    r   c                 C  sf   | d d }|d dkr |d8 }t jdd|ft jd|}| | d |d  }tj|||ddS )N      r   dtype)padding)r   r   floatdivFconv1dreshape)r   fsizefr   r   r   r   make_triangular_kernel   s    r   c                 C  s6   t | d }t|| d dddd|  }|d |  S )Ng      @r   sampledF)sigma	truncatedapprox	normalizeg@)r   tensorr   )r   r!   kernelr   r   r   make_gaussian_kernel$   s
    r'   )rectangular
triangulargaussianc                	      s\   e Zd ZdZdddejddfddddd	d	d
d fddZdd ZddddddZ  Z	S )#LocalNormalizedCrossCorrelationLossa  
    Local squared zero-normalized cross-correlation.
    The loss is based on a moving kernel/window over the y_true/y_pred,
    within the window the square of zncc is calculated.
    The kernel can be a rectangular / triangular / gaussian window.
    The final loss is the averaged loss over all windows.

    Adapted from:
        https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
        DeepReg (https://github.com/DeepRegNet/DeepReg)
       r(           gh㈵>r	   strLossReduction | strr   None)spatial_dimsr   kernel_type	reduction	smooth_nr	smooth_drr   c                   s   t  jt|jd || _| jdkr6td| j d|| _| jd dkrZtd| j t|t}|| j| _	d| j	_
|  | _t|| _t|| _d	S )
a1  
        Args:
            spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
            kernel_size: kernel spatial size, must be odd.
            kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
            smooth_nr: a small constant added to the numerator to avoid nan.
            smooth_dr: a small constant added to the denominator to avoid nan.

        r3   >   r   r   r,   zUnsupported ndim: z/-d, only 1-d, 2-d, and 3-d inputs are supportedr   r   zkernel_size must be odd, got FN)super__init__r   valuendim
ValueErrorr   r   kernel_dictr&   Zrequire_gradsget_kernel_vol
kernel_volr   r4   r5   )selfr1   r   r2   r3   r4   r5   _kernel	__class__r   r   r8   @   s    



z,LocalNormalizedCrossCorrelationLoss.__init__c                 C  s>   | j }t| jd D ]}t|d| j d}qt|S )Nr   r   r   )r&   ranger:   r   matmul	unsqueezesum)r?   vol_r   r   r   r=   j   s    z2LocalNormalizedCrossCorrelationLoss.get_kernel_volr
   predtargetr   c                 C  s  |j d | j kr(td| j  d|j |j|jkrNtd|j d|j d|| || ||   }}}| j|| j| }}|g| j  }t||d}	t||d}
t||d}t||d}t||d}|	| }|
| }|||	  }t|||	  tj	| j
|j|jd}t|||
  tj	| j
|j|jd}|| | j ||  }| jtjjkrht| S | jtjjkr| S | jtjjkrt| S td	| j d
dS )z
        Args:
            pred: the shape should be BNH[WD].
            target: the shape should be BNH[WD].
        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
        r   zexpecting pred with z' spatial dimensions, got pred of shape "ground truth has differing shape () from pred ())kernels)r   deviceUnsupported reduction: 0, available options are ["mean", "sum", "none"].N)r:   r;   shaper&   tor>   r   r   max	as_tensorr5   r   rP   r4   r3   r   SUMr9   rF   negNONEMEANmean)r?   rJ   rK   t2p2tpr&   r>   rO   Zt_sump_sumZt2_sumZp2_sumZtp_sumZt_avgZp_avgcrossZt_varZp_varZnccr   r   r   forwardp   s>    
 
 z+LocalNormalizedCrossCorrelationLoss.forward)
__name__
__module____qualname____doc__r   rZ   r8   r=   ra   __classcell__r   r   rA   r   r+   3   s    *r+   c                	      s   e Zd ZdZdddejddfdddd	ddd
d fddZddddddZddddddZdddddZ	ddddddZ
  ZS )GlobalMutualInformationLossz
    Differentiable global mutual information loss via Parzen windowing method.

    Reference:
        https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1
    r*            ?gHz>r.   r	   r   r/   r0   )r2   num_binssigma_ratior3   r4   r5   r   c           	        s   t  jt|jd |dkr$tdtdd|}t|dd |dd  | }t|d	d
g| _	|| _
|| _	| j	d	krdd|d   | _|d | _t|| _t|| _dS )a  
        Args:
            kernel_type: {``"gaussian"``, ``"b-spline"``}
                ``"gaussian"``: adapted from DeepReg
                Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1.
                ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK
                References:
                  [1] "Nonrigid multimodality image registration"
                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
                      Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620.
                  [2] "PET-CT Image Registration in the Chest Using Free-form Deformations"
                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
                      IEEE Transactions in Medical Imaging. Vol.22, No.1,
                      January 2003. pp.120-128.

            num_bins: number of bins for intensity
            sigma_ratio: a hyper param for gaussian function
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
            smooth_nr: a small constant added to the numerator to avoid nan.
            smooth_dr: a small constant added to the denominator to avoid nan.
        r6   r   z!num_bins must > 0, got {num_bins}r-   g      ?r   Nr   r*   b-spliner   )NN.)r7   r8   r   r9   r;   r   linspacer[   r   r2   rj   pretermbin_centersr   r4   r5   )	r?   r2   rj   rk   r3   r4   r5   ro   r!   rA   r   r   r8      s    #"


z$GlobalMutualInformationLoss.__init__r
   z=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]rI   c                 C  sh   | j dkr(| |\}}| |\}}n4| j dkrX| j|dd\}}| j|dd\}}nt||||fS )Nr*   rl   r,   )orderr   )r2   parzen_windowing_gaussianparzen_windowing_b_spliner;   )r?   rJ   rK   Zpred_weightZpred_probabilityZtarget_weightZtarget_probabilityr   r   r   parzen_windowing   s    

z,GlobalMutualInformationLoss.parzen_windowingz!tuple[torch.Tensor, torch.Tensor])imgrp   r   c                 C  sp  t |t | }}d}|| | jd|   }t ||| }t ||| }t ||| j| d }||jd dd}t j| j|j	dddd}	t 
|	| }
t j|
t jd}|dkr||
dk  |
dkd  }nl|dkr4|d	d
|
d   d|
d   |
dk  d
  }|d|
 d |
dk |
dk  d
  }ntd| d|t j|ddd }t j|ddd}||fS )z
        Parzen windowing with b-spline kernel (adapted from ITK)

        Args:
            img: the shape should be B[NDHW].
            order: int.
        r   r   r   r   )rP   r   ri   r,         zDo not support b-spline z-order parzen windowingTdimkeepdim)r   rU   minrj   r   clampr   rS   arangerP   abs
zeros_liker   r;   rF   r[   )r?   rt   rp   _max_minr   bin_sizenorm_minZwindow_termbinsZsample_bin_matrixweightprobabilityr   r   r   rr      s&    
,&z5GlobalMutualInformationLoss.parzen_windowing_b_spline)rt   r   c                 C  sv   t |dd}||jd dd}t | j| || j| d  }|t j|ddd }t j	|ddd}||fS )z
        Parzen windowing with gaussian kernel (adapted from DeepReg implementation)
        Note: the input is expected to range between 0 and 1
        Args:
            img: the shape should be B[NDHW].
        r   r   r   r   Trw   rz   )
r   r|   r   rS   exprn   rT   ro   rF   r[   )r?   rt   r   r   r   r   r   rq   !  s     z5GlobalMutualInformationLoss.parzen_windowing_gaussianc           
      C  s  |j |j kr&td|j  d|j  d| ||\}}}}t|ddd|||j d }t|ddd||}tj|t	|| j
 || j  | j  dd}	| jtjjkrt|	 S | jtjjkr|	 S | jtjjkrt|	 S td	| j d
dS )z
        Args:
            pred: the shape should be B[NDHW].
            target: the shape should be same as the pred shape.
        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
        rL   rM   rN   r   r   r   )r   r   )rx   rQ   rR   N)rS   r;   rs   r   bmmpermuterT   r   rF   logr4   r5   r3   r   rW   r9   rX   rY   rZ   r[   )
r?   rJ   rK   wapawbpbZpabZpapbmir   r   r   ra   1  s     (" z#GlobalMutualInformationLoss.forward)rb   rc   rd   re   r   rZ   r8   rs   rr   rq   ra   rf   r   r   rA   r   rg      s   	 13rg   )
__future__r   r   torch.nnr   r   torch.nn.modules.lossr   monai.networks.layersr   r   monai.utilsr   monai.utils.moduler   r   r   r'   r<   r+   rg   r   r   r   r   <module>   s   		s