o
    -iS                     @  s  d dl mZ d dlZd dlZd dlZd dlmZ d dlmZ d dl	Z	d dl	m
Z
 d dlmZ d dlmZ d dlmZmZmZmZ d d	lmZ d d
lmZ g dZddddddddddd
Zi dddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8ZG d9d: d:e
jZG d;d< d<e
jZG d=d> d>eZG d?d@ d@eZG dAdB dBeeZ dudGdHZ!dvdOdPZ"dwdWdXZ#dxd_d`Z$dydcddZ%dzdidjZ&d{dmdnZ'd|dqdrZ(G dsdt dteZ)dS )}    )annotationsN)reduce)
NamedTuple)nn)	model_zoo)BaseEncoder)ActConvPadPool)get_norm_layer)look_up_option)EfficientNetEfficientNetBNget_efficientnet_image_sizedrop_connectEfficientNetBNFeatures	BlockArgsEfficientNetEncoder)      ?r      皙?r   )r   皙?   r   r   )r   333333?i  333333?r   )r   ffffff?i,  r   r   )r   ?i|  皙?r   )g?皙@i  r   r   )r   g@i        ?r   )       @g@iX  r    r   )r   g@i  r    r   )g333333@g333333@i   r    r   
efficientnet-b0efficientnet-b1efficientnet-b2efficientnet-b3efficientnet-b4efficientnet-b5efficientnet-b6efficientnet-b7zefficientnet-b8zefficientnet-l2r#   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pthr$   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pthr%   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pthr&   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pthr'   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pthr(   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pthr)   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pthr*   zdhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pthzb0-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pthzb1-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pthzb2-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pthzb3-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pthzb4-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pthzb5-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pthzb6-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pthzb7-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pthzb8-apzhhttps://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pthc                      sF   e Zd Zdddddfdfd# fddZd$ddZd%d&d!d"Z  ZS )'MBConvBlockTbatchMbP?{Gz?epsmomentumr   spatial_dimsintin_channelsout_channelskernel_sizestride
image_size	list[int]expand_ratiose_ratiofloat | Noneid_skipbool | Nonenormstr | tupledrop_connect_ratereturnNonec                   s  t    td|f }td|f }|| _|| _|	| _|| _|| _|| _	|dur:d|  k r1dkr:n nd| _
|| _nd| _
|}|| }| jdkra|||ddd	| _t| j|| _t|
||d
| _nt | _t | _t | _|||||| jdd| _t| j|| _t|
||d
| _t|| j}| j
r|d| _tdt|| j }|||dd| _t| jdg| | _|||dd| _t| jdg| | _|}|||ddd	| _t| j|| _t|
||d
| _ t!d dd| _"dS )a  
        Mobile Inverted Residual Bottleneck Block.

        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernel_size: size of the kernel for conv ops.
            stride: stride to use for conv ops.
            image_size: input image resolution.
            expand_ratio: expansion ratio for inverted bottleneck.
            se_ratio: squeeze-excitation ratio for se layers.
            id_skip: whether to use skip connection.
            norm: feature normalization type and arguments. Defaults to batch norm.
            drop_connect_rate: dropconnect rate for drop connection (individual weights) layers.

        References:
            [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
            [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
            [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
        convadaptiveavgN        r   TF   )r4   r5   r6   biasnamer2   channels)r4   r5   groupsr6   r7   rH   )r4   r5   r6   memswishinplace)#super__init__r	   r   r4   r5   r=   r7   r:   rA   has_ser;   _expand_conv_make_same_padder_expand_conv_paddingr   _bn0r   Identity_depthwise_conv_depthwise_conv_padding_bn1_calculate_output_image_size_se_adaptpoolmaxr3   
_se_reduce_se_reduce_padding
_se_expand_se_expand_padding_project_conv_project_conv_padding_bn2r   _swish)selfr2   r4   r5   r6   r7   r8   r:   r;   r=   r?   rA   	conv_typeadaptivepool_typeinpoupZnum_squeezed_channelsZ	final_oup	__class__ b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/efficientnet.pyrQ   M   sZ   
# 




zMBConvBlock.__init__inputstorch.Tensorc                 C  s   |}| j dkr| | |}| |}| |}| | |}| |}| |}| jrO| 	|}| 
| |}| |}| | |}t|| }| | |}| |}| jrz| jdkrz| j| jkrz| jrvt|| j| jd}|| }|S )zMBConvBlock"s forward function.

        Args:
            inputs: Input tensor.

        Returns:
            Output of this block after processing.
        rG   )ptraining)r:   rS   rU   rV   re   rX   rY   rZ   rR   r\   r^   r_   r`   ra   torchsigmoidrb   rc   rd   r=   r7   r4   r5   rA   r   rr   )rf   ro   xZ
x_squeezedrm   rm   rn   forward   s*   








zMBConvBlock.forwardmemory_efficientboolc                 C  s,   |rt d dd| _dS t d dd| _dS )zSets swish function as memory efficient (for training) or standard (for export).

        Args:
            memory_efficient (bool): Whether to use memory-efficient version of swish.
        rM   TrN   swishr   alphaN)r   re   )rf   rw   rm   rm   rn   	set_swish   s   ,zMBConvBlock.set_swish)r2   r3   r4   r3   r5   r3   r6   r3   r7   r3   r8   r9   r:   r3   r;   r<   r=   r>   r?   r@   rA   r<   rB   rC   ro   rp   Trw   rx   rB   rC   )__name__
__module____qualname__rQ   rv   r|   __classcell__rm   rm   rk   rn   r+   K   s    
h(r+   c                      s^   e Zd Zdddddddddd	d
fddf
d* fddZd+d,d"d#Zd-d&d'Zd.d(d)Z  ZS )/r           r   r   r   r,   r-   r.   r/      blocks_args_str	list[str]r2   r3   r4   num_classeswidth_coefficientfloatdepth_coefficientdropout_rater8   r?   r@   rA   depth_divisorrB   rC   c                   s  t    |dvrtdtd|f }td|f }dd |D }t|ts)td|g kr1td|| _|| _|| _	|
| _
|g| }d	}td
||}|| j	|d|dd| _t| j|| _t|	||d| _t||}d}t | _g | _t| jD ]/\}}|jt|j||t|j||t|j|d}|| j|< ||j7 }|jdkr| j| qx| jt| j d}t| jD ]\}}| j
}|r|t|| 9 }t }| t!|t"||j|j|j#|j||j$|j%|j&|	|d |d7 }t||j}|jdkr|j|jdd}t'|jd D ]1}| j
}|r|t|| 9 }| t!|t"||j|j|j#|j||j$|j%|j&|	|d |d7 }q| j t!|| q||krMtd|j}td||}|||ddd| _(t| j(|| _)t|	||d| _*|d| _+t,|| _-t.|| j| _/t0d  | _1| 2  dS )a  
        EfficientNet based on `Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/pdf/1905.11946.pdf>`_.
        Adapted from `EfficientNet-PyTorch <https://github.com/lukemelas/EfficientNet-PyTorch>`_.

        Args:
            blocks_args_str: block definitions.
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            num_classes: number of output classes.
            width_coefficient: width multiplier coefficient (w in paper).
            depth_coefficient: depth multiplier coefficient (d in paper).
            dropout_rate: dropout rate for dropout layers.
            image_size: input image resolution.
            norm: feature normalization type and arguments.
            drop_connect_rate: dropconnect rate for drop connection (individual weights) layers.
            depth_divisor: depth divisor for channel rounding.

        )rG   r   r   z#spatial_dims can only be 1, 2 or 3.rD   rE   c                 S  s   g | ]}t |qS rm   )r   from_string.0srm   rm   rn   
<listcomp>  s    z)EfficientNet.__init__.<locals>.<listcomp>zblocks_args must be a listzblock_args must be non-emptyr       r   F)r6   r7   rH   rI   r   )input_filtersoutput_filters
num_repeatrG   )r2   r4   r5   r6   r7   r8   r:   r;   r=   r?   rA   )r   r7   z,total number of blocks created != num_blocksi   )r6   rH   rM   N)3rP   rQ   
ValueErrorr	   r   
isinstancelistZ_blocks_argsr   r4   rA   _round_filters
_conv_stemrT   _conv_stem_paddingr   rV   r[   r   
Sequential_blocksextract_stacks	enumerate_replacer   r   _round_repeatsr   r7   appendlenr   
add_modulestrr+   r6   r:   r;   r=   range
_conv_head_conv_head_paddingrZ   _avg_poolingDropout_dropoutLinear_fcr   re   _initialize_weights)rf   r   r2   r4   r   r   r   r   r8   r?   rA   r   rg   rh   Zblocks_argsZcurrent_image_sizer7   r5   
num_blocksidxZ
block_argsZ	stack_idxZblk_drop_connect_rate	sub_stack_Zhead_in_channelsrk   rm   rn   rQ      s   
 









zEfficientNet.__init__Trw   rx   c                 C  sB   |rt d  nt d dd| _| jD ]}|D ]}|| qqdS )z
        Sets swish function as memory efficient (for training) or standard (for JIT export).

        Args:
            memory_efficient: whether to use memory-efficient version of swish.

        rM   ry   r   rz   N)r   re   r   r|   )rf   rw   r   blockrm   rm   rn   r|     s   
zEfficientNet.set_swishro   rp   c                 C  sx   |  | |}| | |}| |}| | |}| | |}| |}|j	dd}| 
|}| |}|S )a!  
        Args:
            inputs: input should have spatially N dimensions
            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

        Returns:
            a torch Tensor of classification prediction in shape ``(Batch, num_classes)``.
        rG   )	start_dim)r   r   re   rV   r   r   r   rZ   r   flattenr   r   )rf   ro   ru   rm   rm   rn   rv     s   




zEfficientNet.forwardc                 C  s   |   D ]r\}}t|tjtjtjfr7ttj|j	d|j
 }|jjdtd|  |jdur6|jj  qt|tjtjtjfrP|jjd |jj  qt|tjrv|jd}d}dt||  }|jj| | |jj  qdS )a  
        Args:
            None, initializes weights for conv/linear/batchnorm layers
            following weight init methods from
            `official Tensorflow EfficientNet implementation
            <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L61>`_.
            Adapted from `EfficientNet-PyTorch's init method
            <https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/efficientnet_builder.py>`_.
        rG   r   r!   Nr   )named_modulesr   r   Conv1dConv2dConv3dr   operatormulr6   r5   weightdatanormal_mathsqrtrH   zero_BatchNorm1dBatchNorm2dBatchNorm3dfill_r   sizeuniform_)rf   r   mfan_outfan_in
init_rangerm   rm   rn   r     s$   

z EfficientNet._initialize_weights)r   r   r2   r3   r4   r3   r   r3   r   r   r   r   r   r   r8   r3   r?   r@   rA   r   r   r3   rB   rC   r~   r   r}   )rB   rC   )r   r   r   rQ   r|   rv   r   r   rm   rm   rk   rn   r      s      3
r   c                	      s8   e Zd Zdddddddddfd	fd fddZ  ZS )r   Tr   r   r   r,   r-   r.   r/   F
model_namer   
pretrainedrx   progressr2   r3   r4   r   r?   r@   adv_proprB   rC   c	                      g d}	|t vrdt  }
td| d|
 dt | \}}}}}t j|	|||||||||d
 |rB|dkrDt| ||| dS dS dS )	a  
        Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models
        model_name is mandatory argument as there is no EfficientNetBN itself,
        it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model

        Args:
            model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2].
            pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch
                norm is used.
            progress: whether to show download progress for pretrained weights download.
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            num_classes: number of output classes.
            norm: feature normalization type and arguments.
            adv_prop: whether to use weights trained with adversarial examples.
                This argument only works when `pretrained` is `True`.

        Examples::

            # for pretrained spatial 2D ImageNet
            >>> image_size = get_efficientnet_image_size("efficientnet-b0")
            >>> inputs = torch.rand(1, 3, image_size, image_size)
            >>> model = EfficientNetBN("efficientnet-b0", pretrained=True)
            >>> model.eval()
            >>> outputs = model(inputs)

            # create spatial 2D
            >>> model = EfficientNetBN("efficientnet-b0", spatial_dims=2)

            # create spatial 3D
            >>> model = EfficientNetBN("efficientnet-b0", spatial_dims=3)

            # create EfficientNetB7 for spatial 2D
            >>> model = EfficientNetBN("efficientnet-b7", spatial_dims=2)

        zr1_k3_s11_e1_i32_o16_se0.25zr2_k3_s22_e6_i16_o24_se0.25zr2_k5_s22_e6_i24_o40_se0.25zr3_k3_s22_e6_i40_o80_se0.25zr3_k5_s11_e6_i80_o112_se0.25zr4_k5_s22_e6_i112_o192_se0.25zr1_k3_s11_e6_i192_o320_se0.25, invalid model_name  found, must be one of  
r   r2   r4   r   r   r   r   r8   rA   r?   r   Nefficientnet_paramsjoinkeysr   rP   rQ   _load_state_dictrf   r   r   r   r2   r4   r   r?   r   r   model_name_stringZweight_coeffZdepth_coeffr8   r   Zdropconnect_raterk   rm   rn   rQ     s(   0zEfficientNetBN.__init__r   r   r   rx   r   rx   r2   r3   r4   r3   r   r3   r?   r@   r   rx   rB   rC   )r   r   r   rQ   r   rm   rm   rk   rn   r     s    r   c                	      sB   e Zd Zdddddddddfd	fd fddZdddZ  ZS ) r   Tr   r   r   r,   r-   r.   r/   Fr   r   r   rx   r   r2   r3   r4   r   r?   r@   r   rB   rC   c	                   r   )	a{  
        Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can
        be used as an encoder for segmentation and objection models.
        Compared with the class `EfficientNetBN`, the only different place is the forward function.

        This class refers to `PyTorch image models <https://github.com/rwightman/pytorch-image-models>`_.

        r   r   r   r   r   r   r   Nr   r   rk   rm   rn   rQ   7  s(   zEfficientNetBNFeatures.__init__ro   rp   c                 C  sp   |  | |}| | |}g }d| jv r|| t| jD ]\}}||}|d | jv r5|| q!|S )z
        Args:
            inputs: input should have spatially N dimensions
            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

        Returns:
            a list of torch Tensors.
        r   rG   )r   r   re   rV   r   r   r   r   )rf   ro   ru   featuresir   rm   rm   rn   rv   n  s   



zEfficientNetBNFeatures.forwardr   r}   )r   r   r   rQ   rv   r   rm   rm   rk   rn   r   5  s    7r   c                   @  sP   e Zd ZdZg dZedddZeddd	ZedddZedddZ	dS )r   zI
    Wrap the original efficientnet to an encoder for flexible-unet.
    r"   rB   
list[dict]c                 C  s>   g }| j D ]}||dddddddddfd	|v d
 q|S )zN
        Get the initialization parameter for efficientnet backbones.
        Tr   r   r   r,   r-   r.   r/   ap)r   r   r   r2   r4   r   r?   r   )backbone_namesr   )clsparameter_listbackbone_namerm   rm   rn   get_encoder_parameters  s   
z*EfficientNetEncoder.get_encoder_parameterslist[tuple[int, ...]]c                 C  s   g dS )zS
        Get number of efficientnet backbone output feature maps' channel.
        )
      (   p   i@  r   )r   r   0   x   i`  )r   r   r      i  )r   r   8      i  )r   r   @      i   )r   r   H      i@  )r   r   P   r   i  )r   r   X      i  )r   h   r   i  i`  rm   r   rm   rm   rn   num_channels_per_output  s   z+EfficientNetEncoder.num_channels_per_outputr9   c                 C  s
   dgd S )z
        Get number of efficientnet backbone output feature maps.
        Since every backbone contains the same 5 output feature maps,
        the number list should be `[5] * 10`.
           
   rm   r   rm   rm   rn   num_outputs  s   
zEfficientNetEncoder.num_outputsr   c                 C  s   | j S )z2
        Get names of efficient backbone.
        )r   r   rm   rm   rn   get_encoder_names  s   z%EfficientNetEncoder.get_encoder_namesN)rB   r   )rB   r   )rB   r9   )rB   r   )
r   r   r   __doc__r   classmethodr   r   r  r  rm   rm   rm   rn   r     s    r   r   r   rB   r3   c                 C  sB   | t vrdt  }td|  d| dt |  \}}}}}|S )z
    Get the input image size for a given efficientnet model.

    Args:
        model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7].

    Returns:
        Image size for single spatial dimension as integer.

    r   r   r   r   )r   r   r   r   )r   r   r   resrm   rm   rn   r     s
   r   ro   rp   rq   r   rr   rx   c           
      C  s   |dk s|dkrt d| |s| S | jd }d| }t| jd }|dgdg|  }tj|| j| jd}||7 }t|}| | | }	|	S )ah  
    Drop connect layer that drops individual connections.
    Differs from dropout as dropconnect drops connections instead of whole neurons as in dropout.

    Based on `Deep Networks with Stochastic Depth <https://arxiv.org/pdf/1603.09382.pdf>`_.
    Adapted from `Official Tensorflow EfficientNet utils
    <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py>`_.

    This function is generalized for MONAI's N-Dimensional spatial activations
    e.g. 1D activations [B, C, H], 2D activations [B, C, H, W] and 3D activations [B, C, H, W, D]

    Args:
        inputs: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims.
        p: probability to use for dropping connections.
        training: whether in training or evaluation mode.

    Returns:
        output: output tensor after applying drop connection.
    rF   r   z$p must be in range of [0, 1], found r   rG   r   )dtypedevice)r   shaper   rs   randr  r  floor)
ro   rq   rr   
batch_size	keep_probnum_dimsZrandom_tensor_shaperandom_tensorZbinary_tensoroutputrm   rm   rn   r     s   

r   model	nn.Modulearchr   r   rC   c                 C  s   |r| dd d }t|td }|d u rtd| d d S t| }tj||d}|  }td}|	 D ]\}}	t
|d|}
|
|v rT|	j||
 jkrT||
 ||< q7| | d S )	Nzefficientnet-z-apzpretrained weights of z is not provided)r   z(.+)\.\d+(\.\d+\..+)z\1\2)splitr   url_mapprintr   load_url
state_dictrecompileitemssubr	  load_state_dict)r  r  r   r   	model_urlZpretrain_state_dictmodel_state_dictpatternkeyvalueZpretrain_keyrm   rm   rn   r     s   
r   r8   r9   r6   tuple[int, ...]dilationr7   c                 C  sl   t |}t |dkr|| }t |dkr|| }dd t| |||D }dd |D }dd t|D }|S )a/  
    Helper for getting padding (nn.ConstantPadNd) to be used to get SAME padding
    conv operations similar to Tensorflow's SAME padding.

    This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D)

    Args:
        image_size: input image/feature spatial size.
        kernel_size: conv kernel's spatial size.
        dilation: conv dilation rate for Atrous conv.
        stride: stride for conv operation.

    Returns:
        paddings for ConstantPadNd padder to be used on input tensor to conv op.
    rG   c                 S  sD   g | ]\}}}}t t|| d  | |d  |  d  | dqS )rG   r   )r]   r   ceil)r   Z_i_sZ_k_s_d_srm   rm   rn   r   E  s    
.z-_get_same_padding_conv_nd.<locals>.<listcomp>c                 S  s    g | ]}|d  ||d   fqS )r   rm   )r   _prm   rm   rn   r   J  s     c                 S  s   g | ]	}|D ]}|qqS rm   rm   )r   innerouterrm   rm   rn   r   N  s    )r   zipreversed)r8   r6   r%  r7   r  	_pad_size	_paddingsZ_paddings_retrm   rm   rn   _get_same_padding_conv_nd(  s   r0  conv_op!nn.Conv1d | nn.Conv2d | nn.Conv3dc                 C  sH   t || j| j| j}tdt|d f }t|dkr ||ddS t S )a
  
    Helper for initializing ConstantPadNd with SAME padding similar to Tensorflow.
    Uses output of _get_same_padding_conv_nd() to get the padding size.

    This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D)

    Args:
        conv_op: nn.ConvNd operation to extract parameters for op from
        image_size: input image/feature spatial size

    Returns:
        If padding required then nn.ConstandNd() padder initialized to paddings otherwise nn.Identity()
    Zconstantpadr   r   rF   )paddingr#  )	r0  r6   r%  r7   r
   r   sumr   rW   )r1  r8   r3  padderrm   rm   rn   rT   R  s
   rT   filtersr   r<   r   c                 C  sR   |s| S |}|}| | }t |t||d  | | }|d| k r%||7 }t|S )aN  
    Calculate and round number of filters based on width coefficient multiplier and depth divisor.

    Args:
        filters: number of input filters.
        width_coefficient: width coefficient for model.
        depth_divisor: depth divisor to use.

    Returns:
        new_filters: new number of filters after calculation.
    r   g?)r]   r3   )r6  r   r   
multiplierdivisorZfilters_floatZnew_filtersrm   rm   rn   r   j  s   r   repeatsr   c                 C  s   |s| S t t||  S )a  
    Re-calculate module's repeat number of a block based on depth coefficient multiplier.

    Args:
        repeats: number of original repeats.
        depth_coefficient: depth coefficient for model.

    Returns:
        new repeat: new number of repeat after calculating.
    r3   r   r&  )r9  r   rm   rm   rn   r     s   r   input_image_sizeint | tuple[int]c                   sL   t  trt fdd D }|std   d   fdd| D S )a5  
    Calculates the output image size when using _make_same_padder with a stride.
    Required for static padding.

    Args:
        input_image_size: input image/feature spatial size.
        stride: Conv2d operation"s stride.

    Returns:
        output_image_size: output image/feature spatial size.
    c                 3  s    | ]	} d  |kV  qdS )r   Nrm   r   r7   rm   rn   	<genexpr>  s    z/_calculate_output_image_size.<locals>.<genexpr>z&unequal strides are not possible, got r   c                   s   g | ]}t t|  qS rm   r:  )r   Zim_szr=  rm   rn   r     s    z0_calculate_output_image_size.<locals>.<listcomp>)r   tupleallr   )r;  r7   Zall_strides_equalrm   r=  rn   r[     s   
r[   c                   @  sl   e Zd ZU dZded< ded< ded< ded< ded< ded< d	ed
< dZded< edddZdd ZdS )r   zq
    BlockArgs object to assist in decoding string notation
        of arguments for MBConvBlock definition.
    r3   r   r6   r7   r:   r   r   rx   r=   Nr<   r;   block_stringr   c                 C  s<  |  d}i }|D ]}t d|}t|dkr#|dd \}}|||< q	d|v r0t|d dkpct|d dkoC|d d |d d kpct|d dkoc|d d |d d koc|d d |d d k}|sjtd	tt|d
 t|d t|d d t|d t|d t|d d| vd|v rt|d dS ddS )a>  
        Get a BlockArgs object from a string notation of arguments.

        Args:
            block_string (str): A string notation of arguments.
                                Examples: "r1_k3_s11_e1_i32_o16_se0.25".

        Returns:
            BlockArgs: namedtuple defined at the top of this function.
        r   z(\d.*)r   Nr   rG   r   r   zinvalid stride option receivedrker   oZnoskipse)r   r6   r7   r:   r   r   r=   r;   )r  r  r   r   r   r3   r   )rA  opsoptionsopsplitsr"  r#  Zstride_checkrm   rm   rn   r     s6   
&>




zBlockArgs.from_stringc                 C  sT   d| j  d| j d| j | j d| j d| j d| j d| j }| js(|d7 }|S )	z
        Return a block string notation for current BlockArgs object

        Returns:
            A string notation of BlockArgs object arguments.
                Example: "r1_k3_s11_e1_i32_o16_se0.25_noskip".
        rB  _kr(  _e_i_o_seZ_noskip)r   r6   r7   r:   r   r   r;   r=   )rf   stringrm   rm   rn   	to_string  s    	zBlockArgs.to_string)rA  r   )	r   r   r   r  __annotations__r;   staticmethodr   rQ  rm   rm   rm   rn   r     s   
 'r   )r   r   rB   r3   )ro   rp   rq   r   rr   rx   rB   rp   )
r  r  r  r   r   rx   r   rx   rB   rC   )
r8   r9   r6   r$  r%  r$  r7   r$  rB   r9   )r1  r2  r8   r9   )r6  r3   r   r<   r   r   rB   r3   )r9  r3   r   r<   rB   r3   )r;  r9   r7   r<  )*
__future__r   r   r   r  	functoolsr   typingr   rs   r   torch.utilsr   monai.networks.blocksr   monai.networks.layers.factoriesr   r	   r
   r   monai.networks.layers.utilsr   monai.utils.moduler   __all__r   r  Moduler+   r   r   r   r   r   r   r   r0  rT   r   r   r[   r   rm   rm   rm   rn   <module>   s   
  yWP
J

/

*


