U
    PhBT                     @  s,  d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	Z	 d dl
m	  mZ d dlmZ d dlmZmZmZmZ ddd	gZG d
d de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd de	jZG dd	 d	e	jZdd Zdd Z e Z!Z"dS )    )annotationsN)Sequence)Union)FCN)ActConvNormPoolAHnetAhnetAHNetc                      s:   e Zd ZdZdddddddd fd	d
Zdd Z  ZS )Bottleneck3x3x1      NintzSequence[int] | intznn.Sequential | NoneNone)spatial_dimsinplanesplanesstride
downsamplereturnc           
        s   t    ttj|f }ttj|f }ttj|f }ttj	 }	|||ddd| _
||| _|||d| d  |d| d  dd| _||| _|||d ddd| _||d | _|	dd	| _|| _|| _|d
| d  d
| d  d| _d S )Nr   F)kernel_sizebias   r   r   r   r   r   r   r   paddingr   r   Tinplacer   r      r   r   )super__init__r   CONVr   BATCHr	   MAXr   RELUconv1bn1conv2bn2conv3bn3relur   r   pool)
selfr   r   r   r   r   	conv_type	norm_type	pool_type	relu_type	__class__ N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/ahnet.pyr%      s,    



zBottleneck3x3x1.__init__c                 C  s   |}|  |}| |}| |}| |}| |}| |}| |}| |}| jd k	r| |}| | kr| 	|}||7 }| |}|S N)
r*   r+   r0   r,   r-   r.   r/   r   sizer1   )r2   xresidualoutr9   r9   r:   forward@   s     











zBottleneck3x3x1.forward)r   N)__name__
__module____qualname__	expansionr%   r@   __classcell__r9   r9   r7   r:   r      s
     !r   c                      s&   e Zd Zdddd fddZ  ZS )
Projectionr   )r   num_input_featuresnum_output_featuresc              
     sp   t    ttj|f }ttj|f }ttj }| d|| | d|dd | d|||dddd d S )	Nnormr0   Tr   convr   Fr   r   r   )	r$   r%   r   r&   r   r'   r   r)   
add_module)r2   r   rG   rH   r3   r4   r6   r7   r9   r:   r%   [   s    

zProjection.__init__rA   rB   rC   r%   rE   r9   r9   r7   r:   rF   Y   s   rF   c                      s,   e Zd Zddddddd fddZ  ZS )
DenseBlockr   float)r   
num_layersrG   bn_sizegrowth_ratedropout_probc           	        sH   t    t|D ]0}t||||  |||}| d|d  | qd S )Nzdenselayer%dr   )r$   r%   rangePseudo3DLayerrL   )	r2   r   rP   rG   rQ   rR   rS   ilayerr7   r9   r:   r%   i   s    	
 
   zDenseBlock.__init__rM   r9   r9   r7   r:   rN   g   s   rN   c                      s*   e Zd Zdddddd fddZ  ZS )UpTransition	transposer   strr   rG   rH   upsample_modec           
   
     s   t    ttj|f }ttj|f }ttj }| d|| | d|dd | d|||dddd |d	krttj	|f }| d
|||dddd n(d }	|dkrd}	| d
t
jd||	d d S )NrI   r0   Tr   rJ   r   FrK   rY   upr"   	trilinearbilinearscale_factormodealign_cornersr$   r%   r   r&   r   r'   r   r)   rL   	CONVTRANSnnUpsample
r2   r   rG   rH   r\   r3   r4   r6   conv_trans_typerd   r7   r9   r:   r%   |   s"    

 zUpTransition.__init__)rY   rM   r9   r9   r7   r:   rX   z   s    rX   c                      s*   e Zd Zdddddd fddZ  ZS )FinalrY   r   rZ   r[   c           
        s   t    ttj|f }ttj|f }ttj }| d|| | d|dd | d|||d| d  dd| d  d	d
 |dkrttj	|f }| d|||ddd	d n(d }	|dkrd}	| dt
jd||	d d S )NrI   r0   Tr   rJ   r   r   r   Fr   rY   r]   r"   rK   r^   ra   re   ri   r7   r9   r:   r%      s6    

 zFinal.__init__)rY   rM   r9   r9   r7   r:   rk      s    rk   c                      s2   e Zd Zdddddd fddZdd Z  ZS )rU   r   rO   )r   rG   rR   rQ   rS   c           	        s  t    ttj|f }ttj|f }ttj }||| _|dd| _	|||| dddd| _
||| | _|dd| _||| |d| d  dd| d  dd| _||| _|dd| _|||d	| d  dd
| d  dd| _||| _|dd| _|||dddd| _|| _d S )NTr   r   FrK   r   r   r   )r   r   r   )r   r   r   )r$   r%   r   r&   r   r'   r   r)   r+   relu1r*   r-   relu2r,   r/   relu3r.   bn4relu4conv4rS   )	r2   r   rG   rR   rQ   rS   r3   r4   r6   r7   r9   r:   r%      s>    


	
	
zPseudo3DLayer.__init__c                 C  s   |}|  |}| |}| |}| |}| |}| |}| |}| |}| |}|| }| 	|}| 
|}| |}d| _| jdkrtj|| j| jd}t||gdS )N        )ptrainingr   )r+   rl   r*   r-   rm   r,   r/   rn   r.   ro   rp   rq   rS   Fdropoutrt   torchcat)r2   r=   ZinxZx3x3x1Zx1x1x3Znew_featuresr9   r9   r:   r@      s$    












zPseudo3DLayer.forwardrA   rB   rC   r%   r@   rE   r9   r9   r7   r:   rU      s   'rU   c                      s:   e Zd Zdddddd fddZdddd	d
Z  ZS )PSPrY   r   rZ   )r   psp_block_numin_chr\   c                   sX  t    t | _ttj|f }ttj|f }t | _	t | _
t|D ]j}d|d  d|d  df| d  }| j	|||d | j
||dd| d  dd| d  d qL|| _|| _|| _| jdkrTttj|f }	t|D ]f}d|d  d|d  df| d  }d|d  d|d  d	f| d  }
| j|	dd|||
d qd S )
Nr"   r   r   r#   )r   r   r   r   r   r   r   rY   r   )r$   r%   rg   
ModuleList
up_modulesr   r&   r	   r(   pool_modulesproject_modulesrT   appendr   r{   r\   rf   )r2   r   r{   r|   r\   r3   r5   rV   r<   rj   pad_sizer7   r9   r:   r%      s*    



$$$$zPSP.__init__ztorch.Tensor)r=   r   c           	      C  s   g }| j dkrHt| j| j| jD ]$\}}}||||}|| q n^t| j| jD ]N\}}|jdd  }d }| j dkr~d}tj||||| j |d}|| qVt	j
|dd}|S )NrY   r"   r^   T)r<   rc   rd   r   dim)r\   zipr   r   r   r   shaperu   interpolaterw   rx   )	r2   r=   outputsZproject_moduleZpool_moduleZ	up_moduleoutputZinterpolate_sizerd   r9   r9   r:   r@     s&    


zPSP.forward)rY   ry   r9   r9   r7   r:   rz      s   rz   c                
      s^   e Zd ZdZdd	d
d
d
d
dddd fddZddd
d
d
ddddZdd Zdd Z  ZS )r   a4	  
    AHNet based on `Anisotropic Hybrid Network <https://arxiv.org/pdf/1711.08580.pdf>`_.
    Adapted from `lsqshr's official code <https://github.com/lsqshr/AH-Net/blob/master/net3d.py>`_.
    Except from the original network that supports 3D inputs, this implementation also supports 2D inputs.
    According to the `tests for deconvolutions <https://github.com/Project-MONAI/MONAI/issues/1023>`_, using
    ``"transpose"`` rather than linear interpolations is faster. Therefore, this implementation sets ``"transpose"``
    as the default upsampling method.

    To meet the requirements of the structure, the input size for each spatial dimension
    (except the last one) should be: divisible by 2 ** (psp_block_num + 3) and no less than 32 in ``transpose`` mode,
    and should be divisible by 32 and no less than 2 ** (psp_block_num + 3) in other upsample modes.
    In addition, the input size for the last spatial dimension should be divisible by 32, and at least one spatial size
    should be no less than 64.

    Args:
        layers: number of residual blocks for 4 layers of the network (layer1...layer4). Defaults to ``(3, 4, 6, 3)``.
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        in_channels: number of input channels for the network. Default to 1.
        out_channels: number of output channels for the network. Defaults to 1.
        psp_block_num: the number of pyramid volumetric pooling modules used at the end of the network before the final
            output layer for extracting multiscale features. The number should be an integer that belongs to [0,4]. Defaults
            to 4.
        upsample_mode: [``"transpose"``, ``"bilinear"``, ``"trilinear"``, ``nearest``]
            The mode of upsampling manipulations.
            Using the last two modes cannot guarantee the model's reproducibility. Defaults to ``transpose``.

            - ``"transpose"``, uses transposed convolution layers.
            - ``"bilinear"``, uses bilinear interpolate.
            - ``"trilinear"``, uses trilinear interpolate.
            - ``"nearest"``, uses nearest interpolate.
        pretrained: whether to load pretrained weights from ResNet50 to initialize convolution layers, default to False.
        progress: If True, displays a progress bar of the download of pretrained weights to stderr.
    r   r      r   r   r   r   rY   FTtupler   rZ   bool)layersr   in_channelsout_channelsr{   r\   
pretrainedprogressc	                    s  d| _ t   ttj|f }	ttj|f }
ttj|f }ttj	|f }t
t
j }ttjdf }ttjdf }|| _|| _|	| _|| _|| _|| _|| _|| _|  |dkrtd|dkrtd|	|dd| d  d| d  d	| d  d
d| _|d| d  d| d  d| _|d| _|dd| _|dkrR|d| d  dd| _n|d| d  ddd| _| jtd|d dd| _| jtd|d dd| _| jtd|d dd| _| jtd|d dd| _d}d}d}d}d}d}d}d}t ||||| _!t"|||||d| _#|||  }t ||||| _$t"|||||d| _%|||  }t ||||| _&t"|||||d| _'|||  }t(|||| _)t"|||||d| _*|||  }t ||||| _+t"|||||d| _,|||  }t-||||| _.t/||| ||| _0| 1 D ]r}t2||	|
frP|j3d |j3d  |j4 }|j5j67dt89d |  n&t2||r|j5j6:d |j;j6<  q|rt=d|d!}| >| d S )"N@   r"   )r"   r   z spatial_dims can only be 2 or 3.)r   r   r"   r   r   z:psp_block_num should be an integer that belongs to [0, 4].)   r   r   )r"   r"   r   r   Fr   r!   r#   Tr   )rY   nearest)r"   r"   r"   )r   r   r   r   r}   r   )r         i   r      r   i   i   rr   g       @)r   r   )?r   r$   r%   r   r&   rf   r   r'   r	   r(   r   r)   conv2d_typenorm2d_typer3   r4   r6   r5   r   r{   AssertionErrorr*   pool1bn0r0   maxpool_make_layerr   layer1layer2layer3layer4rX   up0rN   dense0up1dense1up2dense2rF   trans1dense3up3dense4rz   psprk   finalmodules
isinstancer   r   weightdatanormal_mathsqrtfill_r   zero_r   	copy_from) r2   r   r   r   r   r{   r\   r   r   r3   rj   r4   r5   r6   r   r   ZdensegrowthZdensebnZndenselayerZnum_init_featuresZnoutres1Znoutres2Znoutres3Znoutres4Z	noutdenseZ
noutdense1Z
noutdense2Z
noutdense3Z
noutdense4mnZnet2dr7   r9   r:   r%   R  s    

"

zAHNet.__init__ztype[Bottleneck3x3x1]znn.Sequential)blockr   blocksr   r   c              	   C  s   d }|dks| j ||j krt| j| j ||j d||dfd | j dd| jdd|fd | j dd|fd | j d| ||j }g }||| j| j |||dfd | j | ||j | _ t	d|D ]}||| j| j | qtj| S )Nr   FrK   r#   )
r   rD   rg   
Sequentialr3   r   r5   r4   r   rT   )r2   r   r   r   r   r   r   _r9   r9   r:   r     s0     "zAHNet._make_layerc                 C  s  |  |}| |}| |}| |}|}| |}|}| |}| |}| |}| |}| 	|| }| 
|}	| |	| }
| |
}| || }| |}| || }| |}| || }| |}| jdkr| |}tj||fdd}n|}| |S )Nr   r   r   )r*   r   r   r0   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r{   r   rw   rx   r   )r2   r=   Zconv_xZpool_xZfm1Zfm2Zfm3Zfm4sum0d0Zsum1d1Zsum2d2Zsum3d3Zsum4d4r   r9   r9   r:   r@     s4    














zAHNet.forwardc                 C  sB  t |j t | j  }}|jjddddddd }|d|jd dddg|_t	|j
| j
 tddD ]}dt| }g }g }t|d	 |  D ] }	t|	| j| jfr||	 qt| d	 |  D ] }
t|
| j| jfr||
 qt||D ]:\}	}
t|	| jr t|	|
 t|	| jr t	|	|
 q qrd S )
Nr   r   r   r"   r   r      rW   _modules)nextr*   
parametersr   	unsqueezepermuteclonerepeatr   copy_bn_paramr   rT   rZ   varsr   r   r   r   r   r4   r3   r   copy_conv_param)r2   netp2dp3dweightsrV   Z	layer_numZlayer_2dZlayer_3dm1m2r9   r9   r:   r     s&     
zAHNet.copy_from)r   r   r   r   r   rY   FT)r   )	rA   rB   rC   __doc__r%   r   r@   r   rE   r9   r9   r7   r:   r   /  s   $        "k#c                 C  sD   t |  | D ],\}}|jjdd d d  |jd d < qd S )Nr   r   )r   r   r   r   r   Zmodule2dZmodule3dr   r   r9   r9   r:   r     s    r   c                 C  s8   t |  | D ] \}}|jd d  |jd d < qd S r;   )r   r   r   r   r9   r9   r:   r     s    r   )#
__future__r   r   collections.abcr   typingr   rw   torch.nnrg   torch.nn.functional
functionalru   Zmonai.networks.blocks.fcnr   monai.networks.layers.factoriesr   r   r   r	   __all__Moduler   r   rF   rN   rX   rk   rU   rz   r   r   r   r
   r   r9   r9   r9   r:   <module>   s*   
=$C4 k