o
    ,iP                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlm  mZ	 d dl
mZ d dlmZ d dlmZmZ d dlmZ g dZG d	d
 d
ejZeZG dd dejZeZdS )    )annotations)SequenceN)Convolution)UpSample)DVF2DDFWarp)SkipConnection)VoxelMorphUNetvoxelmorphunet
VoxelMorph
voxelmorphc                      sh   e Zd ZdZ									d4d5 fd d!Zd6d&d'Zd7d*d+Zd8d,d-Zd7d.d/Zd9d2d3Z	  Z
S ):r	   aY  
    The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details.

    A concatenated pair of images (moving and fixed) is first passed through a UNet. The output of the UNet is then
    passed through a series of convolution blocks to produce the final prediction of the displacement field (DDF) or the
    stationary velocity field (DVF).

    In the original implementation, downsample is achieved through maxpooling, here one has the option to use either
    maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the
    original implementation. Note that for upsampling, the authors of VoxelMorph used nearest neighbor interpolation
    instead of transposed convolution. In this implementation, only nearest neighbor interpolation is supported in order
    to be consistent with the original implementation.

    An instance of this class can be used as a backbone network for constructing a VoxelMorph network. See the
    documentation of :py:class:`monai.networks.nets.VoxelMorph` for more details and an example on how to construct a
    VoxelMorph network.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of channels in the input volume after concatenation of moving and fixed images.
        unet_out_channels: number of channels in the output of the UNet.
        channels: number of channels in each layer of the UNet. See the following example for more details.
        final_conv_channels: number of channels in each layer of the final convolution block.
        final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU.
            Since VoxelMorph was originally implemented in tensorflow where the default negative slope for
            LeakyReLU was 0.2, we use the same default value here.
        kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3.
        up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3.
        act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2.
        norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None.
        dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout).
        bias: whether to use bias in all convolution layers in the UNet. Defaults to True.
        use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True.
            Using maxpooling is the consistent with the original implementation of VoxelMorph.
            But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False).
        adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA".
    	LEAKYRELU   N        TNDAspatial_dimsintin_channelsunet_out_channelschannelsSequence[int]final_conv_channelsfinal_conv_acttuple | str | Nonekernel_sizeSequence[int] | intup_kernel_sizeacttuple | strnormdropoutfloatbiasbooluse_maxpooladn_orderingstrreturnNonec                   s|  t    |dvrtd|d dkrtdt|dk r!tdt|d dkr-tdt|tr<t||kr<tdt|trKt||krKtd	|_|_|_|_	|_
|_t|	tro|	 d
kroddddfn|	_|
_|_|_|_|_|_t|tr| d
krddddfn|_d fdd dfdd}t ||j	dd||jj_d S )N)   r   z#spatial_dims must be either 2 or 3.r)   r   z#in_channels must be divisible by 2.z2the length of `channels` should be no less than 2.z8the elements of `channels` should be specified in pairs.z9the length of `kernel_size` should equal to `dimensions`.z<the length of `up_kernel_size` should equal to `dimensions`.r   	leakyrelug?T)negative_slopeinplaceincr   outcr   r   is_topr#   r'   	nn.Modulec           
        st   |dd \}}|| }t |dkr |||dd dd}n||}| ||}|||}	||	|S )a>  
            Builds the UNet structure recursively.

            Args:
                inc: number of input channels.
                outc: number of output channels.
                channels: sequence of channels for each pair of down and up layers.
                is_top: True if this is the top block.
            r   r)   NFr/   )len_get_bottom_layer_get_down_layer_get_up_layer_get_connection_block)
r-   r.   r   r/   Z	next_c_inZ
next_c_outupcsubblockdownup_create_blockself `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/voxelmorph.pyr<   }   s   z.VoxelMorphUNet.__init__.<locals>._create_blockc                   s   t  }t|D ]!\}}|d| t j| | j j j j	 j
 jd	 |} q|dt j| | jd j j	 j
 jd	 |S )a  
            Builds the final convolution blocks.

            Args:
                inc: number of input channels, should be the same as `unet_out_channels`.
                outc: number of output channels, should be the same as `spatial_dims`.
                channels: sequence of channels for each convolution layer.

            Note: there is no activation after the last convolution layer as per the original implementation.
            Zfinal_conv_)r   r   r   r    r"   r%   Zfinal_conv_outN)nn
Sequential	enumerate
add_moduler   
dimensionsr   r   r   r    r"   r%   )r-   r.   r   modic)r=   r>   r?   _create_final_conv   s@   z3VoxelMorphUNet.__init__.<locals>._create_final_convr1   )
r-   r   r.   r   r   r   r/   r#   r'   r0   )r-   r   r.   r   r   r   r'   r0   )super__init__
ValueErrorr2   
isinstancer   rD   r   r   r   r   r   r&   upperr   r   r    r"   r$   r%   r   r   r@   rA   net)r=   r   r   r   r   r   r   r   r   r   r   r    r"   r$   r%   rH   	__class__r;   r?   rJ   C   sV   
0
zVoxelMorphUNet.__init__	down_pathr0   up_pathr8   c                 C  s   t |t||S )a  
        Returns the block object defining a layer of the UNet structure including the implementation of the skip
        between encoding (down) and decoding (up) sides of the network.

        Args:
            down_path: encoding half of the layer
            up_path: decoding half of the layer
            subblock: block defining the next layer in the network.

        Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
        )r@   rA   r   )r=   rQ   rR   r8   r>   r>   r?   r6      s   z$VoxelMorphUNet._get_connection_blockout_channelsr/   c                 C  s   | j s|rdnd}t| j|||| j| j| j| j| j| jd
}| j r>|s>| jdkr3t	
t	jddd|n
t	
t	jddd|}|S )a  
        In each down layer, the input is first downsampled using maxpooling,
        then passed through a convolution block, unless this is the top layer
        in which case the input is passed through a convolution block only
        without maxpooling first.

        Args:
            in_channels: number of input channels.
            out_channels: number of output channels.
            is_top: True if this is the top block.
           r)   )stridesr   r   r   r    r"   r%   r   )r   stride)r$   r   rD   r   r   r   r    r"   r%   r@   rA   	MaxPool3d	MaxPool2dr=   r   rS   r/   rU   rE   r>   r>   r?   r4      s&   

zVoxelMorphUNet._get_down_layerc              	   C  s4   | j ||dd}t| j||ddddd}t||S )z
        Bottom layer (bottleneck) in voxelmorph consists of a typical down layer followed by an upsample layer.

        Args:
            in_channels: number of input channels.
            out_channels: number of output channels.
        Fr1   r)   nontrainablenearestNscale_factormodeinterp_modealign_corners)r4   r   rD   r@   rA   )r=   r   rS   rE   upsampler>   r>   r?   r3     s   
z VoxelMorphUNet._get_bottom_layerc                 C  sX   d}t | j|||| j| j| j| j| jd| jd}|s*t	|t
| j||ddddd}|S )	a  
        In each up layer, the input is passed through a convolution block before upsampled,
        unless this is the top layer in which case the input is passed through a convolution block only
        without upsampling.

        Args:
            in_channels: number of input channels.
            out_channels: number of output channels.
            is_top: True if this is the top block.
        rT   F)rU   r   r   r   r    r"   is_transposedr%   r)   rZ   r[   Nr\   )r   rD   r   r   r   r    r"   r%   r@   rA   r   rY   r>   r>   r?   r5     s8   zVoxelMorphUNet._get_up_layerconcatenated_pairstorch.Tensorc                 C  s   |  |}|S )N)rN   )r=   rc   xr>   r>   r?   forwardL  s   
zVoxelMorphUNet.forward)	r   r   r   r   Nr   TTr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r#   r%   r&   r'   r(   )rQ   r0   rR   r0   r8   r0   r'   r0   )r   r   rS   r   r/   r#   r'   r0   )r   r   rS   r   r'   r0   )rc   rd   r'   rd   )__name__
__module____qualname____doc__rJ   r6   r4   r3   r5   rf   __classcell__r>   r>   rO   r?   r	      s$    - 


'
/r	   c                      s6   e Zd ZdZ				dd fddZdddZ  ZS )r   a  
    A re-implementation of VoxelMorph framework for medical image registration as described in
    https://arxiv.org/pdf/1809.05231.pdf. For more details, please refer to VoxelMorph: A Learning Framework for
    Deformable Medical Image Registration, Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
    IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.

    This class is intended to be a general framework, based on which a deformable image registration
    network can be built. Given a user-specified backbone network (e.g., UNet in the original VoxelMorph paper), this
    class serves as a wrapper that concatenates the input pair of moving and fixed images, passes through the backbone
    network, integrate the predicted stationary velocity field (DVF) from the backbone network to obtain the
    displacement field (DDF), and, finally, warp the moving image using the DDF.

    To construct a VoxelMorph network, one need to first construct a backbone network
    (e.g., a :py:class:`monai.networks.nets.VoxelMorphUNet`) and pass it to the constructor of
    :py:class:`monai.networks.nets.VoxelMorph`. The backbone network should be able to take a pair of moving and fixed
    images as input and produce a DVF (or DDF, details to be discussed later) as output.

    When `forward` is called, the input moving and fixed images are first concatenated along the channel dimension and
    passed through the specified backbone network to produce the prediction of the displacement field (DDF) in the
    non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the stationary velocity field (DVF) in the
    diffeomorphic variant (i.e. when `integration_steps` is set to a positive integer). The DVF is then integrated using
    a scaling-and-squaring approach via a :py:class:`monai.networks.blocks.warp.DVF2DDF` module to produce the DDF.
    Finally, the DDF is used to warp the moving image to the fixed image using a
    :py:class:`monai.networks.blocks.warp.Warp` module. Optionally, the integration from DVF to DDF can be
    performed on reduced resolution by specifying `half_res` to be True, in which case the output DVF from the backbone
    network is first linearly interpolated to half resolution before integration. The output DDF is then linearly
    interpolated again back to full resolution before being used to warp the moving image.

    Args:
        backbone: a backbone network.
        integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring.
            Defaults to 7. If set to 0, the network will be non-diffeomorphic.
        half_res: whether to perform integration on half resolution. Defaults to False.
        spatial_dims: number of spatial dimensions, defaults to 3.

    Example::

        from monai.networks.nets import VoxelMorphUNet, VoxelMorph

        # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper
        # https://arxiv.org/pdf/1809.05231.pdf

        # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network.
        backbone = VoxelMorphUNet(
            spatial_dims=3,
            in_channels=2,
            unet_out_channels=32,
            channels=(16, 32, 32, 32, 32, 32),  # this indicates the down block at the top takes 16 channels as
                                                # input, the corresponding up block at the top produces 32
                                                # channels as output, the second down block takes 32 channels as
                                                # input, and the corresponding up block at the same level
                                                # produces 32 channels as output, etc.
            final_conv_channels=(16, 16)
        )

        # Then, a full VoxelMorph network is constructed using the specified backbone network.
        net = VoxelMorph(
            backbone=backbone,
            integration_steps=7,
            half_res=False
        )

        # A forward pass through the network would look something like this
        moving = torch.randn(1, 1, 160, 192, 224)
        fixed = torch.randn(1, 1, 160, 192, 224)
        warped, ddf = net(moving, fixed)

    N   Fr   backbone!VoxelMorphUNet | nn.Module | Noneintegration_stepsr   half_resr#   r   r'   r(   c                   sz   t    |d ur|nt|ddddd| _|| _|| _|| _| jdkr%dnd| _| jr4t| jd	d
d| _	t
d	d
d| _d S )Nr)       )   rq   rq   rq   rq   rq   )rr   rr   )r   r   r   r   r   r   TFbilinearzeros)	num_stepsr^   padding_mode)r^   rv   )rI   rJ   r	   rm   r   rp   ro   diffeomorphicr   dvf2ddfr   warp)r=   rm   ro   rp   r   rO   r>   r?   rJ     s$   
zVoxelMorph.__init__movingrd   fixed!tuple[torch.Tensor, torch.Tensor]c                 C  s
  |j |j krtd|j  d|j  d| tj||gdd}|j d | jkr6td| j d|j d  d|j d	d  |j d	d  krYtd
|j d	d   d|j d	d   d| jrgtj|ddddd }| j	ro| 
|}| jr}tj|d dddd}| |||fS )NzfThe spatial shape of the moving image should be the same as the spatial shape of the fixed image. Got z and z	 instead.rT   )dimzqThe number of channels in the output of the backbone network should be equal to the number of spatial dimensions z. Got z channels instead.r)   zvThe spatial shape of the output of the backbone network should be equal to the spatial shape of the input images. Got z instead of .g      ?	trilinearT)r]   r^   r`   g       @)shaperK   rm   torchcatr   rp   Finterpolaterw   rx   ry   )r=   rz   r{   re   r>   r>   r?   rf     s@   
zVoxelMorph.forward)Nrl   Fr   )
rm   rn   ro   r   rp   r#   r   r   r'   r(   )rz   rd   r{   rd   r'   r|   )rg   rh   ri   rj   rJ   rf   rk   r>   r>   rO   r?   r   T  s    G!r   )
__future__r   collections.abcr   r   torch.nnr@   torch.nn.functional
functionalr   "monai.networks.blocks.convolutionsr   monai.networks.blocks.upsampler   Zmonai.networks.blocks.warpr   r   "monai.networks.layers.simplelayersr   __all__Moduler	   r
   r   r   r>   r>   r>   r?   <module>   s"     7 	