o
    i                     @  sz   d dl mZ d dlmZmZ d dlZd dlmZmZ d dl	m
Z
mZ d dlmZ d dlmZ d dlmZ G d	d
 d
ZdS )    )annotations)CallableSequenceN)decollate_batchlist_data_collate)SupervisedEvaluatorSupervisedTrainer)IterationEvents)Compose)
CommonKeysc                   @  s(   e Zd ZdZ	ddddZdddZdS )Interactiona  
    Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
    For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
    This implementation is based on:

        Sakinis et al., Interactive segmentation of medical images through
        fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

    Args:
        transforms: execute additional transformation during every iteration (before train).
            Typically, several Tensor based transforms composed by `Compose`.
        max_interactions: maximum number of interactions per iteration
        train: training or evaluation
        key_probability: field name to fill probability for every interaction
    probability
transformsSequence[Callable] | Callablemax_interactionsinttrainboolkey_probabilitystrreturnNonec                 C  s.   t |ts	t|}|| _|| _|| _|| _d S )N)
isinstancer
   r   r   r   r   )selfr   r   r   r    r   a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/deepgrow/interaction.py__init__*   s   

zInteraction.__init__engine'SupervisedTrainer | SupervisedEvaluator	batchdatadict[str, torch.Tensor]dictc           	   
   C  sN  |d u rt dt| jD ]}||\}}||jj}|tj	 |j
  t / |jrMtd |||j
}W d    n1 sGw   Y  n|||j
}W d    n1 s^w   Y  |tj |tj|i t|dd}tt|D ]}| jrdd| j |  nd|| | j< | || ||< q}t|}q|||S )Nz.Must provide batch data for current iteration.cudaT)detachg      ?)
ValueErrorranger   prepare_batchtostatedevice
fire_eventr	   INNER_ITERATION_STARTEDnetworkevaltorchno_gradampautocastinfererINNER_ITERATION_COMPLETEDupdater   PREDr   lenr   r   r   r   
_iteration)	r   r   r   jinputs_predictionsbatchdata_listir   r   r   __call__9   s2   


zInteraction.__call__N)r   )
r   r   r   r   r   r   r   r   r   r   )r   r   r   r    r   r!   )__name__
__module____qualname____doc__r   r>   r   r   r   r   r      s
    r   )
__future__r   collections.abcr   r   r.   
monai.datar   r   monai.enginesr   r   monai.engines.utilsr	   monai.transformsr
   monai.utils.enumsr   r   r   r   r   r   <module>   s   