o
    i%                     @  sZ   d dl mZ d dlZd dlmZ d dlmZ dd
dZG dd deZG dd deZ	dS )    )annotationsN)_Loss)LossReductionxtorch.Tensordimintreturnc                 C  s   t dd}t dd}t dd}t d}||g||g}}t|| jk r4||g }||g }t|| jk s#|||< |||< | | | |  d S )a  
    Calculate gradients on single dimension of a tensor using central finite difference.
    It moves the tensor along the dimension to calculate the approximate gradient
    dx[i] = (x[i+1] - x[i-1]) / 2.
    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)

    Args:
        x: the shape should be BCH(WD).
        dim: dimension to calculate gradient along.
    Returns:
        gradient_dx: the shape should be BCH(WD)
          Ng       @)slicelenndim)r   r   Zslice_1Z	slice_2_sZ	slice_2_e	slice_allZ	slicing_sZ	slicing_e r   U/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/deform.pyspatial_gradient   s   




r   c                      4   e Zd ZdZdejfd fd	d
ZdddZ  ZS )BendingEnergyLossaO  
    Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference.

    For more information,
    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    F	normalizebool	reductionLossReduction | strr	   Nonec                      t  jt|jd || _dS ae  
        Args:
            normalize:
                Whether to divide out spatial sizes in order to make the computation roughly
                invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
        )r   Nsuper__init__r   valuer   selfr   r   	__class__r   r   r    :      
zBendingEnergyLoss.__init__predr   c           	        s   j dvrtd j t j d D ]} j| d  dkr,td jdd  q jd  j d krGtd jd  d	 j d   fd
dtd j D }| jrntj j jddd d j d d  }td}t	|D ]V\}}|d7 }| jr| j| | 9 }|t
|| j|  d  }n	|t
||d  }t|d  j D ]!}| jr|dt
|| j|  d   }q|dt
||d   }qqw| jtjjkrt|}|S | jtjjkrt|}|S | jtjjkrtd| j d|S )a  
        Args:
            pred: the shape should be BCH(WD)

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4.
            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

                 :Expecting 3-d, 4-d or 5-d pred, instead got pred of shape r   r
   r*   z;All spatial dimensions must be > 4, got spatial dimensions NGNumber of vector components, i.e. number of channels of the input DDF, /, does not match number of spatial dimensions, c                      g | ]}t  |qS r   r   .0r   r'   r   r   
<listcomp>b       z-BendingEnergyLoss.forward.<locals>.<listcomp>devicer
   r   r
   r   Unsupported reduction: 0, available options are ["mean", "sum", "none"].)r   
ValueErrorshaperanger   torchtensorr7   reshape	enumerater   r   r   MEANr!   meanSUMsumNONE)	r#   r'   ifirst_order_gradientspatial_dimsenergydim_1gZdim_2r   r3   r   forwardJ   sH   
.
"

zBendingEnergyLoss.forwardr   r   r   r   r	   r   r'   r   r	   r   	__name__
__module____qualname____doc__r   rC   r    rN   __classcell__r   r   r$   r   r   /   s    
r   c                      r   )DiffusionLossah  
    Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference.
    For the original paper, please refer to
    VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
    Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
    IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.

    For more information,
    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.

    Adapted from:
        VoxelMorph (https://github.com/voxelmorph/voxelmorph)
    Fr   r   r   r   r	   r   c                   r   r   r   r"   r$   r   r   r       r&   zDiffusionLoss.__init__r'   r   c                   s   j dvrtd j t j d D ]} j| d  dkr,td jdd  q jd  j d krGtd jd  d j d   fd	d
td j D }| jrntj j jddd d j d d  }td}t	|D ]\}}|d7 }| jr| j| | 9 }||d  }qw| j
tjjkrt|}|S | j
tjjkrt|}|S | j
tjjkrtd| j
 d|S )a  
        Args:
            pred:
                Predicted dense displacement field (DDF) with shape BCH[WD],
                where C is the number of spatial dimensions.
                Note that diffusion loss can only be calculated
                when the sizes of the DDF along all spatial dimensions are greater than 2.

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

        r(   r,   r   r
   z;All spatial dimensions must be > 2, got spatial dimensions Nr-   r.   c                   r/   r   r0   r1   r3   r   r   r4      r5   z)DiffusionLoss.forward.<locals>.<listcomp>r6   r8   r9   r   r:   r;   )r   r<   r=   r>   r   r?   r@   r7   rA   rB   r   r   rC   r!   rD   rE   rF   rG   )r#   r'   rH   rI   rJ   Z	diffusionrL   rM   r   r3   r   rN      s<   
.


zDiffusionLoss.forwardrO   rP   rQ   r   r   r$   r   rW      s    rW   )r   r   r   r   r	   r   )

__future__r   r?   torch.nn.modules.lossr   monai.utilsr   r   r   rW   r   r   r   r   <module>   s   
Q