U
    Phn=                     @  s(  d dl mZ d dlZd dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 e
dZe
ddd	d  Ze
d
dd	d  Ze
ddd	d  Ze
ddd	d  ZdddddddgZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd deZG dd dejjZdS )    )annotationsN)Sequence)nn)PathLike)optional_importtransformersload_tf_weights_in_bert)nameztransformers.utilscached_filez&transformers.models.bert.modeling_bertBertEmbeddings	BertLayerBertPreTrainedModelBertAttention
BertOutputBertMixedLayerPooler
MultiModal	Transchexc                      s<   e Zd ZdZdd fddZdd ZedddZ  ZS )r   zModule to load BERT pre-trained weights.
    Based on:
    LXMERT
    https://github.com/airsplay/lxmert
    BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    Nonereturnc                   s   t    d S N)super__init__)selfinputskwargs	__class__ R/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/transchex.pyr   )   s    zBertPreTrainedModel.__init__c                 C  sx   t |tjtjfr*|jjjd| jjd n(t |t	jj
rR|jj  |jjd t |tjrt|jd k	rt|jj  d S )N        )meanstd      ?)
isinstancer   Linear	Embeddingweightdatanormal_configinitializer_rangetorch	LayerNormbiaszero_fill_)r   moduler   r   r    init_bert_weights,   s    z%BertPreTrainedModel.init_bert_weightsNFbert-base-uncasedpytorch_model.binc
                   sZ  t ||	|d}| ||||f|
|}d krL|sLtj|tj sDdnd d|rZt||S g }g } D ]H}d }d|kr|dd}d|kr|dd}|rj|| || qjt	||D ]\}}
||< qg g g  tdd  d k	r_d fd
d	d	}t|dsJtdd  D rJd}||d |S )N)	cache_dircpu)map_locationgammar(   betar/   	_metadata c              	     sh   d kri n |d d i }| ||d  | j D ]"\}}|d k	r@||| d  q@d S )NT.)get_load_from_state_dict_modulesitems)r2   prefixlocal_metadatar	   child
error_msgsloadmetadatamissing_keys
state_dictunexpected_keysr   r    rH   _   s           z1BertPreTrainedModel.from_pretrained.<locals>.loadbertc                 s  s   | ]}| d V  qdS )bert.N)
startswith).0sr   r   r    	<genexpr>i   s     z6BertPreTrainedModel.from_pretrained.<locals>.<genexpr>rN   )rC   )r<   )r
   r-   rH   cudais_availabler   keysreplaceappendzippopgetattrcopyr;   hasattrany)clsnum_language_layersnum_vision_layersnum_mixed_layersbert_configrK   r6   Zfrom_tfpath_or_repo_idfilenamer   r   Zweights_pathmodelZold_keysnew_keyskeynew_keyold_keyZstart_prefixr   rF   r    from_pretrained5   s@    


	$z#BertPreTrainedModel.from_pretrained)NNFr4   r5   )	__name__
__module____qualname____doc__r   r3   classmethodrj   __classcell__r   r   r   r    r       s   	     c                      s6   e Zd ZdZdd fddZdd Zdd	 Z  ZS )
r   zsBERT attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   sz   t    |j| _t|j|j | _| j| j | _t|j| j| _	t|j| j| _
t|j| j| _t|j| _d S r   )r   r   num_attention_headsinthidden_sizeattention_head_sizeall_head_sizer   r&   queryrg   valueDropoutattention_probs_dropout_probdropoutr   r+   r   r   r    r   u   s    
zBertAttention.__init__c                 C  s6   |  d d | j| jf }|j| }|ddddS )Nr=   r            )sizerq   rt   viewpermute)r   xZnew_x_shaper   r   r    transpose_for_scores   s    
z"BertAttention.transpose_for_scoresc                 C  s   |  |}| |}| |}| |}| |}| |}t||dd}	|	t| j	 }	| 
tjdd|	}
t|
|}|dddd }| d d | jf }|j| }|S )Nr=   )dimr   r|   r}   r~   )rv   rg   rw   r   r-   matmul	transposemathsqrtrt   rz   r   Softmaxr   
contiguousr   ru   r   )r   hidden_statescontextZmixed_query_layerZmixed_key_layerZmixed_value_layerquery_layerZ	key_layerZvalue_layerZattention_scoresattention_probsZcontext_layerZnew_context_layer_shaper   r   r    forward   s    






zBertAttention.forward)rk   rl   rm   rn   r   r   r   rp   r   r   r   r    r   o   s   
c                      s.   e Zd ZdZdd fddZdd Z  ZS )r   zpBERT output layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   sB   t    t|j|j| _tjj|jdd| _t|j	| _
d S )N-q=)eps)r   r   r   r&   rs   denser-   r.   rx   hidden_dropout_probrz   r{   r   r   r    r      s    
zBertOutput.__init__c                 C  s&   |  |}| |}| || }|S r   )r   rz   r.   )r   r   input_tensorr   r   r    r      s    

zBertOutput.forwardrk   rl   rm   rn   r   r   rp   r   r   r   r    r      s   c                      s.   e Zd ZdZdd fddZdd Z  ZS )r   zyBERT cross attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   s6   t    t|| _t|| _t|| _t|| _d S r   )r   r   r   att_xr   output_xatt_youtput_yr{   r   r   r    r      s
    



zBertMixedLayer.__init__c                 C  s0   |  ||}| ||}| ||| ||fS r   )r   r   r   r   )r   r   yr   r   r   r   r    r      s    zBertMixedLayer.forwardr   r   r   r   r    r      s   c                      s.   e Zd ZdZdd fddZdd Z  ZS )r   zpBERT pooler layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   s&   t    t||| _t | _d S r   )r   r   r   r&   r   Tanh
activation)r   rs   r   r   r    r      s    
zPooler.__init__c                 C  s(   |d d df }|  |}| |}|S Nr   )r   r   )r   r   Zfirst_token_tensorZpooled_outputr   r   r    r      s    

zPooler.forwardr   r   r   r   r    r      s   c                      s8   e Zd ZdZdddddd fddZdd	d
Z  ZS )r   z?
    Multimodal Transformers From Pretrained BERT Weights"
    rr   dictr   )r_   r`   ra   rb   r   c                   s   t    tdtf| _t j _t fddt	|D  _
t fddt	|D  _t fddt	|D  _  j dS )z
        Args:
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            bert_config: configuration for bert language transformer encoder.

        objc                   s   g | ]}t  jqS r   r   r+   rP   _r   r   r    
<listcomp>   s     z'MultiModal.__init__.<locals>.<listcomp>c                   s   g | ]}t  jqS r   r   r   r   r   r    r      s     c                   s   g | ]}t  jqS r   )r   r+   r   r   r   r    r      s     N)r   r   typeobjectr+   r   
embeddingsr   
ModuleListrangelanguage_encodervision_encodermixed_encoderapplyr3   )r   r_   r`   ra   rb   r   r   r    r      s    

zMultiModal.__init__Nc                 C  sb   |  ||}| jD ]}||d d }q| jD ]}|||d }q,| jD ]}|||\}}qF||fS r   )r   r   r   r   )r   	input_idstoken_type_idsvision_featsattention_maskZlanguage_featureslayerr   r   r    r      s    


zMultiModal.forward)NNNr   r   r   r   r    r      s   c                "      sp   e Zd ZdZd#ddddddddddddddddddddddddddddddddd  fddZd$d!d"Z  ZS )%r   z
    TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language
    Transformers for Chest X-ray Analysis"
       r!   皙?Fgelu{Gz?   r      rM      r   absolute4.10.2r|   T:w  r4   r5   rr   zSequence[int] | intzint | tuple[int, int]floatboolstrzstr | PathLiker   ) in_channelsimg_size
patch_sizenum_classesr_   r`   ra   rs   drop_outry   gradient_checkpointing
hidden_actr   r,   intermediate_sizelayer_norm_epsmax_position_embeddings
model_typerq   num_hidden_layerspad_token_idposition_embedding_typetransformers_versiontype_vocab_size	use_cache
vocab_sizechunk_size_feed_forward
is_decoderadd_cross_attentionrc   rd   r   c            !        s8  t    |
d||||||||||||||||||||d} d|	  krPdksZn td|d |d  dks|d |d  dkrtdtj|||| ||d| _|| _|d | jd  |d | jd   | _tj	||| j| jd| _
t|| _ttd| j|| _t|d	| _tj|	| _tj||| _dS )
a  
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            num_classes: number of classes if classification is used.
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            num_mixed_layers: number of mixed transformer layers.
            drop_out: fraction of the input units to drop.
            path_or_repo_id: This can be either:
                - a string, the *model id* of a model repo on huggingface.co.
                - a path to a *directory* potentially containing the file.
            filename: The name of the file to locate in `path_or_repo`.

        The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.

        Examples:

        .. code-block:: python

            # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers,
            # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head
            net = Transchex(in_channels=3,
                                 img_size=(224, 224),
                                 num_classes=3,
                                 num_language_layers=2,
                                 num_vision_layers=2,
                                 num_mixed_layers=2,
                                 drop_out=0.2)

        N)ry   Zclassifier_dropoutr   r   r   rs   r,   r   r   r   r   rq   r   r   r   r   r   r   r   r   r   r   r   r}   z'dropout_rate should be between 0 and 1.z+img_size should be divisible by patch_size.)r_   r`   ra   rb   rc   rd   )r   out_channelskernel_sizestride)rs   )r   r   
ValueErrorr   rj   
multimodalr   num_patchesr   Conv2dvision_projr.   norm_vision_pos	Parameterr-   zerospos_embed_visr   poolerrx   dropr&   cls_head)!r   r   r   r   r   r_   r`   ra   rs   r   ry   r   r   r   r,   r   r   r   r   rq   r   r   r   r   r   r   r   r   r   r   rc   rd   rb   r   r   r    r      sb    B
(	&   zTranschex.__init__Nc           	      C  s   t |dd}|jt|  jd}d| d }| |d	dd}| 
|}|| j }| j||||d\}}| |}| | |}|S )Nr}   r|   )dtyper$   g     )r   r   r   r   )r-   	ones_like	unsqueezetonext
parametersr   r   flattenr   r   r   r   r   r   r   )	r   r   r   r   r   Zhidden_state_langZhidden_state_visZpooled_featureslogitsr   r   r    r   j  s    

   

zTranschex.forward)r   r!   r   Fr   r   r   r   r   r   rM   r   r   r   r   r   r|   Tr   r   FFr4   r5   )NNr   r   r   r   r    r      s6                           Ru)
__future__r   r   collections.abcr   r-   r   monai.config.type_definitionsr   monai.utilsr   r   r   r
   r   r   __all__Moduler   r   r   r   r   r   r   r   r   r   r    <module>   s&   O&"