U
    Ph                     @  s   d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	m
Z
 d dlmZmZ d dlmZ d dlmZ d dlmZ G d	d
 d
ZdS )    )annotations)CallableSequenceN)decollate_batchlist_data_collate)SupervisedEvaluatorSupervisedTrainer)IterationEvents)Compose)
CommonKeysc                	   @  s>   e Zd ZdZdddddd	d
ddddZddddddZdS )InteractionaC  
    Ignite process_function used to introduce interactions (simulation of clicks) for DeepEdit Training/Evaluation.

    More details about this can be found at:

        Diaz-Pinto et al., MONAI Label: A framework for AI-assisted Interactive
        Labeling of 3D Medical Images. (2022) https://arxiv.org/abs/2203.12362

    Args:
        deepgrow_probability: probability of simulating clicks in an iteration
        transforms: execute additional transformation during every iteration (before train).
            Typically, several Tensor based transforms composed by `Compose`.
        train: True for training mode or False for evaluation mode
        click_probability_key: key to click/interaction probability
        label_names: Dict of label names
        max_interactions: maximum number of interactions per iteration
    Nprobability   floatzSequence[Callable] | CallableboolzNone | dict[str, int]strintNone)deepgrow_probability
transformstrainlabel_namesclick_probability_keymax_interactionsreturnc                 C  s:   || _ t|tst|n|| _|| _|| _|| _|| _d S )N)r   
isinstancer
   r   r   r   r   r   )selfr   r   r   r   r   r    r   T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/deepedit/interaction.py__init__-   s    	zInteraction.__init__z'SupervisedTrainer | SupervisedEvaluatorzdict[str, torch.Tensor]dict)engine	batchdatar   c           	      C  s  |d krt dtjjddg| jd| j gdrDt| jD ]}||\}}||j	j
}|tj |j  t B |jrtjj  |||j}W 5 Q R X n|||j}W 5 Q R X |tj|i t|dd}tt|D ]>}| jr
dd| j |  nd|| | j< | || ||< qt|}|tj q<nNt|dd}tdt|d tj  D ] }|d tj  |  d9  < qht|}||j	_!|"||S )	Nz.Must provide batch data for current iteration.TFr   )p)detachg      ?r   )#
ValueErrornprandomchoicer   ranger   prepare_batchtostatedevice
fire_eventr	   INNER_ITERATION_STARTEDnetworkevaltorchno_gradampcudaautocastinfererupdater   PREDr   lenr   r   r   r   INNER_ITERATION_COMPLETEDIMAGEbatch
_iteration)	r   r!   r"   jinputs_predictionsZbatchdata_listir   r   r   __call__=   s6    "

zInteraction.__call__)Nr   r   )__name__
__module____qualname____doc__r   rD   r   r   r   r   r      s      r   )
__future__r   collections.abcr   r   numpyr&   r2   
monai.datar   r   monai.enginesr   r   monai.engines.utilsr	   monai.transformsr
   monai.utils.enumsr   r   r   r   r   r   <module>   s   