U
    Ph]	                     @  sf   d dl mZ d dlmZmZ d dlZd dlmZmZ d dl	m
Z
 d dlmZ dgZG dd deZdS )	    )annotations)AnySequenceN)PrepareBatchPrepareBatchExtraInput)ensure_tuple)HoVerNetBranchPrepareBatchHoVerNetc                   @  s8   e Zd ZdZdddddZdd	d
ddddddZdS )r	   a  
    Customized prepare batch callable for trainers or evaluators which support label to be a dictionary.
    Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).
    This assumes label is a dictionary.

    Args:
        extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from
            those keys and passed to the network as extra positional arguments.
    zSequence[str]None)
extra_keysreturnc                 C  s4   t t|dkr&tdt t| t|| _d S )N   z(length of `extra_keys` should be 2, get )lenr   
ValueErrorr   prepare_batch)selfr    r   W/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/pathology/engines/utils.py__init__$   s    zPrepareBatchHoVerNet.__init__NFzdict[str, torch.Tensor]zstr | torch.device | Noneboolr   z7tuple[torch.Tensor, dict[HoVerNetBranch, torch.Tensor]])	batchdatadevicenon_blockingkwargsr   c           
      K  s@   | j |||f|\}}}}tj|tj|d tj|d i}	||	fS )z
        Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
        https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
        `kwargs` supports other args for `Tensor.to()` API.
        r      )r   r   NPNCHV)
r   r   r   r   r   image_labelZextra_label_labelr   r   r   __call__)   s    zPrepareBatchHoVerNet.__call__)NF)__name__
__module____qualname____doc__r   r"   r   r   r   r   r	      s
   
  )
__future__r   typingr   r   torchZmonai.enginesr   r   monai.utilsr   monai.utils.enumsr   __all__r	   r   r   r   r   <module>   s   