o
    -i>p                     @  sf  d dl mZ d dlZd dlZd dlZd dlmZ d dlmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZmZ d dlmZmZ d d	lmZmZmZmZ d d
lmZ g dZG dd dejZ G dd dej!Z"G dd dej!Z#G dd dej!Z$G dd dejZ%G dd dej&Z'G dd dejZ(d(ddZ)d)d"d#Z*d*d+d&d'Z+e( Z, Z-Z.dS ),    )annotationsN)OrderedDict)CallableSequence)download_url)UpSample)ConvDropout)get_act_layerget_norm_layer)HoVerNetBranchHoVerNetModeInterpolateModeUpsampleMode)look_up_option)HoVerNetHovernetHoVernetr   c                      s<   e Zd Zddddifdddfd fddZdddZ  ZS )_DenseLayerDecoder        reluinplaceTbatch   r   num_featuresintin_channelsout_channelsdropout_probfloatactstr | tuplenormkernel_sizepaddingreturnNonec	                   s   t    ttjdf }	ttjdf }
t | _| j	dt
|d|d | j	dt|d | j	d|	||ddd	 | j	d
t
|d|d | j	dt|d | j	d|	||||ddd |dkro| j	d|
| dS dS )a1  
        Args:
            num_features: number of internal channels used for the layer
            in_channels: number of the input channels.
            out_channels: number of the output channels.
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
            kernel_size: size of the kernel for >1 convolutions (dependent on mode)
            padding: padding value for >1 convolutions.
           zpreact_bna/bnnamespatial_dimschannelszpreact_bna/relur)   conv1   Fr#   biasz
conv1/normzconv1/relu2conv2   )r#   r$   groupsr0   r   dropoutN)super__init__r   CONVr	   DROPOUTnn
Sequentiallayers
add_moduler   r
   )selfr   r   r   r   r    r"   r#   r$   	conv_typedropout_type	__class__ ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/hovernet.pyr6   5   s    

z_DenseLayerDecoder.__init__xtorch.Tensorc                 C  sn   |  |}|jd |jd kr-|jd |jd  d }|d d d d || || f }t||gd}|S )Nr'   r.   )r;   shapetorchcat)r=   rD   x1trimrB   rB   rC   forward_   s   
$z_DenseLayerDecoder.forward)r   r   r   r   r   r   r   r   r    r!   r"   r!   r#   r   r$   r   r%   r&   rD   rE   r%   rE   __name__
__module____qualname__r6   rL   __classcell__rB   rB   r@   rC   r   3   s    
*r   c                      s2   e Zd Zddddifdddfd fddZ  ZS )_DecoderBlockr   r   r   Tr   r   Fr;   r   r   r   r   r   r   r    r!   r"   r#   same_paddingboolr%   r&   c
                   s   t    ttjdf }
|	r|d nd}| d|
||d ||dd |d }t|D ]}t||||||||d}||7 }| d|d	  | q+t|||d
}| d| | d|
||d	dd dS )a  
        Args:
            layers: number of layers in the block.
            num_features: number of internal features used.
            in_channels: number of the input channel.
            out_channels: number of the output channel.
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
            kernel_size: size of the kernel for >1 convolutions (dependent on mode)
            same_padding: whether to do padding for >1 convolutions to ensure
                the output size is the same as the input size.
        r'   r   convar2   Fr#   r$   r0   )r    r"   r#   r$   zdenselayerdecoder%dr.   r    r"   	bna_blockZconvfr/   N)r5   r6   r   r7   r<   ranger   _Transition)r=   r;   r   r   r   r   r    r"   r#   rT   r>   r$   _in_channelsilayertransr@   rB   rC   r6   l   s.   

z_DecoderBlock.__init__)r;   r   r   r   r   r   r   r   r   r   r    r!   r"   r!   r#   r   rT   rU   r%   r&   rO   rP   rQ   r6   rR   rB   rB   r@   rC   rS   j   s    
rS   c                      s2   e Zd Zddddifdddfd fddZ  ZS )_DenseLayerr   r   r   Tr   r   r   r   r   r   r   r   r   r    r!   r"   drop_first_norm_relur#   r%   r&   c	                   sZ  t    t | _ttjdf }	ttjdf }
|s0| j	dt
|d|d | j	dt|d | j	d|	||ddd	d
 | j	dt
|d|d | j	dt|d |dkrj|rj| j	d|	|||ddd	d n| j	d|	|||dd	d
 | j	dt
|d|d | j	dt|d | j	d|	||ddd	d
 |dkr| j	d|
| dS dS )a2  Dense Convolutional Block.

        References:
            Huang, Gao, et al. "Densely connected convolutional networks."
            Proceedings of the IEEE conference on computer vision and
            pattern recognition. 2017.

        Args:
            num_features: number of internal channels used for the layer
            in_channels: number of the input channels.
            out_channels: number of the output channels.
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
            drop_first_norm_relu - omits the first norm/relu for the first layer
            kernel_size: size of the kernel for >1 convolutions (dependent on mode)
        r'   z	preact/bnr(   zpreact/relur,   r-   r.   r   FrW   zconv1/bnz
conv1/relu@   r1   r#   strider$   r0   zconv2/bnz
conv2/reluconv3r4   N)r5   r6   r9   r:   r;   r   r7   r	   r8   r<   r   r
   )r=   r   r   r   r   r    r"   rb   r#   r>   r?   r@   rB   rC   r6      s.   

z_DenseLayer.__init__)r   r   r   r   r   r   r   r   r    r!   r"   r!   rb   r   r#   r   r%   r&   r`   rB   rB   r@   rC   ra      s    
ra   c                      s,   e Zd Zdddifdfd fddZ  ZS )r[   r   r   Tr   r   r   r    r!   r"   r%   r&   c                   s6   t    | dt|d|d | dt|d dS )z
        Args:
            in_channels: number of the input channel.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        bnr'   r(   r   r,   N)r5   r6   r<   r   r
   )r=   r   r    r"   r@   rB   rC   r6      s   
	z_Transition.__init__)r   r   r    r!   r"   r!   r%   r&   r`   rB   rB   r@   rC   r[      s    r[   c                      s<   e Zd Zddddifdddfd fddZdddZ  ZS )_ResidualBlockr   r   r   Tr   Fr;   r   r   r   r   r   r   r    r!   r"   freeze_dense_layerrU   freeze_blockr%   r&   c
              	     s   t    t | _ttjdf }
|dkr|
||ddd| _n|
||ddddd| _t||||||dd}| j	d	| t
d|D ]}t||||||d
}| j	d| | qAt|||d
| _|rh| jd |	rq| d dS dS )a>  Residual block.

        References:
            He, Kaiming, et al. "Deep residual learning for image
            recognition." Proceedings of the IEEE conference on computer
            vision and pattern recognition. 2016.

        Args:
            layers: number of layers in the block.
            num_features: number of internal features used.
            in_channels: number of the input channel.
            out_channels: number of the output channel.
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
            freeze_dense_layer: whether to freeze all dense layers within the block.
            freeze_block: whether to freeze the whole block.

        r'   rc   r.   Fr/   rd   T)r    r"   rb   Zdenselayer_0rX   Zdenselayer_N)r5   r6   r9   r:   r;   r   r7   shortcutra   r<   rZ   r[   rY   requires_grad_)r=   r;   r   r   r   r   r    r"   ri   rj   r>   r^   r]   r@   rB   rC   r6      s&   

z_ResidualBlock.__init__rD   rE   c                 C  s   |  |}| j jdkr|d d d d d dd df }| jD ]+}||}|jdd  |jdd  krC|d d d d d dd df }|| }|}q| |}|S )N)r'   r'   rF   )rk   re   r;   rL   rG   rY   )r=   rD   scr^   rB   rB   rC   rL   ,  s   
 

 
z_ResidualBlock.forward)r;   r   r   r   r   r   r   r   r   r   r    r!   r"   r!   ri   rU   rj   rU   r%   r&   rM   rN   rB   rB   r@   rC   rh      s    
9rh   c                      s@   e Zd Zddddifddddd	fd  fddZd!ddZ  ZS )"_DecoderBranch)   r2   r   r   Tr   r   r'   r   Fdecode_configSequence[int]r    r!   r"   r   r   r   r   r#   rT   rU   r%   r&   c                   s@  t    ttjdf }d}	d}
d}t | _t|D ]\}}t||
|	||||||d	}| j	d|d  | d}	qt | _
t|}|d d }ttd	|d
d|dd|dfg}| j
	d|d  | ttdt|dddfdt|dfd|d|dddfg}| j
	d|d  | tddtjtjdd| _dS )aA  
        Args:
            decode_config: number of layers for each block.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
            dropout_prob: dropout rate after each dense layer.
            out_channels: number of the output channel.
            kernel_size: size of the kernel for >1 convolutions (dependent on mode)
            same_padding: whether to do padding for >1 convolutions to ensure
                the output size is the same as the input size.
        r'   i          )	r;   r   r   r   r   r    r"   r#   rT   Zdecoderblockr.   i   rV      rc   F)r#   re   r0   r$   rg   r(   r   r,   conv)r#   re   scale_factormodeinterp_moder0   N)r5   r6   r   r7   r9   r:   decoder_blocks	enumeraterS   r<   output_featureslenr   r   r
   r   r   NONTRAINABLEr   BILINEARupsample)r=   rq   r    r"   r   r   r#   rT   r>   r\   _num_features_out_channelsr]   
num_layersblock_i	_pad_sizeZ
_seq_blockr@   rB   rC   r6   A  sR   



z_DecoderBranch.__init__xinrE   
short_cutslist[torch.Tensor]c                 C  s   t |d }|||  }| jD ]7}||}| |}|d8 }|| jd |jd  d }|dkrF||| d d d d || || f 7 }q| jD ]}||}qJ|S )Nr.   rF   r'   r   )r~   r{   r   rG   r}   )r=   r   r   block_numberrD   r   rK   rB   rB   rC   rL     s   

,

z_DecoderBranch.forward)rq   rr   r    r!   r"   r!   r   r   r   r   r#   r   rT   rU   r%   r&   )r   rE   r   r   r%   rE   rN   rB   rB   r@   rC   ro   ?  s    
Iro   c                      sX   e Zd ZdZeZeZejddddddifdd	d
dd	dd	fd' fd d!Z	d(d%d&Z
  ZS ))r   a  HoVerNet model

    References:
      Graham, Simon et al. Hover-net: Simultaneous segmentation
      and classification of nuclei in multi-tissue histology images,
      Medical Image Analysis 2019

      https://github.com/vqdang/hover_net
      https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html

    This network is non-deterministic since it uses `torch.nn.Upsample` with ``UpsampleMode.NONTRAINABLE`` mode which
    is implemented with torch.nn.functional.interpolate(). Please check the link below for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

    Args:
        mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
          a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
        in_channels: number of the input channel.
        np_out_channels: number of the output channel of the nucleus prediction branch.
        out_classes: number of the nuclear type classes.
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.
        decoder_padding: whether to do padding on convolution layers in the decoders. In the conic branch
            of the referred repository, the architecture is changed to do padding on convolution layers in order to
            get the same output size as the input, and this changed version is used on CoNIC challenge.
            Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.
        dropout_prob: dropout rate after each dense layer.
        pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.
            There are two supported forms of weights:
            1. preact-resnet50 weights coming from the referred hover_net
            repository, each user is responsible for checking the content of model/datasets and the applicable licenses
            and determining if suitable for the intended use. please check the following link for more details:
            https://github.com/vqdang/hover_net#data-format
            2. standard resnet50 weights of torchvision. Please check the following link for more details:
            https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html#ResNet50_Weights
        adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this
            value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,
            this value should be `True`.
        pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True.
            It is used to extract the expected state dict.
        freeze_encoder: whether to freeze the encoder of the network.
    r   r'   r   r   r   Tr   Fr   Nry   HoVerNetMode | strr   r   np_out_channelsout_classesr    r!   r"   decoder_paddingrU   r   r   pretrained_url
str | Noneadapt_standard_resnetpretrained_state_dict_keyfreeze_encoderr%   r&   c                   s  t    t|tr| }t|t| _| jdkr"|du r"t	d |dkr*t
d|dkr2t
d|dks:|dk r>t
d	d
}d}| jtjkrMd}d}nd}d}ttjdf }ttd|||dd|ddfdt|d|dfdt|dfg| _|}d}|}t | _t|D ]4\}}d}d}|r|dkrd}nd}t|||||||||d	}| jd| | |}|d9 }|d9 }qt | _| jd|||ddddd tddtjtjdd| _t|||d| _ t||d| _!|dkrt|||dnd | _"| # D ]0}t||rtj$%t&'|j( qt|tj)r.tj$*t&'|j(d tj$*t&'|j+d q|	d urI|
r>t,|	|d}nt-|	}t.| | d S d S ) NORIGINALTzl'decoder_padding=True' only works when mode is 'FAST', otherwise the output size may not equal to the input.rs   z5Number of nuclear types classes exceeds maximum (128)r.   z:Number of nuclear type classes should either be None or >1r   z+Dropout can only be in the range 0.0 to 1.0rc   )r   r2      r   r      r'   rv      Frd   rg   r(   r   r,   ru   )	r;   r   r   r   r   r    r"   ri   rj   dZconv_bottleneckrw   )r#   rT   r   )r#   rT   )r   r#   rT   )state_dict_key)/r5   r6   
isinstancestrupperr   r   ry   warningswarn
ValueErrorFASTr   r7   r9   r:   r   r   r
   conv0
res_blocksr|   rh   r<   
bottleneckr   r   r   r   r   r   ro   nucleus_predictionhorizontal_verticaltype_predictionmodulesinitkaiming_normal_rH   	as_tensorweightBatchNorm2d	constant_r0   _remap_standard_resnet_model_remap_preact_resnet_model_load_pretrained_encoder)r=   ry   r   r   r   r    r"   r   r   r   r   r   r   Z_init_featuresZ_block_configZ_ksize_padr>   r\   r   r   r]   r   ri   rj   r   mweightsr@   rB   rC   r6     s   






zHoVerNet.__init__rD   rE   dict[str, torch.Tensor]c                 C  s   | j tjjkr|jd dks|jd dkrtdn|jd dks(|jd dkr,td| |}g }t| jD ]\}}|	|}|dkrJ|
| q8| |}| |}tjj| ||tjj| ||i}| jd urv| |||tjj< |S )NrF   i  rm   z?Input size should be 270 x 270 when using HoVerNetMode.ORIGINALru   z;Input size should be 256 x 256 when using HoVerNetMode.FASTr'   )ry   r   r   valuerG   r   r   r|   r   rL   appendr   r   r   NPr   HVr   r   NC)r=   rD   r   r]   r   outputrB   rB   rC   rL   F  s*   





zHoVerNet.forward)ry   r   r   r   r   r   r   r   r    r!   r"   r!   r   rU   r   r   r   r   r   rU   r   r   r   rU   r%   r&   )rD   rE   r%   r   )rO   rP   rQ   __doc__r   Moder   ZBranchr   r6   rL   rR   rB   rB   r@   rC   r     s$    +
{r   model	nn.Module
state_dictOrderedDict | dictc                   sr   |     fdd D   |   t dkr)td d S tt dt  d d S )Nc                   s2   i | ]\}}| v r | j | j kr||qS rB   )rG   ).0kv
model_dictr   rB   rC   
<dictcomp>f  s    ,z,_load_pretrained_encoder.<locals>.<dictcomp>r   zcno key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct.z out of z* keys are updated with pretrained weights.)	r   itemsupdateload_state_dictr~   keysr   r   print)r   r   rB   r   rC   r   d  s   

 r   	model_urlr   c           
      C  s  t d}t d}t d}t d}tjtj d}t| d|dd tj	
 r-d ntd	}tj||dd
d }t| D ]H}d }	||rSt |d|}	n%||rxt |d|}	||	rlt |d|	}	n||	rxt |d|	}	|	r|| ||	< ||= d|v r||= qB|S )Nz^(conv0\.\/)(.+)$z^(d\d+)\.(.+)$z^(.+\.d\d+)\.units\.(\d+)(.+)$z^(.+\.d\d+)\.blk_bna\.(.+)zpreact-resnet50.pthTFfuzzyfilepathprogresscpumap_locationweights_onlydesczconv0.conv\2zres_blocks.\1.\2z \1.layers.denselayer_\2.layers\3z\1.bna_block.\2Z
upsample2xrecompileospathjoinrH   hubget_dirr   cudais_availabledeviceloadlistr   matchsub)
r   pattern_conv0pattern_blockZpattern_layerZpattern_bnaweights_dirr   r   keynew_keyrB   rB   rC   r   t  s2   







r   r   r   c                 C  s  t d}t d}t d}t d}t d}t d}t d}tjtj d}	t| d	|	d
d tj	
 r<d ntd}
tj|	|
d	d}|d urQ|| }t| D ]l}d }||rht |d|}nP||rut |d|}nC||rt |dd |}||rt |dd |}n&||rt |d|}n||rt |d|}n||rt |d|}|r|| ||< ||= qW|S )Nz^conv1\.(.+)$z^bn1\.(.+)$z^layer(\d+)\.(\d+)\.(.+)$z@^(res_blocks.d\d+\.layers\.denselayer_)(\d+)\.layers\.bn3\.(.+)$zB^(res_blocks.d\d+\.layers\.denselayer_\d+\.layers)\.bn(\d+)\.(.+)$z)^(res_blocks.d\d+).+\.downsample\.0\.(.+)z)^(res_blocks.d\d+).+\.downsample\.1\.(.+)zresnet50.pthTFr   r   r   zconv0.conv.\1zconv0.bn.\1c                 S  s6   dt t| dd  d | d d | d S )Nzres_blocks.dr.   z.layers.denselayer_r'   z.layers.r   )r   r   groupsrB   rB   rC   <lambda>  s    z._remap_standard_resnet_model.<locals>.<lambda>c                 S  s.   |  dtt|  dd  d |  d S )Nr.   r'   z.layers.preact/bn.r   )r   r   r   r   rB   rB   rC   r     s   . z\1.conv\2/bn.\3z\1.shortcut.\2z\1.bna_block.bn.\2r   )r   r   r   Zpattern_bn1r   Zpattern_block_bn3Zpattern_block_bnZpattern_downsample0Zpattern_downsample1r   r   r   r   r   rB   rB   rC   r     sT   














r   )r   r   r   r   )r   r   )N)r   r   r   r   )/
__future__r   r   r   r   collectionsr   collections.abcr   r   rH   torch.nnr9   monai.apps.utilsr   monai.networks.blocksr   monai.networks.layers.factoriesr   r	   monai.networks.layers.utilsr
   r   monai.utils.enumsr   r   r   r   monai.utils.moduler   __all__Moduler   r:   rS   ra   r[   rh   
ModuleListro   r   r   r   r   r   r   HoverNetrB   rB   rB   rC   <module>   s6   79=N] 
I
6