o
    -i=                     @  s  d dl mZ d dlZd dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 e
dZe
ddd	d  Ze
d
dd	d  Ze
ddd	d  Ze
ddd	d  Zg dZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd deZG dd dejjZdS )    )annotationsN)Sequence)nn)PathLike)optional_importtransformersload_tf_weights_in_bert)nameztransformers.utilscached_filez&transformers.models.bert.modeling_bertBertEmbeddings	BertLayer)BertPreTrainedModelBertAttention
BertOutputBertMixedLayerPooler
MultiModal	Transchexc                      sB   e Zd ZdZd fddZdd Ze					
	dddZ  ZS )r   zModule to load BERT pre-trained weights.
    Based on:
    LXMERT
    https://github.com/airsplay/lxmert
    BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    returnNonec                   s   t    d S N)super__init__)selfinputskwargs	__class__ _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/transchex.pyr   )   s   zBertPreTrainedModel.__init__c                 C  s   t |tjtjfr|jjjd| jjd nt |t	jj
r)|jj  |jjd t |tjr<|jd ur>|jj  d S d S d S )N        )meanstd      ?)
isinstancer   Linear	Embeddingweightdatanormal_configinitializer_rangetorch	LayerNormbiaszero_fill_)r   moduler   r   r   init_bert_weights,   s   z%BertPreTrainedModel.init_bert_weightsNFbert-base-uncasedpytorch_model.binc
                   s`  t ||	|d}| ||||g|
R i |}d u r,|s,tj s"dnd }tj||dd|r3t||S g }g } D ]$}d }d|v rI|dd}d|v rS|dd}|r_|| || q;t	||D ]\}}
||< qeg g g  td	d  d ur_d fdd	d
}t|dstdd  D rd}||d |S )N)	cache_dircpuT)map_locationweights_onlygammar'   betar.   	_metadata c              	     sh   d u ri n	 |d d i }| ||d  | j D ]\}}|d ur1||| d  q d S )NT.)get_load_from_state_dict_modulesitems)r1   prefixlocal_metadatar	   child
error_msgsloadmetadatamissing_keys
state_dictunexpected_keysr   r   rH   `   s    z1BertPreTrainedModel.from_pretrained.<locals>.loadbertc                 s  s    | ]}| d V  qdS )bert.N)
startswith).0sr   r   r   	<genexpr>j   s    z6BertPreTrainedModel.from_pretrained.<locals>.<genexpr>rN   )rC   )r<   )r
   r,   cudais_availablerH   r   keysreplaceappendzippopgetattrcopyr;   hasattrany)clsnum_language_layersnum_vision_layersnum_mixed_layersbert_configrK   r5   Zfrom_tfpath_or_repo_idfilenamer   r   Zweights_pathmodelr7   Zold_keysnew_keyskeynew_keyold_keyZstart_prefixr   rF   r   from_pretrained5   sD   


	 z#BertPreTrainedModel.from_pretrainedr   r   )NNFr3   r4   )	__name__
__module____qualname____doc__r   r2   classmethodrj   __classcell__r   r   r   r   r       s    	r   c                      s2   e Zd ZdZd
 fddZdd Zdd	 Z  ZS )r   zsBERT attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   sz   t    |j| _t|j|j | _| j| j | _t|j| j| _	t|j| j| _
t|j| j| _t|j| _d S r   )r   r   num_attention_headsinthidden_sizeattention_head_sizeall_head_sizer   r%   queryrg   valueDropoutattention_probs_dropout_probdropoutr   r*   r   r   r   r   v   s   
zBertAttention.__init__c                 C  s6   |  d d | j| jf }|j| }|ddddS )Nr=   r            )sizerr   ru   viewpermute)r   xZnew_x_shaper   r   r   transpose_for_scores   s   
z"BertAttention.transpose_for_scoresc                 C  s   |  |}| |}| |}| |}| |}| |}t||dd}	|	t| j	 }	| 
tjdd|	}
t|
|}|dddd }| d d | jf }|j| }|S )Nr=   )dimr   r}   r~   r   )rw   rg   rx   r   r,   matmul	transposemathsqrtru   r{   r   Softmaxr   
contiguousr   rv   r   )r   hidden_statescontextZmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper   r   r   forward   s   






zBertAttention.forwardrk   )rl   rm   rn   ro   r   r   r   rq   r   r   r   r   r   p   s
    
r   c                      *   e Zd ZdZd fddZdd Z  ZS )	r   zpBERT output layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   sB   t    t|j|j| _tjj|jdd| _t|j	| _
d S )N-q=)eps)r   r   r   r%   rt   denser,   r-   ry   hidden_dropout_probr{   r|   r   r   r   r      s   
zBertOutput.__init__c                 C  s&   |  |}| |}| || }|S r   )r   r{   r-   )r   r   input_tensorr   r   r   r      s   

zBertOutput.forwardrk   rl   rm   rn   ro   r   r   rq   r   r   r   r   r      s    r   c                      r   )	r   zyBERT cross attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   s6   t    t|| _t|| _t|| _t|| _d S r   )r   r   r   att_xr   output_xatt_youtput_yr|   r   r   r   r      s
   



zBertMixedLayer.__init__c                 C  s0   |  ||}| ||}| ||| ||fS r   )r   r   r   r   )r   r   yr   r   r   r   r   r      s   zBertMixedLayer.forwardrk   r   r   r   r   r   r      s    r   c                      r   )	r   zpBERT pooler layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    r   r   c                   s&   t    t||| _t | _d S r   )r   r   r   r%   r   Tanh
activation)r   rt   r   r   r   r      s   
zPooler.__init__c                 C  s(   |d d df }|  |}| |}|S Nr   )r   r   )r   r   Zfirst_token_tensorZpooled_outputr   r   r   r      s   

zPooler.forwardrk   r   r   r   r   r   r      s    r   c                      s,   e Zd ZdZd fd
dZdddZ  ZS )r   z?
    Multimodal Transformers From Pretrained BERT Weights"
    r_   rs   r`   ra   rb   dictr   r   c                   s   t    tdtf| _t j _t fddt	|D  _
t fddt	|D  _t fddt	|D  _  j dS )z
        Args:
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            bert_config: configuration for bert language transformer encoder.

        objc                      g | ]}t  jqS r   r   r*   rP   _r   r   r   
<listcomp>       z'MultiModal.__init__.<locals>.<listcomp>c                   r   r   r   r   r   r   r   r      r   c                   r   r   )r   r*   r   r   r   r   r      r   N)r   r   typeobjectr*   r   
embeddingsr   
ModuleListrangelanguage_encodervision_encodermixed_encoderapplyr2   )r   r_   r`   ra   rb   r   r   r   r      s   

zMultiModal.__init__Nc                 C  sb   |  ||}| jD ]	}||d d }q	| jD ]	}|||d }q| jD ]	}|||\}}q#||fS r   )r   r   r   r   )r   	input_idstoken_type_idsvision_featsattention_maskZlanguage_featureslayerr   r   r   r      s   


zMultiModal.forward)
r_   rs   r`   rs   ra   rs   rb   r   r   r   )NNNr   r   r   r   r   r      s    r   c                      s^   e Zd ZdZ											
														dBdC fd=d>ZdDd@dAZ  ZS )Er   z
    TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language
    Transformers for Chest X-ray Analysis"
       r    皙?Fgelu{Gz?   r      rM      r   absolute4.10.2r}   T:w  r3   r4   in_channelsrs   img_sizeSequence[int] | int
patch_sizeint | tuple[int, int]num_classesr_   r`   ra   rt   drop_outfloatrz   gradient_checkpointingbool
hidden_actstrr   r+   intermediate_sizelayer_norm_epsmax_position_embeddings
model_typerr   num_hidden_layerspad_token_idposition_embedding_typetransformers_versiontype_vocab_size	use_cache
vocab_sizechunk_size_feed_forward
is_decoderadd_cross_attentionrc   str | PathLikerd   r   r   c            !        s  t    i d|
ddd|d|d|d|d|d	|d
|d|d|d|d|d|d|d|d||||||dd} d|	  krPdksUtd td|d |d  dksi|d |d  dkrmtdtj|||| ||d| _|| _|d | jd  |d | jd   | _tj	||| j| jd| _
t|| _ttd| j|| _t|d| _tj|	| _tj||| _dS )a  
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            num_classes: number of classes if classification is used.
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            num_mixed_layers: number of mixed transformer layers.
            drop_out: fraction of the input units to drop.
            path_or_repo_id: This can be either:
                - a string, the *model id* of a model repo on huggingface.co.
                - a path to a *directory* potentially containing the file.
            filename: The name of the file to locate in `path_or_repo`.

        The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.

        Examples:

        .. code-block:: python

            # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers,
            # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head
            net = Transchex(in_channels=3,
                                 img_size=(224, 224),
                                 num_classes=3,
                                 num_language_layers=2,
                                 num_vision_layers=2,
                                 num_mixed_layers=2,
                                 drop_out=0.2)

        rz   Zclassifier_dropoutNr   r   r   rt   r+   r   r   r   r   rr   r   r   r   r   r   eager)r   r   r   r   r   Z_attn_implementationr   r~   z'dropout_rate should be between 0 and 1.z+img_size should be divisible by patch_size.)r_   r`   ra   rb   rc   rd   )r   out_channelskernel_sizestride)rt   )r   r   
ValueErrorr   rj   
multimodalr   num_patchesr   Conv2dvision_projr-   norm_vision_pos	Parameterr,   zerospos_embed_visr   poolerry   dropr%   cls_head)!r   r   r   r   r   r_   r`   ra   rt   r   rz   r   r   r   r+   r   r   r   r   rr   r   r   r   r   r   r   r   r   r   r   rc   rd   rb   r   r   r   r      s   
B	
(	&zTranschex.__init__Nc           	      C  s   t |dd}|jt|  jd}d| d }| |d	dd}| 
|}|| j }| j||||d\}}| |}| | |}|S )Nr~   r}   )dtyper#   g     )r   r   r   r   )r,   	ones_like	unsqueezetonext
parametersr   r   flattenr   r   r   r   r   r   r   )	r   r   r   r   r   Zhidden_state_langZhidden_state_visZpooled_featureslogitsr   r   r   r   l  s   



zTranschex.forward)r   r    r   Fr   r   r   r   r   r   rM   r   r   r   r   r   r}   Tr   r   FFr3   r4   )@r   rs   r   r   r   r   r   rs   r_   rs   r`   rs   ra   rs   rt   rs   r   r   rz   r   r   r   r   r   r   r   r+   r   r   rs   r   r   r   rs   r   r   rr   rs   r   rs   r   rs   r   r   r   r   r   rs   r   r   r   rs   r   rs   r   r   r   r   rc   r   rd   r   r   r   )NNr   r   r   r   r   r      s8    vr   )
__future__r   r   collections.abcr   r,   r   monai.config.type_definitionsr   monai.utilsr   r   r   r
   r   r   __all__Moduler   r   r   r   r   r   r   r   r   r   r   <module>   s(   P&"