o
    )iæ  ã                   @  sJ   d dl mZ d dlZd dlmZ d dlmZ ddiZG dd„ dejƒZdS )	é    )ÚannotationsN)Únn)Ú	model_zooÚ clip_encoding_universal_model_32zphttps://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_universal_model.pthc                      s6   e Zd ZdZ					dd‡ fdd„Zdd„ Z‡  ZS )ÚTextEncoderaV  
    Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding.
    The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models.

    Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al.,
    Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>"

    Connecting text and medical 3D image, based on: "Liu et al.,
    CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
    é   é   é   r   TÚout_channelsÚintÚspatial_dimsÚtext_dimÚhidden_sizeÚencodingÚstrÚ
pretrainedÚboolÚreturnÚNonec           	        s¢   t ƒ  ¡  || _|| _|dvrtdƒ‚| jdkr!t ||¡| _dS |  dt	 
||¡¡ |r@t| j }tj|dd}| ¡ | j_nt| j› dƒ t ||¡| _dS )	a#  
        Args:
            out_channels: number of output channels, to control text-based embedding for classes.
            spatial_dims: number of spatial dims.
            text_dim: dimension of text embeddings.
            hidden_size: dimension of hidden features, compatible to different vision feature dimensions.
            encoding: the text embedding type, default to use clip text pretrained weights.
            pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False.
        )é   r   z#spatial dimension should be 2 or 3.Úrand_embeddingÚtext_embeddingÚcpu)Úmap_locationzD is not implemented, and can not be downloaded, please load your ownN)ÚsuperÚ__init__r   r   Ú
ValueErrorr   Ú	Embeddingr   Úregister_bufferÚtorchÚrandnÚurl_mapr   Úload_urlÚfloatÚdataÚprintÚLinearÚtext_to_vision)	Úselfr
   r   r   r   r   r   Ú	model_urlÚpretrain_state_dict©Ú	__class__© úf/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/text_embedding.pyr   &   s   


zTextEncoder.__init__c                 C  st   | j dkr
| jj}nt| jƒ tj |  | j¡¡}| jdkr+| 	d¡ 	d¡ 	d¡}|S | jdkr8| 	d¡ 	d¡}|S )Nr   r   r   )
r   r   Úweightr%   r   Ú
functionalÚrelur'   r   Ú	unsqueeze)r(   r   r-   r-   r.   ÚforwardM   s   




ýzTextEncoder.forward)r   r   r	   r   T)r
   r   r   r   r   r   r   r   r   r   r   r   r   r   )Ú__name__Ú
__module__Ú__qualname__Ú__doc__r   r3   Ú__classcell__r-   r-   r+   r.   r      s    ù'r   )	Ú
__future__r   r   r   Útorch.utilsr   r!   ÚModuler   r-   r-   r-   r.   Ú<module>   s   þ