U
    Ph $                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ d dlmZmZ d dlmZmZmZ d dlmZ ed	d
d\ZZddhZdddhZG dd dejZG dd dejZdS )    )annotations)SequenceN)	LayerNorm)build_sincos_position_embedding)Convtrunc_normal_)deprecated_argensure_tuple_repoptional_import)look_up_optionzeinops.layers.torch	Rearrange)nameconv
perceptronnone	learnablesincosc                      s^   e Zd ZdZeddddddddddddddddddd fddZdd Zdd Z  ZS )PatchEmbeddingBlocka  
    A patch embedding block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Example::

        >>> from monai.networks.blocks import PatchEmbeddingBlock
        >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
        >>>                     proj_type="conv", pos_embed_type="sincos")

    	pos_embedz1.2z1.4	proj_typezplease use `proj_type` instead.)r   sinceremovednew_name
msg_suffixr   r              intSequence[int] | intstrfloatNone)in_channelsimg_size
patch_sizehidden_size	num_headsr   r   pos_embed_typedropout_ratespatial_dimsreturnc                   sv  t    d|	  krdks0n td|	 d|| dkrRtd| d| dt|t| _t|t| _t||
}t||
}t	||D ]6\}}||k rtd| jd	kr|| dkrtd
qt
dd t	||D | _t|t
| | _|  | jdkrttj|
f ||||d| _n| jd	krdd|
 }dddd |D  }dddd |D  dddd |D  d}dd t|D }tt| d| f|t| j|| _ttd| j|| _t|	| _| jdkrnx| jdkrt| jdd d!d"d# nV| jd$krTg }t	||D ]\}}|||  q*t |||
| _ntd%| j d&| !| j" dS )'aX  
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            num_heads: number of attention heads.
            proj_type: patch embedding layer type.
            pos_embed_type: position embedding layer type.
            dropout_rate: fraction of the input units to drop.
            spatial_dims: number of spatial dimensions.
        .. deprecated:: 1.4
            ``pos_embed`` is deprecated in favor of ``proj_type``.
        r      zdropout_rate z should be between 0 and 1.zhidden size z" should be divisible by num_heads .z+patch_size should be smaller than img_size.r   z:patch_size should be divisible by img_size for perceptron.c                 S  s   g | ]\}}|| qS  r,   ).0Zim_dp_dr,   r,   Y/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/patchembedding.py
<listcomp>^   s     z0PatchEmbeddingBlock.__init__.<locals>.<listcomp>r   r!   out_channelskernel_sizestride))hp1)wp2)dp3Nzb c  c                 s  s$   | ]\}}d | d| dV  qdS )(r;   )Nr,   )r-   kvr,   r,   r/   	<genexpr>i   s     z/PatchEmbeddingBlock.__init__.<locals>.<genexpr>zb (c                 S  s   g | ]}|d  qS )r   r,   r-   cr,   r,   r/   r0   j   s     z) (c                 S  s   g | ]}|d  qS )r*   r,   rA   r,   r,   r/   r0   j   s     z c)c                 S  s    i | ]\}}d |d  |qS )pr*   r,   )r-   irC   r,   r,   r/   
<dictcomp>k   s      z0PatchEmbeddingBlock.__init__.<locals>.<dictcomp>z -> r   r   r   {Gz?              @meanstdabr   zpos_embed_type z not supported.)#super__init__
ValueErrorr   SUPPORTED_PATCH_EMBEDDING_TYPESr   SUPPORTED_POS_EMBEDDING_TYPESr&   r	   zipnpprodZ	n_patchesr   Z	patch_dimr   CONVpatch_embeddingsjoin	enumeratenn
Sequentialr   Linear	Parametertorchzerosposition_embeddingsDropoutdropoutr   appendr   apply_init_weights)selfr!   r"   r#   r$   r%   r   r   r&   r'   r(   mrC   charsZ
from_charsZto_charsZaxes_len	grid_sizein_sizeZpa_size	__class__r,   r/   rO   -   s\    



   
2 zPatchEmbeddingBlock.__init__c                 C  sx   t |tjrHt|jddddd t |tjrt|jd k	rttj|jd n,t |tjrttj|jd tj|jd d S )Nr   rF   rG   rH   rI   r   g      ?)	
isinstancerZ   r\   r   weightbiasinit	constant_r   )rf   rg   r,   r,   r/   re      s    z!PatchEmbeddingBlock._init_weightsc                 C  s>   |  |}| jdkr&|ddd}|| j }| |}|S )Nr      )rW   r   flatten	transposer`   rb   )rf   x
embeddingsr,   r,   r/   forward   s    



zPatchEmbeddingBlock.forward)r   r   r   r   r   )	__name__
__module____qualname____doc__r   rO   re   ry   __classcell__r,   r,   rk   r/   r       s        
     *Q	r   c                      sF   e Zd ZdZdddejdfdddddd	d
 fddZdd Z  ZS )
PatchEmbeda0  
    Patch embedding block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

    Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
    specified (3) position embedding is not used.

    Example::

        >>> from monai.networks.blocks import PatchEmbed
        >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)
    rr   r*   0   r   r   r   ztype[LayerNorm]r    )r#   in_chans	embed_dim
norm_layerr(   r)   c                   sj   t    |dkrtdt||}|| _|| _ttj|f ||||d| _|dk	r`||| _	nd| _	dS )a  
        Args:
            patch_size: dimension of patch size.
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            norm_layer: normalization layer.
            spatial_dims: spatial dimension.
        )rr   r   z#spatial dimension should be 2 or 3.r1   N)
rN   rO   rP   r	   r#   r   r   rV   projnorm)rf   r#   r   r   r   r(   rk   r,   r/   rO      s    

   zPatchEmbed.__init__c           	      C  s   |  }t|dkr|\}}}}}|| jd  dkrXt|d| jd || jd   f}|| jd  dkrt|ddd| jd || jd   f}|| jd  dkrt|ddddd| jd || jd   f}nt|dkr`|\}}}}|| jd  dkr$t|d| jd || jd   f}|| jd  dkr`t|ddd| jd || jd   f}| |}| jd k	r|  }|ddd}| |}t|dkr|d |d |d   }}}|dd	d| j
|||}n:t|dkr|d |d  }}|dd	d| j
||}|S )N   rr   r   r*      r   rs   )sizelenr#   Fpadr   r   ru   rv   viewr   )	rf   rw   x_shape_r9   r5   r7   whwwr,   r,   r/   ry      s6    $(.$(

zPatchEmbed.forward)	rz   r{   r|   r}   rZ   r   rO   ry   r~   r,   r,   rk   r/   r      s   !r   )
__future__r   collections.abcr   numpyrT   r^   torch.nnrZ   torch.nn.functional
functionalr   r   Z%monai.networks.blocks.pos_embed_utilsr   monai.networks.layersr   r   monai.utilsr   r	   r
   monai.utils.moduler   r   r   rQ   rR   Moduler   r   r,   r,   r,   r/   <module>   s   
s