U
    Ph                     @  sz   d dl mZ d dlmZmZ d dlZd dlmZmZ d dl	m
Z
mZ d dlmZ d dlmZ d dlmZ G d	d
 d
ZdS )    )annotations)CallableSequenceN)decollate_batchlist_data_collate)SupervisedEvaluatorSupervisedTrainer)IterationEvents)Compose)
CommonKeysc                   @  s:   e Zd ZdZdddddddd	d
ZddddddZdS )Interactiona  
    Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
    For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
    This implementation is based on:

        Sakinis et al., Interactive segmentation of medical images through
        fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

    Args:
        transforms: execute additional transformation during every iteration (before train).
            Typically, several Tensor based transforms composed by `Compose`.
        max_interactions: maximum number of interactions per iteration
        train: training or evaluation
        key_probability: field name to fill probability for every interaction
    probabilityzSequence[Callable] | CallableintboolstrNone)
transformsmax_interactionstrainkey_probabilityreturnc                 C  s.   t |tst|}|| _|| _|| _|| _d S )N)
isinstancer
   r   r   r   r   )selfr   r   r   r    r   T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/deepgrow/interaction.py__init__*   s    
zInteraction.__init__z'SupervisedTrainer | SupervisedEvaluatorzdict[str, torch.Tensor]dict)engine	batchdatar   c           	      C  s*  |d krt dt| jD ] }||\}}||jj}|tj	 |j
  t B |jrtjj  |||j
}W 5 Q R X n|||j
}W 5 Q R X |tj |tj|i t|dd}tt|D ]<}| jrdd| j |  nd|| | j< | || ||< qt|}q|||S )Nz.Must provide batch data for current iteration.T)detachg      ?)
ValueErrorranger   prepare_batchtostatedevice
fire_eventr	   INNER_ITERATION_STARTEDnetworkevaltorchno_gradampcudaautocastinfererINNER_ITERATION_COMPLETEDupdater   PREDr   lenr   r   r   r   
_iteration)	r   r   r   jinputs_predictionsbatchdata_listir   r   r   __call__9   s*    


zInteraction.__call__N)r   )__name__
__module____qualname____doc__r   r;   r   r   r   r   r      s    r   )
__future__r   collections.abcr   r   r*   
monai.datar   r   monai.enginesr   r   monai.engines.utilsr	   monai.transformsr
   monai.utils.enumsr   r   r   r   r   r   <module>   s   