U
    tPh%                     @  sb   d dl mZ d dlZd dlmZ d dlmZ dddddd	ZG d
d deZG dd deZ	dS )    )annotationsN)_Loss)LossReductiontorch.Tensorint)xdimreturnc                 C  s   t dd}t dd}t dd}t d}||g||g }}t|| jk r\||g }||g }q8|||< |||< | | | |  d S )a  
    Calculate gradients on single dimension of a tensor using central finite difference.
    It moves the tensor along the dimension to calculate the approximate gradient
    dx[i] = (x[i+1] - x[i-1]) / 2.
    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)

    Args:
        x: the shape should be BCH(WD).
        dim: dimension to calculate gradient along.
    Returns:
        gradient_dx: the shape should be BCH(WD)
          Ng       @)slicelenndim)r   r   Zslice_1Z	slice_2_sZ	slice_2_e	slice_allZ	slicing_sZ	slicing_e r   H/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/deform.pyspatial_gradient   s    



r   c                      sB   e Zd ZdZdejfdddd fddZd	d	d
ddZ  ZS )BendingEnergyLossaO  
    Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference.

    For more information,
    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    FboolLossReduction | strNone	normalize	reductionr	   c                   s   t  jt|jd || _dS ae  
        Args:
            normalize:
                Whether to divide out spatial sizes in order to make the computation roughly
                invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
        )r   Nsuper__init__r   valuer   selfr   r   	__class__r   r   r   :   s    zBendingEnergyLoss.__init__r   predr	   c           	        s    j dkrtd j t j d D ]0} j| d  dkr(td jdd  q( jd  j d krtd jd  d	 j d   fd
dtd j D }| jrtj j jddd d j d d  }td}t	|D ]\}}|d7 }| jr6| j| | 9 }|t
|| j|  d  }n|t
||d  }t|d  j D ]F}| jr|dt
|| j|  d   }n|dt
||d   }qXq| jtjjkrt|}n>| jtjjkrt|}n"| jtjjkrtd| j d|S )a  
        Args:
            pred: the shape should be BCH(WD)

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4.
            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

                 :Expecting 3-d, 4-d or 5-d pred, instead got pred of shape r   r
   r)   z;All spatial dimensions must be > 4, got spatial dimensions NGNumber of vector components, i.e. number of channels of the input DDF, /, does not match number of spatial dimensions, c                   s   g | ]}t  |qS r   r   .0r   r&   r   r   
<listcomp>b   s     z-BendingEnergyLoss.forward.<locals>.<listcomp>devicer
   r   r
   r   Unsupported reduction: 0, available options are ["mean", "sum", "none"].)r   
ValueErrorshaperanger   torchtensorr4   reshape	enumerater   r   r   MEANr    meanSUMsumNONE)	r"   r&   ifirst_order_gradientspatial_dimsenergydim_1gZdim_2r   r1   r   forwardJ   s<    
.
"zBendingEnergyLoss.forward	__name__
__module____qualname____doc__r   r@   r   rK   __classcell__r   r   r#   r   r   /   s   
r   c                      sB   e Zd ZdZdejfdddd fddZd	d	d
ddZ  ZS )DiffusionLossah  
    Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference.
    For the original paper, please refer to
    VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
    Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
    IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.

    For more information,
    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.

    Adapted from:
        VoxelMorph (https://github.com/voxelmorph/voxelmorph)
    Fr   r   r   r   c                   s   t  jt|jd || _dS r   r   r!   r#   r   r   r      s    zDiffusionLoss.__init__r   r%   c                   s   j dkrtd j t j d D ]0} j| d  dkr(td jdd  q( jd  j d krtd jd  d j d   fd	d
td j D }| jrtj j jddd d j d d  }td}t	|D ]6\}}|d7 }| jr| j| | 9 }||d  }q| j
tjjkrBt|}n>| j
tjjkr^t|}n"| j
tjjkrtd| j
 d|S )a  
        Args:
            pred:
                Predicted dense displacement field (DDF) with shape BCH[WD],
                where C is the number of spatial dimensions.
                Note that diffusion loss can only be calculated
                when the sizes of the DDF along all spatial dimensions are greater than 2.

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

        r'   r+   r   r
   z;All spatial dimensions must be > 2, got spatial dimensions Nr,   r-   c                   s   g | ]}t  |qS r   r.   r/   r1   r   r   r2      s     z)DiffusionLoss.forward.<locals>.<listcomp>r3   r5   r6   r   r7   r8   )r   r9   r:   r;   r   r<   r=   r4   r>   r?   r   r   r@   r    rA   rB   rC   rD   )r"   r&   rE   rF   rG   	diffusionrI   rJ   r   r1   r   rK      s2    
.
zDiffusionLoss.forwardrL   r   r   r#   r   rR      s   rR   )

__future__r   r<   torch.nn.modules.lossr   monai.utilsr   r   r   rR   r   r   r   r   <module>   s   Q