U
    PhU&                     @  sZ   d dl mZ d dlmZ d dlZd dlmZ d dlmZ ed\Z	Z
G dd dejZdS )    )annotations)castN)optional_importztorchvision.modelsc                
      sZ   e Zd ZdZdddd	d
ddddd fddZdddddZddd	ddddZ  ZS )MILModela  
    Multiple Instance Learning (MIL) model, with a backbone classification model.
    Currently, it only works for 2D images, a typical use case is for classification of the
    digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`,
    where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
    extracted from every original image in the batch. A tutorial example is available at:
    https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.

    Args:
        num_classes: number of output classes.
        mil_mode: MIL algorithm, available values (Defaults to ``"att"``):

            - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL).
            - ``"max"`` - retain only the instance with the max probability for loss calculation.
            - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712.
            - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
            - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.

        pretrained: init backbone with pretrained weights, defaults to ``True``.
        backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
            or a string name of a torchvision model).
            Defaults to ``None``, in which case ResNet50 is used.
        backbone_num_features: Number of output features of the backbone CNN
            Defaults to ``None`` (necessary only when using a custom backbone)
        trans_blocks: number of the blocks in `TransformEncoder` layer.
        trans_dropout: dropout rate in `TransformEncoder` layer.

    attTN           intstrboolzstr | nn.Module | Nonez
int | NonefloatNone)num_classesmil_mode
pretrainedbackbonebackbone_num_featurestrans_blockstrans_dropoutreturnc                   s2  t    |dkr"tdt| | dkr>tdt| |  _t  _d  _	|d krt
j|d}|jj}	tj |_i  _|dkrڇ fdd}
|j|
d	 |j|
d
 |j|
d |j|
d nt|trTtt
|d }|d krtdt| ||d}t|dd d k	r@|jj}	tj |_ntdt| dn2t|tjr~|}|}	|d krtdntd|d k	r|dkrtdt|  jdkrn` jdkrtt|	dt tdd _n, jdkr@tj|	d|d}tj||d _	tt|	dt tdd _nڈ jdkr
ttjtjdd|d|dttddtjtjdd|d|dttddtjtjdd|d|dtjtjd d|d|dg}| _	|	d }	tt|	dt tdd _ntdt| t|	| _| _d S )!Nr   z$Number of classes must be positive: )meanmaxr   	att_transatt_trans_pyramidzUnsupported mil_mode: )r   r   c                   s    fdd}|S )Nc                   s   |j  < d S )N)extra_outputs)moduleinputoutput)
layer_nameself Q/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/milmodel.pyhookW   s    z5MILModel.__init__.<locals>.forward_hook.<locals>.hookr    )r   r"   r   )r   r!   forward_hookU   s    z'MILModel.__init__.<locals>.forward_hooklayer1layer2layer3layer4zUnknown torch vision modelfcz4Unable to detect FC layer for the torchvision model z0. Please initialize the backbone model manually.zJNumber of endencoder features must be provided for a custom backbone modelzUnsupported backbone)r   r   r   r   z.Custom backbone is not supported for the mode:)r   r   r   i      r      )d_modelnheaddropout)
num_layers   i   i   i 	  ) super__init__
ValueErrorr
   lowerr   nn
Sequential	attentiontransformermodelsresnet50r)   in_featurestorchIdentityr   r%   register_forward_hookr&   r'   r(   
isinstancegetattrModuleLinearTanhTransformerEncoderLayerTransformerEncoder
ModuleListmyfcnet)r   r   r   r   r   r   r   r   rH   nfcr$   Ztorch_modelr8   transformer_list	__class__r#   r!   r2   6   s    








(& 

&zMILModel.__init__ztorch.Tensor)xr   c           
      C  s  |j }| jdkr,| |}tj|dd}nx| jdkrV| |}tj|dd\}}nN| jdkr| |}tj|dd}tj|| dd}| |}n| jdkr| j	d k	r|
ddd}| 	|}|
ddd}| |}tj|dd}tj|| dd}| |}n| jd	kr| j	d k	rtj| jd
 dd|d |d d
ddd}tj| jd dd|d |d d
ddd}tj| jd dd|d |d d
ddd}tj| jd dd|d |d d
ddd}ttj| j	}	|	d |}|	d tj||fdd}|	d tj||fdd}|	d tj||fdd}|
ddd}| |}tj|dd}tj|| dd}| |}ntdt| j |S )Nr   r*   )dimr   r   r   r      r   r%   )rO      r&   r'   r(   rP   zWrong model mode)shaper   rG   r<   r   r   r7   softmaxsumr8   permuter   reshaper   r5   rF   catr3   r
   )
r   rM   sh_al1l2l3l4rJ   r    r    r!   	calc_head   sJ    







0000
zMILModel.calc_headF)rM   no_headr   c                 C  s`   |j }||d |d  |d |d |d }| |}||d |d d}|s\| |}|S )Nr   r*   rO   rP   r   rQ   )rR   rV   rH   r_   )r   rM   r`   rX   r    r    r!   forward   s    (

zMILModel.forward)r   TNNr   r   )F)__name__
__module____qualname____doc__r2   r_   ra   __classcell__r    r    rK   r!   r      s          "w7r   )
__future__r   typingr   r<   torch.nnr5   monai.utils.moduler   r9   rY   rA   r   r    r    r    r!   <module>   s   