U
    PhW=                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlZd dl	m
Z
 d dlmZ d dlmZmZmZmZ d dlmZmZ d dlmZmZ d	gZdd
ddddZd
dddZG dd de
jZG dd de
jZG dd	 d	e
jZdS )    )annotations)Callable)UnionN)UpSample)ActConvNorm
split_args)get_act_layerget_norm_layer)UpsampleMode
has_optionSegResNetDSztuple | listz
int | None)
resolutionn_stagesc                   s   t | }t| }t|dks&tdttt|| tj	  fddt
t D }|r|t kr|d| g|t    }n|d| }|S )aV  
    A helper function to compute a schedule of scale at different downsampling levels,
    given the input resolution.

    .. code-block:: python

        scales_for_resolution(resolution=[1,1,5], n_stages=5)

    Args:
        resolution: input image resolution (in mm)
        n_stages: optionally the number of stages of the network
    r   zResolution must be positivec                   s*   g | ]"}t td | d   kdd qS )      )tuplenpwhere).0inl U/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/segresnet_ds.py
<listcomp>1   s     z)scales_for_resolution.<locals>.<listcomp>r   N)lenr   arrayall
ValueErrorfloorlog2maxastypeint32range)r   r   ndimresscalesr   r   r   scales_for_resolution   s    
"r+   scalec                   s2    fddt t D }dd |D }|| fS )z
    A helper function to compute kernel_size, padding and stride for the given scale

    Args:
        scale: scale from a current scale level
    c                   s    g | ]} | d krdnd qS )r      r   r   kr,   r   r   r   @   s     z aniso_kernel.<locals>.<listcomp>c                 S  s   g | ]}|d  qS r   r   r/   r   r   r   r   A   s     )r'   r   )r-   kernel_sizepaddingr   r,   r   aniso_kernel9   s    r3   c                      s:   e Zd ZdZdddddddd fd	d
Zdd Z  ZS )SegResBlockz
    Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    r.   reluinttuple | strztuple | intNone)spatial_dimsin_channelsnormr1   actreturnc                   s   t    t|ttfr,tdd |D }n|d }t|||d| _t|| _t	t	j
|f |||d|dd| _t|||d| _t|| _t	t	j
|f |||d|dd| _dS )	aY  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
            in_channels: number of input channels.
            norm: feature normalization type and arguments.
            kernel_size: convolution kernel size. Defaults to 3.
            act: activation type and arguments. Defaults to ``RELU``.
        c                 s  s   | ]}|d  V  qdS )r   Nr   r/   r   r   r   	<genexpr>^   s     z'SegResBlock.__init__.<locals>.<genexpr>r   )namer9   channelsr   F)r:   out_channelsr1   strider2   biasN)super__init__
isinstancer   listr   norm1r
   act1r   CONVconv1norm2act2conv2)selfr9   r:   r;   r1   r<   r2   	__class__r   r   rE   K   s0    

	
zSegResBlock.__init__c                 C  s8   |}|  | | | | | |}||7 }|S N)rN   rM   rL   rK   rI   rH   )rO   xidentityr   r   r   forwardx   s    (zSegResBlock.forward)r.   r5   )__name__
__module____qualname____doc__rE   rU   __classcell__r   r   rP   r   r4   E   s
   
  -r4   c                
      sV   e Zd ZdZdd	d	d	d
d
dddd fddZdddddZdddddZ  ZS )SegResEncodera~  
    SegResEncoder based on the encoder structure in `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``BATCH``.
        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
        head_module: optional callable module to apply to the final features.
        anisotropic_scales: optional list of scale for each scale level.
    r.       r   r5   batchr   r   r      Nr6   r7   r   znn.Module | Nonetuple | None)r9   init_filtersr:   r<   r;   blocks_downhead_moduleanisotropic_scalesc	              	     s  t    dkrtdtttd f drHd dd t  tt d  drr d dd ||rt|d nd\}	}
t	t	j
f ||	dd	d
| _t | _tt|D ]}t }|rt|| nd\}	} fddt|| D }tj| |d< |t|d k rVt	t	j
f d d	||	d|d< nt |d< | j| d9 q|| _|| _|| _|| _| _ | _| _d S )Nr   r   r.   %`spatial_dims` can only be 1, 2 or 3.r   affiner   Tinplace)r.   r   r   F)r:   rA   r1   r2   rB   rC   r.   r   r   c              	     s   g | ]}t  d qS )r9   r:   r1   r;   r<   r4   r   _r<   filtersr1   r;   r9   r   r   r      s   z*SegResEncoder.__init__.<locals>.<listcomp>blocksr   )r:   rA   rC   r1   rB   r2   
downsample)rD   rE   r!   r	   r   r   
setdefaultr   r3   r   rJ   	conv_initnn
ModuleListlayersr'   r   
ModuleDict
SequentialIdentityappendrc   r:   rb   ra   r;   r<   r9   )rO   r9   ra   r:   r<   r;   rb   rc   rd   r2   rm   r   levelrB   rp   rP   rn   r   rE      s\    


	
zSegResEncoder.__init__torch.Tensorzlist[torch.Tensor]rS   r=   c                 C  sT   g }|  |}| jD ]&}|d |}|| |d |}q| jd k	rP| |}|S )Nrp   rq   )rs   rv   rz   rc   )rO   rS   outputsr{   r   r   r   _forward   s    




zSegResEncoder._forwardc                 C  s
   |  |S rR   r   rO   rS   r   r   r   rU      s    zSegResEncoder.forward)r.   r\   r   r5   r]   r^   NN)rV   rW   rX   rY   rE   r   rU   rZ   r   r   rP   r   r[      s           "Hr[   c                      sn   e Zd ZdZdddddddddddddd fddZdd Zdd ZdddddZdddddZ  Z	S ) r   a  
    SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    It is similar to https://docs.monai.io/en/stable/networks.html#segresnet, with several
    improvements including deep supervision and non-isotropic kernel support.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``BATCH``.
        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of upsample blocks (optional).
        dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
                 At dsdepth==1,only a single output is returned.
        preprocess: optional callable function to apply before the model's forward pass
        resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
                    image spacing into an approximately isotropic space.
                    Otherwise, by default, the kernel size and downsampling is always isotropic.

    r.   r\   r   r   r5   r]   r^   Ndeconvr6   r7   r   r`   znn.Module | Callable | NonezUpsampleMode | str)r9   ra   r:   rA   r<   r;   rb   	blocks_updsdepth
preprocessupsample_moder   c                   s  t    dkrtd| _|| _|| _|| _ | _| _|| _	t
|	d| _|| _|
| _|d k	rt|ttfs|tdntdd |D stdtttd f d	rĈd d	d
 t  tt d  dr d dd
 d }|rt|t|d}|| _t|| ||d| _t|d }|d krDd| }|| _|d|  t | _ t!|D ]}d |rt"|t|| d  nd\}}t# }t$|d |ddd|d<  fddt!|| D }tj%| |d< t|| |	kr&t&t&j'f |dd
d|d< nt( |d< | j )| qh|dkrt#t( t( t&t&j'f |dd
dd}| j )| d S )Nre   rf   r   zresolution must be a tuplec                 s  s   | ]}|d kV  qdS )r   Nr   )r   rr   r   r   r>   %  s     z'SegResNetDS.__init__.<locals>.<genexpr>zresolution must be positiver   rg   Trh   )r   )r9   ra   r:   r<   r;   rb   rd   )r   r   ri   F)moder9   r:   rA   r1   scale_factorrC   align_cornersupsamplec              	     s   g | ]}t  d qS rj   rk   rl   rn   r   r   r   Z  s   z(SegResNetDS.__init__.<locals>.<listcomp>rp   )r:   rA   r1   rC   head)r   rp   r   )*rD   rE   r!   r9   ra   r:   rA   r<   r;   rb   r$   r   r   r   rF   rG   r   	TypeErrorr    r	   r   r   rr   r   r+   r   rd   r[   encoderr   rt   ru   	up_layersr'   r3   rw   r   rx   r   rJ   ry   rz   )rO   r9   ra   r:   rA   r<   r;   rb   r   r   r   r   r   rd   n_upr   rm   rB   r{   rp   rP   rn   r   rE     s    




 


   
   	zSegResNetDS.__init__c                 C  sL   | j dkr&dt| jd  g| j }n"ttjt| j dd dd}|S )zb
        Calculate the factors (divisors) that the input image shape must be divisible by
        Nr   r   r   )axis)rd   r   rb   r9   rG   r   prodr   )rO   dr   r   r   shape_factoru  s    
"zSegResNetDS.shape_factorc                 C  s*   dd t |jdd |  D }t|S )zx
        Calculate if the input shape is divisible by the minimum factors for the current network configuration
        c                 S  s   g | ]\}}|| d kqS )r   r   )r   r   jr   r   r   r     s     z.SegResNetDS.is_valid_shape.<locals>.<listcomp>r   N)zipshaper   r    )rO   rS   ar   r   r   is_valid_shape  s    "zSegResNetDS.is_valid_shaper|   z-Union[None, torch.Tensor, list[torch.Tensor]]r}   c                 C  s  | j d k	r|  |}| |s8td|j d|   | |}|  |d}t|dkrvt	j
d|j|jdg}g }d}| jD ]X}|d |}||d7 }|d |}t| j| | jkr||d | |d }q|  | jrt|dkr|d S |S )	NzInput spatial dims z must be divisible by r   r   )devicedtyper   rp   r   )r   r   r!   r   r   r   reversepopr   torchzerosr   r   r   r   rz   training)rO   rS   Zx_downr~   r   r{   r   r   r   r     s,    






zSegResNetDS._forwardc                 C  s
   |  |S rR   r   r   r   r   r   rU     s    zSegResNetDS.forward)r.   r\   r   r   r5   r]   r^   Nr   Nr   N)
rV   rW   rX   rY   rE   r   r   r   rU   rZ   r   r   rP   r   r      s$               *r
$)N)
__future__r   collections.abcr   typingr   numpyr   r   torch.nnrt   monai.networks.blocks.upsampler   monai.networks.layers.factoriesr   r   r   r	   monai.networks.layers.utilsr
   r   monai.utilsr   r   __all__r+   r3   Moduler4   r[   r   r   r   r   r   <module>   s   :k