o
    ix	                     @  sn   d dl mZ d dlmZ d dlmZ d dlZd dlmZm	Z	 d dl
mZ d dlmZ dgZG d	d deZdS )
    )annotations)Sequence)AnyN)PrepareBatchPrepareBatchExtraInput)ensure_tuple)HoVerNetBranchPrepareBatchHoVerNetc                   @  s*   e Zd ZdZdddZ			ddddZdS )r	   a  
    Customized prepare batch callable for trainers or evaluators which support label to be a dictionary.
    Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).
    This assumes label is a dictionary.

    Args:
        extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from
            those keys and passed to the network as extra positional arguments.
    
extra_keysSequence[str]returnNonec                 C  s4   t t|dkrtdt t| t|| _d S )N   z(length of `extra_keys` should be 2, get )lenr   
ValueErrorr   prepare_batch)selfr
    r   d/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/pathology/engines/utils.py__init__%   s   zPrepareBatchHoVerNet.__init__NF	batchdatadict[str, torch.Tensor]devicestr | torch.device | Nonenon_blockingboolkwargsr   7tuple[torch.Tensor, dict[HoVerNetBranch, torch.Tensor]]c           
      K  sD   | j |||fi |\}}}}tj|tj|d tj|d i}	||	fS )z
        Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
        https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
        `kwargs` supports other args for `Tensor.to()` API.
        r      )r   r   NPNCHV)
r   r   r   r   r   image_labelZextra_label_labelr   r   r   __call__*   s   zPrepareBatchHoVerNet.__call__)r
   r   r   r   )NF)
r   r   r   r   r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r&   r   r   r   r   r	      s    

)
__future__r   collections.abcr   typingr   torchZmonai.enginesr   r   monai.utilsr   monai.utils.enumsr   __all__r	   r   r   r   r   <module>   s   