U
    Ph                     @  s   d dl mZ d dlZd dlmZmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZ d d
lmZ d dlmZmZ d dlmZ ddddgZG dd de	ZG dd deZG dd deZ e  Z!Z"dS )    )annotationsN)CallableSequence)Any)Dataset)
DataLoader)KeysCollection)decollate_batchno_collationpad_list_data_collate)PadListDataCollate)InvertibleTransform)MapTransform	Transform)firstBatchInverseTransform
Decollated
DecollateDDecollateDictc                   @  s<   e Zd ZdddddddZdd	d
dZddddZdS )_BatchInverseDatasetzSequence[Any]r   boolNone)data	transformpad_collation_usedreturnc                 C  s   || _ || _|| _d S N)r   invertible_transformr   )selfr   r   r    r   ]/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/inverse_batch_transform.py__init__"   s    z_BatchInverseDataset.__init__int)indexc                 C  sD   t | j| }| jrt|}t| jts8t	d |S | j|S )NzGtransform is not invertible, can't invert transform for the input data.)
dictr   r   r   inverse
isinstancer   r   warningswarn)r   r#   r   r   r   r    __getitem__'   s    

z _BatchInverseDataset.__getitem__)r   c                 C  s
   t | jS r   )lenr   )r   r   r   r    __len__2   s    z_BatchInverseDataset.__len__N)__name__
__module____qualname__r!   r)   r+   r   r   r   r    r       s   r   c                	   @  sF   e Zd ZdZeddddfddddd	d	d
dddZdddddZdS )r   z
    Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert
    them all.
    r   TNr   TorchDataLoaderzCallable | Nonez
int | Noner   r   )r   loader
collate_fnnum_workersdetach	pad_batchr   c                 C  sZ   || _ |j| _|dkr|jn|| _|| _|| _|| _|| _|jjtjkpRt	|jt
| _dS )a  
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run `transforms` and generate the batch of data.
            collate_fn: how to collate data after inverse transformations.
                default won't do any collation, so the output will be a list of size batch size.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run 1 iteration and multi-processing may be even slower.
                if the transforms are really slow, set num_workers for multi-processing.
                if set to `None`, use the `num_workers` of the transform data loader.
            detach: whether to detach the tensors. Scalars tensors will be detached into number types
                instead of torch tensors.
            pad_batch: when the items in a batch indicate different batch size,
                whether to pad all the sequences to the longest.
                If False, the batch size will be the length of the shortest sequence.
            fill_value: the value to fill the padded sequences when `pad_batch=True`.

        N)r   
batch_sizer2   r1   r3   r4   
fill_value__doc__r   r&   r   r   )r   r   r0   r1   r2   r3   r4   r6   r   r   r    r!   <   s     zBatchInverseTransform.__init__zdict[str, Any]r   )r   r   c              
   C  s   t || j| j| jd}t|| j| j}t|| j| j	| j
d}z
t|W S  tk
r } z&t|}d|krp|d7 }t||W 5 d }~X Y nX d S )Nr3   padr6   )r5   r2   r1   z
equal sizezP
MONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`.)r	   r3   r4   r6   r   r   r   r   r5   r2   r1   r   RuntimeErrorstr)r   r   Zdecollated_dataZinv_dsZ
inv_loaderrere_strr   r   r    __call__c   s       
zBatchInverseTransform.__call__)r,   r-   r.   r7   r
   r!   r>   r   r   r   r    r   6   s   	'c                      s>   e Zd ZdZddddddd fd	d
ZddddZ  ZS )r   a?  
    Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys.
    Note that unlike most MapTransforms, it will delete the other keys that are not specified.
    if `keys=None`, it will decollate all the data in the input.
    It replicates the scalar values to every item of the decollated list.

    Args:
        keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.
            if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.
        detach: whether to detach the tensors. Scalars tensors will be detached into number types
            instead of torch tensors.
        pad_batch: when the items in a batch indicate different batch size,
            whether to pad all the sequences to the longest.
            If False, the batch size will be the length of the shortest sequence.
        fill_value: the value to fill the padded sequences when `pad_batch=True`.
        allow_missing_keys: don't raise exception if key is missing.

    NTFzKeysCollection | Noner   r   )keysr3   r4   allow_missing_keysr   c                   s$   t  || || _|| _|| _d S r   )superr!   r3   r4   r6   )r   r?   r3   r4   r6   r@   	__class__r   r    r!      s    zDecollated.__init__zdict | list)r   c                 C  sj   t | jdkr"| jd d kr"|}n2t|ts4tdi }| |D ]}|| ||< qBt|| j| j| j	dS )N   r   z@input data is not a dictionary, but specified keys to decollate.r8   )
r*   r?   r&   r$   	TypeErrorkey_iteratorr	   r3   r4   r6   )r   r   dkeyr   r   r    r>      s    
zDecollated.__call__)NTTNF)r,   r-   r.   r7   r!   r>   __classcell__r   r   rB   r    r   r   s        )#
__future__r   r'   collections.abcr   r   typingr   torch.utils.datar   torch.utils.data.dataloaderr   r/   monai.configr   Zmonai.data.dataloadermonai.data.utilsr	   r
   r   monai.transforms.croppad.batchr   monai.transforms.inverser   monai.transforms.transformr   r   monai.utilsr   __all__r   r   r   r   r   r   r   r   r    <module>   s"   <0