o
    ,i&                     @  sZ   d dl mZ d dlmZ d dlZd dlmZ d dlmZ ed\Z	Z
G dd dejZdS )    )annotations)castN)optional_importztorchvision.modelsc                      sF   e Zd ZdZ						d d! fddZd"ddZd#d$ddZ  ZS )%MILModela  
    Multiple Instance Learning (MIL) model, with a backbone classification model.
    Currently, it only works for 2D images, a typical use case is for classification of the
    digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`,
    where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
    extracted from every original image in the batch. A tutorial example is available at:
    https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.

    Args:
        num_classes: number of output classes.
        mil_mode: MIL algorithm, available values (Defaults to ``"att"``):

            - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL).
            - ``"max"`` - retain only the instance with the max probability for loss calculation.
            - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712.
            - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
            - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.

        pretrained: init backbone with pretrained weights, defaults to ``True``.
        backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
            or a string name of a torchvision model).
            Defaults to ``None``, in which case ResNet50 is used.
        backbone_num_features: Number of output features of the backbone CNN
            Defaults to ``None`` (necessary only when using a custom backbone)
        trans_blocks: number of the blocks in `TransformEncoder` layer.
        trans_dropout: dropout rate in `TransformEncoder` layer.
    attTN           num_classesintmil_modestr
pretrainedboolbackbonestr | nn.Module | Nonebackbone_num_features
int | Nonetrans_blockstrans_dropoutfloatreturnNonec                   s.  t    |dkrtdt| | dvrtdt| |  _t  _d  _	|d u rtt
j|r8t
jjnd d}|jj}	tj |_i  _|dkrs fdd}
|j|
d	 |j|
d
 |j|
d |j|
d nSt|trtt
|d }|d u rtdt| ||rdnd d}t|dd d ur|jj}	tj |_n tdt| dt|tjr|}|}	|d u rtdntd|d ur|dvrtdt|  jdv rn jdkrtt|	dt tdd _n jdkrtj|	d|d}tj||d _	tt|	dt tdd _nm jdkrttjtjdd|d|dttddtjtjdd|d|dttd dtjtjdd|d|dtjtjd!d|d|dg}| _	|	d }	tt|	dt tdd _ntdt| t|	| _ | _!d S )"Nr   z$Number of classes must be positive: )meanmaxr   	att_transatt_trans_pyramidzUnsupported mil_mode: )weightsr   c                   s    fdd}|S )Nc                   s   |j  < d S )N)extra_outputs)moduleinputoutput)
layer_nameself ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/milmodel.pyhookV   s   z5MILModel.__init__.<locals>.forward_hook.<locals>.hookr#   )r!   r%   r"   )r!   r$   forward_hookT   s   z'MILModel.__init__.<locals>.forward_hooklayer1layer2layer3layer4zUnknown torch vision modelDEFAULTfcz4Unable to detect FC layer for the torchvision model z0. Please initialize the backbone model manually.zJNumber of endencoder features must be provided for a custom backbone modelzUnsupported backbone)r   r   r   r   z.Custom backbone is not supported for the mode:)r   r   r   i      r      )d_modelnheaddropout)
num_layers   i   i   i 	  )"super__init__
ValueErrorr   lowerr   nn
Sequential	attentiontransformermodelsresnet50ResNet50_WeightsIMAGENET1K_V1r-   in_featurestorchIdentityr   r(   register_forward_hookr)   r*   r+   
isinstancegetattrModuleLinearTanhTransformerEncoderLayerTransformerEncoder
ModuleListmyfcnet)r"   r	   r   r   r   r   r   r   rN   nfcr'   Ztorch_modelr<   transformer_list	__class__r&   r$   r6   5   s   







&&

&
zMILModel.__init__xtorch.Tensorc           
      C  s  |j }| jdkr| |}tj|dd}|S | jdkr+| |}tj|dd\}}|S | jdkrL| |}tj|dd}tj|| dd}| |}|S | jdkr| j	d ur|
ddd}| 	|}|
ddd}| |}tj|dd}tj|| dd}| |}|S | jd	krH| j	d urHtj| jd
 dd|d |d d
ddd}tj| jd dd|d |d d
ddd}tj| jd dd|d |d d
ddd}tj| jd dd|d |d d
ddd}ttj| j	}	|	d |}|	d tj||fdd}|	d tj||fdd}|	d tj||fdd}|
ddd}| |}tj|dd}tj|| dd}| |}|S tdt| j )Nr   r.   )dimr   r   r   r      r   r(   )rV      r)   r*   r+   rW   zWrong model mode)shaper   rM   rB   r   r   r;   softmaxsumr<   permuter   reshaper   r9   rL   catr7   r   )
r"   rS   sh_al1l2l3l4rP   r#   r#   r$   	calc_head   sR   

0

,


%


0000

zMILModel.calc_headFno_headc                 C  s`   |j }||d |d  |d |d |d }| |}||d |d d}|s.| |}|S )Nr   r.   rV   rW   r   rX   )rY   r]   rN   rf   )r"   rS   rg   r_   r#   r#   r$   forward   s   (

zMILModel.forward)r   TNNr   r   )r	   r
   r   r   r   r   r   r   r   r   r   r
   r   r   r   r   )rS   rT   r   rT   )F)rS   rT   rg   r   r   rT   )__name__
__module____qualname____doc__r6   rf   rh   __classcell__r#   r#   rQ   r$   r      s    
w7r   )
__future__r   typingr   rB   torch.nnr9   monai.utilsr   r=   r`   rG   r   r#   r#   r#   r$   <module>   s   