o
    &i"                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ d dlmZmZ d dlmZmZ d dlmZ ed	d
d\ZZddhZh dZG dd dejZG dd dejZdS )    )annotations)SequenceN)	LayerNorm)build_sincos_position_embedding)Convtrunc_normal_)ensure_tuple_repoptional_import)look_up_optionzeinops.layers.torch	Rearrange)nameconv
perceptron>   sincos	learnablenonec                      s<   e Zd ZdZ				dd fddZdd Zdd Z  ZS )PatchEmbeddingBlocka  
    A patch embedding block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Example::

        >>> from monai.networks.blocks import PatchEmbeddingBlock
        >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
        >>>                     proj_type="conv", pos_embed_type="sincos")

    r   r              in_channelsintimg_sizeSequence[int] | int
patch_sizehidden_size	num_heads	proj_typestrpos_embed_typedropout_ratefloatspatial_dimsreturnNonec
                   st  t    d|  krdksn td| d|| dkr)td| d| dt|t| _t|t| _t||	}t||	}t	||D ]\}
}|
|k rPtd| jd	kr_|
| dkr_td
qDt
dd t	||D | _t|t
| | _|  | jdkrttj|	f ||||d| _nP| jd	krdd|	 }dddd |D  }dddd |D  dddd |D  d}dd t|D }tt| d| fi |t| j|| _ttd| j|| _t|| _| jdkrn<| jdkrt| jdd d!d"d# n+| jd$kr)g }t	||D ]\}}|||  qt |||	| _n	td%| j d&| !| j" dS )'a  
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            num_heads: number of attention heads.
            proj_type: patch embedding layer type.
            pos_embed_type: position embedding layer type.
            dropout_rate: fraction of the input units to drop.
            spatial_dims: number of spatial dimensions.
        r      zdropout_rate z should be between 0 and 1.zhidden size z" should be divisible by num_heads .z+patch_size should be smaller than img_size.r   z:patch_size should be divisible by img_size for perceptron.c                 S  s   g | ]\}}|| qS  r&   ).0Zim_dp_dr&   r&   f/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/patchembedding.py
<listcomp>X   s    z0PatchEmbeddingBlock.__init__.<locals>.<listcomp>r   r   out_channelskernel_sizestride))hp1)wp2)dp3Nzb c  c                 s  s&    | ]\}}d | d| dV  qdS )(r5   )Nr&   )r'   kvr&   r&   r)   	<genexpr>c   s   $ z/PatchEmbeddingBlock.__init__.<locals>.<genexpr>zb (c                 S     g | ]}|d  qS )r   r&   r'   cr&   r&   r)   r*   d       z) (c                 S  r;   )r$   r&   r<   r&   r&   r)   r*   d   r>   z c)c                 S  s    i | ]\}}d |d  |qS )pr$   r&   )r'   ir?   r&   r&   r)   
<dictcomp>e   s     z0PatchEmbeddingBlock.__init__.<locals>.<dictcomp>z -> r   r   r   {Gz?              @meanstdabr   zpos_embed_type z not supported.)#super__init__
ValueErrorr
   SUPPORTED_PATCH_EMBEDDING_TYPESr   SUPPORTED_POS_EMBEDDING_TYPESr   r   zipnpprodZ	n_patchesr   Z	patch_dimr   CONVpatch_embeddingsjoin	enumeratenn
Sequentialr   Linear	Parametertorchzerosposition_embeddingsDropoutdropoutr   appendr   apply_init_weights)selfr   r   r   r   r   r   r   r   r!   mr?   charsZ
from_charsZto_charsZaxes_len	grid_sizein_sizeZpa_size	__class__r&   r)   rK   -   sV   





2$
zPatchEmbeddingBlock.__init__c                 C  s   t |tjr)t|jddddd t |tjr%|jd ur'tj|jd d S d S d S t |tjrAtj|jd tj|jd d S d S )Nr   rB   rC   rD   rE   r   g      ?)	
isinstancerV   rX   r   weightbiasinit	constant_r   )rb   rc   r&   r&   r)   ra   {   s   z!PatchEmbeddingBlock._init_weightsc                 C  s>   |  |}| jdkr|ddd}|| j }| |}|S )Nr      )rS   r   flatten	transposer\   r^   )rb   x
embeddingsr&   r&   r)   forward   s   



zPatchEmbeddingBlock.forward)r   r   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r   r"   r#   )__name__
__module____qualname____doc__rK   ra   ru   __classcell__r&   r&   rg   r)   r       s    N	r   c                      s8   e Zd ZdZdddejdfd fddZdd Z  ZS )
PatchEmbeda0  
    Patch embedding block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

    Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
    specified (3) position embedding is not used.

    Example::

        >>> from monai.networks.blocks import PatchEmbed
        >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)
    rn   r$   0   r   r   r   in_chansr   	embed_dim
norm_layertype[LayerNorm]r!   r"   r#   c                   sl   t    |dvrtdt||}|| _|| _ttj|f ||||d| _|dur1||| _	dS d| _	dS )a  
        Args:
            patch_size: dimension of patch size.
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            norm_layer: normalization layer.
            spatial_dims: spatial dimension.
        )rn   r   z#spatial dimension should be 2 or 3.r+   N)
rJ   rK   rL   r   r   r~   r   rR   projnorm)rb   r   r}   r~   r   r!   rg   r&   r)   rK      s   


zPatchEmbed.__init__c           	      C  s  |  }t|dkri|\}}}}}|| jd  dkr,t|d| jd || jd   f}|| jd  dkrIt|ddd| jd || jd   f}|| jd  dkrht|ddddd| jd || jd   f}nDt|dkr|\}}}}|| jd  dkrt|d| jd || jd   f}|| jd  dkrt|ddd| jd || jd   f}| |}| jd ur|  }|ddd}| |}t|dkr|d |d |d }}}|dd	d| j
|||}|S t|dkr|d |d }}|dd	d| j
||}|S )N   rn   r   r$      r   ro   )sizelenr   Fpadr   r   rq   rr   viewr~   )	rb   rs   x_shape_r3   r/   r1   whwwr&   r&   r)   ru      s:   $(,$(

zPatchEmbed.forward)r   r   r}   r   r~   r   r   r   r!   r   r"   r#   )	rv   rw   rx   ry   rV   r   rK   ru   rz   r&   r&   rg   r)   r{      s    !r{   )
__future__r   collections.abcr   numpyrP   rZ   torch.nnrV   torch.nn.functional
functionalr   r   Z%monai.networks.blocks.pos_embed_utilsr   monai.networks.layersr   r   monai.utilsr   r	   monai.utils.moduler
   r   r   rM   rN   Moduler   r{   r&   r&   r&   r)   <module>   s    m