o
    i                     @  s   d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	m
Z
 d dlmZmZ d dlmZ d dlmZ d dlmZ G d	d
 d
ZdS )    )annotations)CallableSequenceN)decollate_batchlist_data_collate)SupervisedEvaluatorSupervisedTrainer)IterationEvents)Compose)
CommonKeysc                   @  s,   e Zd ZdZ			ddddZdddZdS )InteractionaC  
    Ignite process_function used to introduce interactions (simulation of clicks) for DeepEdit Training/Evaluation.

    More details about this can be found at:

        Diaz-Pinto et al., MONAI Label: A framework for AI-assisted Interactive
        Labeling of 3D Medical Images. (2022) https://arxiv.org/abs/2203.12362

    Args:
        deepgrow_probability: probability of simulating clicks in an iteration
        transforms: execute additional transformation during every iteration (before train).
            Typically, several Tensor based transforms composed by `Compose`.
        train: True for training mode or False for evaluation mode
        click_probability_key: key to click/interaction probability
        label_names: Dict of label names
        max_interactions: maximum number of interactions per iteration
    Nprobability   deepgrow_probabilityfloat
transformsSequence[Callable] | Callabletrainboollabel_namesNone | dict[str, int]click_probability_keystrmax_interactionsintreturnNonec                 C  s:   || _ t|tst|n|| _|| _|| _|| _|| _d S )N)r   
isinstancer
   r   r   r   r   r   )selfr   r   r   r   r   r    r   a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/deepedit/interaction.py__init__-   s   	
zInteraction.__init__engine'SupervisedTrainer | SupervisedEvaluator	batchdatadict[str, torch.Tensor]dictc           	   
   C  s  |d u rt dtjjddg| jd| j gdrt| jD ]}||\}}||j	j
}|tj |j  t / |jr]td |||j}W d    n1 sWw   Y  n|||j}W d    n1 snw   Y  |tj|i t|dd}tt|D ]}| jrdd| j |  nd|| | j< | || ||< qt|}|tj qn&t|dd}tdt|d	 tj D ]}|d	 tj |  d	9  < qt|}||j	_ |!||S )
Nz.Must provide batch data for current iteration.TFr   )pcuda)detachg      ?r   )"
ValueErrornprandomchoicer   ranger   prepare_batchtostatedevice
fire_eventr	   INNER_ITERATION_STARTEDnetworkevaltorchno_gradampautocastinfererupdater   PREDr   lenr   r   r   r   INNER_ITERATION_COMPLETEDIMAGEbatch
_iteration)	r   r"   r$   jinputs_predictionsZbatchdata_listir   r   r    __call__=   s@    

zInteraction.__call__)Nr   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r"   r#   r$   r%   r   r&   )__name__
__module____qualname____doc__r!   rH   r   r   r   r    r      s    r   )
__future__r   collections.abcr   r   numpyr+   r7   
monai.datar   r   monai.enginesr   r   monai.engines.utilsr	   monai.transformsr
   monai.utils.enumsr   r   r   r   r   r    <module>   s   