o
    ,i]                     @  st  d dl mZ d dlmZ d dlmZmZ d dlmZ d dl	Z	d dl
mZ d dlm  mZ d dl	mZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZmZ d dlmZmZ d dlmZ d dl m!Z!m"Z" g dZ#G dd dej$Z%G dd deZ&G dd de"Z'G dd de'Z(G dd de!Z)G dd dej$Z*G dd deZ+G dd deZ,G d d! d!ej$Z-dS )"    )annotations)OrderedDict)CallableSequence)partialN)Tensor)ADN)
SimpleASPP)BackboneWithFPN)Convolution)ExtraFPNBlockFeaturePyramidNetwork)ConvNorm)get_norm_layer)ResNetResNetBottleneck)	AttentionModule	Daf3dASPPDaf3dResNetBottleneckDaf3dResNetDilatedBottleneckDaf3dResNetDaf3dBackboneDaf3dFPNDaf3dBackboneWithFPNDAF3Dc                      s8   e Zd ZdZddddfdf fdd	Zd	d
 Z  ZS )r   a   
    Attention Module as described in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'
    <https://arxiv.org/pdf/1907.01743.pdf>. Returns refined single layer feature (SLF) and attentive map

    Args:
        spatial_dims: dimension of inputs.
        in_channels: number of input channels (channels of slf and mlf).
        out_channels: number of output channels (channels of attentive map and refined slf).
        norm: normalization type.
        act: activation type.
    group    @   
num_groupsnum_channelsPRELUc                   s   t    tt|||d||dt|||dd||dt|||ddddd| _tt|||d||dt|||dd||dt|||dd||d| _d S )N   )kernel_sizenormact   )r$   paddingr%   r&   ASIGMOID)r$   r(   adn_orderingr&   )super__init__nn
Sequentialr   attentive_maprefine)selfspatial_dimsin_channelsout_channelsr%   r&   	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/daf3d.pyr-   :   s   

zAttentionModule.__init__c                 C  s8   |  t||fd}| t||| fd}||fS )Nr#   )r0   torchcatr1   )r2   slfmlfattoutr8   r8   r9   forwardQ   s   zAttentionModule.forward__name__
__module____qualname____doc__r-   r@   __classcell__r8   r8   r6   r9   r   -   s    r   c                      s8   e Zd ZdZ						dd fddZdd Z  ZS )r   a  
    Atrous Spatial Pyramid Pooling module as used in 'Deep Attentive Features for Prostate Segmentation in
    3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>. Core functionality as in SimpleASPP, but after each
    layerwise convolution a group normalization is added. Further weight initialization for convolutions is provided in
    _init_weight(). Additional possibility to specify the number of final output channels.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        in_channels: number of input channels.
        conv_out_channels: number of output channels of each atrous conv.
        out_channels: number of output channels of final convolution.
            If None, uses len(kernel_sizes) * conv_out_channels
        kernel_sizes: a sequence of four convolutional kernel sizes.
            Defaults to (1, 3, 3, 3) for four (dilated) convolutions.
        dilations: a sequence of four convolutional dilation parameters.
            Defaults to (1, 2, 4, 6) for four (dilated) convolutions.
        norm_type: final kernel-size-one convolution normalization type.
            Defaults to batch norm.
        acti_type: final kernel-size-one convolution activation type.
            Defaults to leaky ReLU.
        bias: whether to have a bias term in convolution blocks. Defaults to False.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.

    Raises:
        ValueError: When ``kernel_sizes`` length differs from ``dilations``.
    Nr#   r'   r'   r'   r#            BATCH	LEAKYRELUFr3   intr4   conv_out_channelsr5   
int | Nonekernel_sizesSequence[int]	dilations	norm_typetuple | str | None	acti_typebiasboolreturnNonec
              
     s   t  ||||||||	 t }
| jD ]}tddd}||_td|dd|_| 	|}|

| q|
| _|d u r?t|| }tdt|| |d||d| _d S )Nr#   N)orderingr%   norm_dimr'   )r3   r4   r5   r$   r%   r&   )r,   r-   r.   
ModuleListconvsr   convr   adn_init_weightappendlenconv_k1)r2   r3   r4   rO   r5   rQ   rS   rT   rV   rW   Z	new_convs_convZtmp_convr6   r8   r9   r-   t   s*   


zDaf3dASPP.__init__c                 C  s.   |  D ]}t|tjrtjj|j q|S N)modules
isinstancer.   Conv3dr:   initkaiming_normal_weight)r2   r`   mr8   r8   r9   rb      s
   zDaf3dASPP._init_weight)NrG   rH   rL   rM   F)r3   rN   r4   rN   rO   rN   r5   rP   rQ   rR   rS   rR   rT   rU   rV   rU   rW   rX   rY   rZ   )rB   rC   rD   rE   r-   rb   rF   r8   r8   r6   r9   r   W   s    !'r   c                      s6   e Zd ZdZdZddddddiff fd	d
	Z  ZS )r   a  
    ResNetBottleneck block as used in 'Deep Attentive Features for Prostate Segmentation in 3D
    Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.
    Instead of Batch Norm Group Norm is used, instead of ReLU PReLU activation is used.
    Initial expansion is 2 instead of 4 and second convolution uses groups.

    Args:
        in_planes: number of input channels.
        planes: number of output channels (taking expansion into account).
        spatial_dims: number of spatial dimensions of the input image.
        stride: stride to use for second conv layer.
        downsample: which downsample layer to use.
        norm: which normalization layer to use. Defaults to group.
    rI   r'   r#   Nr   r    r   c           	   	     s   t t j|f }tt||d}t|tjr*t|||| j d|dd||| j d}t 	||||| ||d| _
||d| _||| j d| _|||dd|ddd| _t | _d S )	N)namer3   r#   F)r$   striderW   )channelsr'   r   )r$   r(   rp   groupsrW   )r   CONVr   r   ri   r.   r/   	expansionr,   r-   bn1bn2bn3conv2PReLUrelu)	r2   	in_planesplanesr3   rp   
downsampler%   	conv_type
norm_layerr6   r8   r9   r-      s   zDaf3dResNetBottleneck.__init__)rB   rC   rD   rE   rt   r-   rF   r8   r8   r6   r9   r      s
    r   c                      s2   e Zd ZdZddddddiff fdd		Z  ZS )
r   a-  
    ResNetDilatedBottleneck as used in 'Deep Attentive Features for Prostate Segmentation in 3D
    Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.
    Same as Daf3dResNetBottleneck but dilation of 2 is used in second convolution.
    Args:
        in_planes: number of input channels.
        planes: number of output channels (taking expansion into account).
        spatial_dims: number of spatial dimensions of the input image.
        stride: stride to use for second conv layer.
        downsample: which downsample layer to use.
    r'   r#   Nr   r    r   c              
     sB   t  |||||| ttj|f }|||d|ddddd| _d S )Nr'   rI   r   F)r$   rp   r(   dilationrr   rW   )r,   r-   r   rs   rx   )r2   r{   r|   r3   rp   r}   r%   r~   r6   r8   r9   r-      s
   z%Daf3dResNetDilatedBottleneck.__init__rB   rC   rD   rE   r-   rF   r8   r8   r6   r9   r      s    r   c                      s8   e Zd ZdZ												dd fddZ  ZS ) r   a  
    ResNet as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'
    <https://arxiv.org/pdf/1907.01743.pdf>.
    Uses two Daf3dResNetBottleneck blocks followed by two Daf3dResNetDilatedBottleneck blocks.

    Args:
        layers: how many layers to use.
        block_inplanes: determine the size of planes at each step. Also tunable with widen_factor.
        spatial_dims: number of spatial dimensions of the input image.
        n_input_channels: number of input channels for first convolutional layer.
        conv1_t_size: size of first convolution layer, determines kernel and padding.
        conv1_t_stride: stride of first convolution layer.
        no_max_pool: bool argument to determine if to use maxpool layer.
        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
            - 'A': using `self._downsample_basic_block`.
            - 'B': kernel_size 1 conv + norm.
        widen_factor: widen output for each layer.
        num_classes: number of output (classifications).
        feed_forward: whether to add the FC layer for the output, default to `True`.
        bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.

    r'      r#   FB      ?  Tlayers	list[int]block_inplanesr3   rN   n_input_channelsconv1_t_sizetuple[int] | intconv1_t_strideno_max_poolrX   shortcut_typestrwiden_factorfloatnum_classesfeed_forwardbias_downsamplec                   s   t  t|||||||||	|
|| d| _ttj|f }ttj|f }||| jddddd| _|dd| _	t
 | _| t|d |d ||| _| jt|d	 |d	 ||dd
| _| jt|d |d ||d	d
| _| jt|d |d ||d	d
| _d S )Nr   r   )r#   rI   rI   )r'   r'   r'   F)r$   rp   r(   rW   r   r   r#   )rp   rI   r'   )r,   r-   r   r{   r   rs   r   ZGROUPconv1ru   r.   ry   rz   _make_layerr   layer1layer2r   layer3layer4)r2   r   r   r3   r   r   r   r   r   r   r   r   r   r~   rT   r6   r8   r9   r-     sL   
zDaf3dResNet.__init__)
r'   r'   r   r#   Fr   r   r   TT)r   r   r   r   r3   rN   r   rN   r   r   r   r   r   rX   r   r   r   r   r   rN   r   rX   r   rX   r   r8   r8   r6   r9   r      s    r   c                      s(   e Zd ZdZ fddZdd Z  ZS )r   a   
    Backbone for 3D Feature Pyramid Network in DAF3D module based on 'Deep Attentive Features for Prostate Segmentation in
    3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.

    Args:
        n_input_channels: number of input channels for the first convolution.
    c                   sz   t    tg dg d|ddd}t| }tj|d d  | _tj|dd  | _|d | _	|d | _
|d	 | _d S )
N)r'   rJ   rK   r'   )            rI   F)r   r   r   r   r   r'      rK   r   )r,   r-   r   listchildrenr.   r/   layer0r   r   r   r   )r2   r   netZnet_modulesr6   r8   r9   r-   G  s   


zDaf3dBackbone.__init__c                 C  s6   |  |}| |}| |}| |}| |}|S rg   )r   r   r   r   r   )r2   xr   r   r   r   r   r8   r8   r9   r@   W  s   




zDaf3dBackbone.forwardrA   r8   r8   r6   r9   r   >  s    r   c                      s0   e Zd ZdZ	dd fd
dZdddZ  ZS )r   a0  
    Feature Pyramid Network as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'
    <https://arxiv.org/pdf/1907.01743.pdf>.
    Omits 3x3x3 convolution of layer_blocks and interpolates resulting feature maps to be the same size as
    feature map with highest resolution.

    Args:
        spatial_dims: 2D or 3D images
        in_channels_list: number of channels for each feature map that is passed to the module
        out_channels: number of channels of the FPN representation
        extra_blocks: if provided, extra operations will be performed.
            It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
    Nr3   rN   in_channels_listr   r5   extra_blocksExtraFPNBlock | Nonec                   sf   t  |||| t | _|D ] }|dkrtdt|||ddddddd	fd
}| j| qd S )Nr   z(in_channels=0 is currently not supportedr#   NAr"   r   r   r   r   )r$   r+   r&   r%   )r,   r-   r.   r^   inner_blocks
ValueErrorr   rc   )r2   r3   r   r5   r   r4   inner_block_moduler6   r8   r9   r-   q  s    
	zDaf3dFPN.__init__r   dict[str, Tensor]rY   c                   s   t   }t   }| |d d}g }|| tt|d ddD ]#}| || |}|jdd  }tj	||dd}	||	 }|
d| q%| jd urW| |||\}}|d g fdd|dd  D  }tt t||}
|
S )	NrI   	trilinearsizemoder   c                   s,   g | ]}t j| d   dd ddqS )feat1rI   Nr   r   Finterpolater   ).0lr   r8   r9   
<listcomp>  s   , z$Daf3dFPN.forward.<locals>.<listcomp>r#   )r   keysvaluesget_result_from_inner_blocksrc   rangerd   shaper   r   insertr   r   zip)r2   r   namesx_values
last_innerresultsidxinner_lateral
feat_shapeinner_top_downr?   r8   r   r9   r@     s    

$zDaf3dFPN.forwardrg   )r3   rN   r   r   r5   rN   r   r   )r   r   rY   r   rA   r8   r8   r6   r9   r   `  s
    r   c                      s(   e Zd ZdZ		dd fddZ  ZS )r   a  
    Same as BackboneWithFPN but uses custom Daf3DFPN as feature pyramid network

    Args:
        backbone: backbone network
        return_layers: a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
        in_channels_list: number of channels for each feature map
            that is returned, in the order they are present in the OrderedDict
        out_channels: number of channels in the FPN.
        spatial_dims: 2D or 3D images
        extra_blocks: if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
    Nbackbone	nn.Modulereturn_layersdict[str, str]r   r   r5   rN   r3   rP   r   r   rY   rZ   c                   s   t  |||||| |d u r6t|drt|jtr|j}nt|jtjr(d}nt|jtj	r2d}nt
dt||||| _d S )Nr3   rI   r'   zZCould not determine value of  `spatial_dims` from backbone, please provide explicit value.)r,   r-   hasattrri   r3   rN   r   r.   Conv2drj   r   r   fpn)r2   r   r   r   r5   r3   r   r6   r8   r9   r-     s   	zDaf3dBackboneWithFPN.__init__)NN)r   r   r   r   r   r   r5   rN   r3   rP   r   r   rY   rZ   r   r8   r8   r6   r9   r     s
    r   c                      s*   e Zd ZdZd fdd	Zdd Z  ZS )r   az  
    DAF3D network based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'
    <https://arxiv.org/pdf/1907.01743.pdf>.
    The network consists of a 3D Feature Pyramid Network which is applied on the feature maps of a 3D ResNet,
    followed by a custom Attention Module and an ASPP module.
    During training the supervised signal consists of the outputs of the FPN (four Single Layer Features, SLFs),
    the outputs of the attention module (four Attentive Features) and the final prediction.
    They are individually compared to the ground truth, the final loss consists of a weighted sum of all
    individual losses (see DAF3D tutorial for details).
    There is an additional possiblity to return all supervised signals as well as the Attentive Maps in validation
    mode to visualize inner functionality of the network.

    Args:
        in_channels: number of input channels.
        out_channels: number of output channels.
        visual_output: whether to return all SLFs, Attentive Maps, Refined SLFs in validation mode
            can be used to visualize inner functionality of the network
    Fc                   s  t    || _tt|dddddg dddd	| _tjd|d
d| _ddddf}dd
ddf}t	t
dddd
d||dt
dddddd
||dt
dddddd
||d| _tddd||d| _t
dddd
d||d| _tjd|d
d| _tdddddd|d dd	| _d S )Nr   Zfeat2Zfeat3Zfeat4)r   r   r   r   )r   r   r   i   r   r'   )r   r   r   r5   r3   r#   )r$   r   r   r   r   prelug      ?)num_parametersrk   r   r   )r3   r4   r5   r$   r+   r%   r&   )r3   r4   r5   r$   r+   r(   r%   r&      )r3   r4   r5   r%   r&   r   )r$   r+   r%   r&   )r'   r'   r'   r'   ))r#   r#   r#   )r#   rK   rK   )r#      r   )r#      r   T)	r3   r4   rO   r5   rQ   rS   rT   rV   rW   )r,   r-   visual_outputr   r   backbone_with_fpnr.   rj   predict1r/   r   fuser   	attentionr1   predict2r   aspp)r2   r4   r5   r   
group_normZ	act_prelur6   r8   r9   r-     sv   
	

zDAF3D.__init__c                   s8  t  }fdd|D }t|d  fdd|D }tt| \}}fdd|D }fdd|D }t|d}		|	}

|
}jrg|| |g }fdd|D }|S jrtj| dd  d	d
}fdd|| | D }|g| }|S tj| dd  d	d
}|S )Nc                      g | ]}  |qS r8   )r   r   r<   r2   r8   r9   r   *      z!DAF3D.forward.<locals>.<listcomp>r#   c                   s   g | ]} | qS r8   )r   r   )r=   r2   r8   r9   r   .  s    c                   r   r8   r   )r   afr   r8   r9   r   2  r   c                   r   r8   r   )r   amr   r8   r9   r   5  r   c                   (   g | ]}t j|  d d ddqS rI   Nr   r   r   r   or   r8   r9   r   ?  s   ( rI   r   r   c                   r   r   r   r   r   r8   r9   r   C  s    )r   r   r   r   r:   r;   tupler   r1   r   r   trainingr   r   r   r   )r2   r   Zsingle_layer_featuresZsupervised1Zattentive_features_mapsZatt_featuresZatt_mapsZsupervised2Zsupervised3Zattentive_mlfr   Zsupervised_finaloutputZsupervised_innerr8   )r=   r2   r   r9   r@   %  s.   




zDAF3D.forward)FrA   r8   r8   r6   r9   r     s    ?r   ).
__future__r   collectionsr   collections.abcr   r   	functoolsr   r:   torch.nnr.   torch.nn.functional
functionalr   r   monai.networks.blocksr   Zmonai.networks.blocks.asppr	   Z(monai.networks.blocks.backbone_fpn_utilsr
   "monai.networks.blocks.convolutionsr   Z-monai.networks.blocks.feature_pyramid_networkr   r   monai.networks.layers.factoriesr   r   monai.networks.layers.utilsr   Zmonai.networks.nets.resnetr   r   __all__Moduler   r   r   r   r   r   r   r   r   r8   r8   r8   r9   <module>   s4   *K.U"D.