o
    ,i1Q                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dl	Z	d dl
mZ d dlmZ d dlmZmZmZmZ d dlmZmZ d dlmZmZ d	d
gZddddZdddZG dd dejZG dd dejZG dd	 d	ejZG dd
 d
eZ dS )    )annotationsN)Callable)Union)UpSample)ActConvNorm
split_args)get_act_layerget_norm_layer)UpsampleMode
has_optionSegResNetDSSegResNetDS2
resolutiontuple | listn_stages
int | Nonec                   s   t | }t| }t|dkstdttt|| tj	  fddt
t D }|rH|t krH|d| g|t    }|S |d| }|S )aV  
    A helper function to compute a schedule of scale at different downsampling levels,
    given the input resolution.

    .. code-block:: python

        scales_for_resolution(resolution=[1,1,5], n_stages=5)

    Args:
        resolution: input image resolution (in mm)
        n_stages: optionally the number of stages of the network
    r   zResolution must be positivec                   s*   g | ]}t td | d   kdd qS )      )tuplenpwhere).0inl b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/segresnet_ds.py
<listcomp>2   s   * z)scales_for_resolution.<locals>.<listcomp>r   N)lenr   arrayall
ValueErrorfloorlog2maxastypeint32range)r   r   ndimresscalesr   r   r   scales_for_resolution   s   
"r.   scalec                   s2    fddt t D }dd |D }|| fS )z
    A helper function to compute kernel_size, padding and stride for the given scale

    Args:
        scale: scale from a current scale level
    c                   s    g | ]} | d krdnd qS )r      r   r   kr/   r   r   r   A   s     z aniso_kernel.<locals>.<listcomp>c                 S  s   g | ]}|d  qS r    r   r1   r   r   r   r   B   s    )r*   r!   )r/   kernel_sizepaddingr   r3   r   aniso_kernel:   s   
r6   c                      s0   e Zd ZdZ		dd fddZdd Z  ZS )SegResBlockz
    Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    r0   reluspatial_dimsintin_channelsnormtuple | strr4   tuple | intactreturnNonec                   s   t    t|ttfrtdd |D }n|d }t|||d| _t|| _t	t	j
|f |||d|dd| _t|||d| _t|| _t	t	j
|f |||d|dd| _dS )	aY  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
            in_channels: number of input channels.
            norm: feature normalization type and arguments.
            kernel_size: convolution kernel size. Defaults to 3.
            act: activation type and arguments. Defaults to ``RELU``.
        c                 s  s    | ]}|d  V  qdS )r   Nr   r1   r   r   r   	<genexpr>_       z'SegResBlock.__init__.<locals>.<genexpr>r   )namer9   channelsr   F)r;   out_channelsr4   strider5   biasN)super__init__
isinstancer   listr   norm1r
   act1r   CONVconv1norm2act2conv2)selfr9   r;   r<   r4   r?   r5   	__class__r   r   rJ   L   s0   

	
zSegResBlock.__init__c                 C  s8   |}|  | | | | | |}||7 }|S N)rS   rR   rQ   rP   rN   rM   )rT   xidentityr   r   r   forwardy   s   (zSegResBlock.forward)r0   r8   )r9   r:   r;   r:   r<   r=   r4   r>   r?   r=   r@   rA   )__name__
__module____qualname____doc__rJ   rZ   __classcell__r   r   rU   r   r7   F   s    
-r7   c                      sH   e Zd ZdZ								d d! fddZd"ddZd"ddZ  ZS )#SegResEncodera~  
    SegResEncoder based on the encoder structure in `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``BATCH``.
        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
        head_module: optional callable module to apply to the final features.
        anisotropic_scales: optional list of scale for each scale level.
    r0       r   r8   batchr   r   r      Nr9   r:   init_filtersr;   r?   r=   r<   blocks_downr   head_modulenn.Module | Noneanisotropic_scalestuple | Nonec	              	     s  t    dvrtdtttd f dr$d dd t  tt d  dr9 d dd ||rCt|d nd\}	}
t	t	j
f ||	dd	d
| _t | _tt|D ]W}t }|rqt|| nd\}	} fddt|| D }tj| |d< |t|d k rt	t	j
f d d	||	d|d< nt |d< | j| d9 qc|| _|| _|| _|| _| _ | _| _d S )Nr   r   r0   %`spatial_dims` can only be 1, 2 or 3.r   affiner   Tinplace)r0   r   r   F)r;   rF   r4   r5   rG   rH   r0   r   r   c              	        g | ]}t  d qS )r9   r;   r4   r<   r?   r7   r   _r?   filtersr4   r<   r9   r   r   r          z*SegResEncoder.__init__.<locals>.<listcomp>blocksr   )r;   rF   rH   r4   rG   r5   
downsample)rI   rJ   r$   r	   r   r   
setdefaultr   r6   r   rO   	conv_initnn
ModuleListlayersr*   r!   
ModuleDict
SequentialIdentityappendrg   r;   rf   re   r<   r?   r9   )rT   r9   re   r;   r?   r<   rf   rg   ri   r5   rt   r   levelrG   rx   rU   ru   r   rJ      s\   


	

zSegResEncoder.__init__rX   torch.Tensorr@   list[torch.Tensor]c                 C  sT   g }|  |}| jD ]}|d |}|| |d |}q
| jd ur(| |}|S )Nrx   ry   )r{   r~   r   rg   )rT   rX   outputsr   r   r   r   _forward   s   




zSegResEncoder._forwardc                 C  
   |  |S rW   r   rT   rX   r   r   r   rZ         
zSegResEncoder.forward)r0   ra   r   r8   rb   rc   NN)r9   r:   re   r:   r;   r:   r?   r=   r<   r=   rf   r   rg   rh   ri   rj   )rX   r   r@   r   )r[   r\   r]   r^   rJ   r   rZ   r_   r   r   rU   r   r`      s    
Hr`   c                      s`   e Zd ZdZ													
		d+d, fddZdd  Zd!d" Zd-d'd(Zd-d)d*Z  Z	S ).r   a  
    SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    It is similar to https://docs.monai.io/en/stable/networks.html#segresnet, with several
    improvements including deep supervision and non-isotropic kernel support.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``BATCH``.
        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of upsample blocks (optional).
        dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
                 At dsdepth==1,only a single output is returned.
        preprocess: optional callable function to apply before the model's forward pass
        resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
                    image spacing into an approximately isotropic space.
                    Otherwise, by default, the kernel size and downsampling is always isotropic.

    r0   ra   r   r   r8   rb   rc   Ndeconvr9   r:   re   r;   rF   r?   r=   r<   rf   r   	blocks_uprj   dsdepth
preprocessnn.Module | Callable | Noneupsample_modeUpsampleMode | strr   c                   s  t    dvrtd| _|| _|| _|| _ | _| _|| _	t
|	d| _|| _|
| _|d urJt|ttfs=tdtdd |D sJtdtttd f d	rad d	d
 t  tt d  drv d dd
 d }|rt|t|d}|| _t|| ||d| _t|d }|d u rd| }|| _|d|  t | _ t!|D ]j}d |rt"|t|| d  nd\}}t# }t$|d |ddd|d<  fddt!|| D }tj%| |d< t|| |	krt&t&j'f |dd
d|d< nt( |d< | j )| q|dkrAt#t( t( t&t&j'f |dd
dd}| j )| d S d S )Nrk   rl   r   zresolution must be a tuplec                 s  s    | ]}|d kV  qdS )r   Nr   )r   rr   r   r   rB   &  rC   z'SegResNetDS.__init__.<locals>.<genexpr>zresolution must be positiver   rm   Trn   )r   )r9   re   r;   r?   r<   rf   ri   )r   r   ro   F)moder9   r;   rF   r4   scale_factorrH   align_cornersupsamplec              	     rp   rq   rr   rs   ru   r   r   r   [  rw   z(SegResNetDS.__init__.<locals>.<listcomp>rx   )r;   rF   r4   rH   head)r   rx   r   )*rI   rJ   r$   r9   re   r;   rF   r?   r<   rf   r'   r   r   r   rK   rL   r   	TypeErrorr#   r	   r   r   rz   r   r.   r!   ri   r`   encoderr   r|   r}   	up_layersr*   r6   r   r   r   r   rO   r   r   )rT   r9   re   r;   rF   r?   r<   rf   r   r   r   r   r   ri   n_upr   rt   rG   r   rx   rU   ru   r   rJ     s   






	zSegResNetDS.__init__c                 C  sN   | j du rdt| jd  g| j }|S ttjt| j dd dd}|S )zb
        Calculate the factors (divisors) that the input image shape must be divisible by
        Nr   r   r   )axis)ri   r!   rf   r9   rL   r   prodr"   )rT   dr   r   r   shape_factorv  s
   
"zSegResNetDS.shape_factorc                 C  s*   dd t |jdd |  D }t|S )zx
        Calculate if the input shape is divisible by the minimum factors for the current network configuration
        c                 S  s   g | ]
\}}|| d kqS )r   r   )r   r   jr   r   r   r     s    z.SegResNetDS.is_valid_shape.<locals>.<listcomp>r   N)zipshaper   r#   )rT   rX   ar   r   r   is_valid_shape  s   "zSegResNetDS.is_valid_shaperX   r   r@   -Union[None, torch.Tensor, list[torch.Tensor]]c                 C  s  | j d ur
|  |}| |std|j d|   | |}|  |d}t|dkr;t	j
d|j|jdg}g }d}| jD ],}|d |}||d7 }|d |}t| j| | jkrj||d | |d }qB|  | jr|t|dkr|d S |S )	NInput spatial dims  must be divisible by r   r   devicedtyper   rx   r   )r   r   r$   r   r   r   reversepopr!   torchzerosr   r   r   r   r   training)rT   rX   x_downr   r   r   r   r   r   r     s,   






zSegResNetDS._forwardc                 C  r   rW   r   r   r   r   r   rZ     r   zSegResNetDS.forwardr0   ra   r   r   r8   rb   rc   Nr   Nr   Nr9   r:   re   r:   r;   r:   rF   r:   r?   r=   r<   r=   rf   r   r   rj   r   r:   r   r   r   r   r   rj   )rX   r   r@   r   )
r[   r\   r]   r^   rJ   r   r   r   rZ   r_   r   r   rU   r   r      s&    r

$c                      sT   e Zd ZdZ													
		d,d- fddZ	d.d/d'd(Zd0d*d+Z  ZS )1r   aL  
    SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
     <https://arxiv.org/abs/2406.05285>`_.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``BATCH``.
        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of upsample blocks (optional).
        dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
                 At dsdepth==1,only a single output is returned.
        preprocess: optional callable function to apply before the model's forward pass
        resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
                    image spacing into an approximately isotropic space.
                    Otherwise, by default, the kernel size and downsampling is always isotropic.

    r0   ra   r   r   r8   rb   rc   Nr   r9   r:   re   r;   rF   r?   r=   r<   rf   r   r   rj   r   r   r   r   r   r   c                   s@   t  j|||||||||	|
||d tdd | jD | _d S )N)r9   re   r;   rF   r?   r<   rf   r   r   r   r   r   c                 S  s   g | ]}t |qS r   )copydeepcopy)r   layerr   r   r   r     s    z)SegResNetDS2.__init__.<locals>.<listcomp>)rI   rJ   r|   r}   r   up_layers_auto)rT   r9   re   r;   rF   r?   r<   rf   r   r   r   r   r   rU   r   r   rJ     s   zSegResNetDS2.__init__TrX   r   
with_pointbool
with_labelr@   ctuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]c           
      C  s  | j dur
|  |}| |std|j d|   | |}|  |d}t|dkr;t	j
d|j|jdg}g }g }|}|r~|rI| }d}| jD ]+}	|	d |}|||  }|	d |}t| j| | jkru||	d	 | |d }qN|  |}|rd}| jD ]+}	|	d |}|||  }|	d |}t| j| | jkr||	d	 | |d }q|  t|dkr|d n|t|dkr|d fS |fS )
z
        Args:
            x: input tensor.
            with_point: if true, return the point branch output.
            with_label: if true, return the label branch output.
        Nr   r   r   r   r   r   rx   r   )r   r   r$   r   r   r   r   r   r!   r   r   r   r   cloner   r   r   r   )
rT   rX   r   r   r   r   Zoutputs_autox_r   r   r   r   r   rZ     sH   
	







2zSegResNetDS2.forwardFc                 C  sR   | j  D ]	}| o| |_q| j D ]}| |_q| j D ]}| |_q dS )z
        Args:
            auto_freeze: if true, freeze the image encoder and the auto-branch.
            point_freeze: if true, freeze the image encoder and the point-branch.
        N)r   
parametersrequires_gradr   r   )rT   Zauto_freezeZpoint_freezeparamr   r   r   set_auto_grad  s   

zSegResNetDS2.set_auto_gradr   r   )TT)rX   r   r   r   r   r   r@   r   )FF)r[   r\   r]   r^   rJ   rZ   r   r_   r   r   rU   r   r     s$    !9rW   )r   r   r   r   )r/   r   )!
__future__r   r   collections.abcr   typingr   numpyr   r   torch.nnr|   monai.networks.blocks.upsampler   monai.networks.layers.factoriesr   r   r   r	   monai.networks.layers.utilsr
   r   monai.utilsr   r   __all__r.   r6   Moduler7   r`   r   r   r   r   r   r   <module>   s&   
:k E