o
    i                     @  s   d dl mZ d dlZd dlmZmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZ d d
lmZ d dlmZmZ d dlmZ g dZG dd de	ZG dd deZG dd deZ e  Z!Z"dS )    )annotationsN)CallableSequence)Any)Dataset)
DataLoader)KeysCollection)decollate_batchno_collationpad_list_data_collate)PadListDataCollate)InvertibleTransform)MapTransform	Transform)first)BatchInverseTransform
Decollated
DecollateDDecollateDictc                   @  s*   e Zd Zdd	d
ZdddZdddZdS )_BatchInverseDatasetdataSequence[Any]	transformr   pad_collation_usedboolreturnNonec                 C  s   || _ || _|| _d S N)r   invertible_transformr   )selfr   r   r    r    j/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/inverse_batch_transform.py__init__"   s   
z_BatchInverseDataset.__init__indexintc                 C  sD   t | j| }| jrt|}t| jtst	d |S | j|S )NzGtransform is not invertible, can't invert transform for the input data.)
dictr   r   r   inverse
isinstancer   r   warningswarn)r   r#   r   r    r    r!   __getitem__'   s   

z _BatchInverseDataset.__getitem__c                 C  s
   t | jS r   )lenr   )r   r    r    r!   __len__2   s   
z_BatchInverseDataset.__len__N)r   r   r   r   r   r   r   r   )r#   r$   )r   r$   )__name__
__module____qualname__r"   r*   r,   r    r    r    r!   r       s    

r   c                   @  s0   e Zd ZdZeddddfdddZdddZdS )r   z
    Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert
    them all.
    r   TNr   r   loaderTorchDataLoader
collate_fnCallable | Nonenum_workers
int | Nonedetachr   	pad_batchr   r   c                 C  sZ   || _ |j| _|du r|jn|| _|| _|| _|| _|| _|jjtjkp)t	|jt
| _dS )a  
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run `transforms` and generate the batch of data.
            collate_fn: how to collate data after inverse transformations.
                default won't do any collation, so the output will be a list of size batch size.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run 1 iteration and multi-processing may be even slower.
                if the transforms are really slow, set num_workers for multi-processing.
                if set to `None`, use the `num_workers` of the transform data loader.
            detach: whether to detach the tensors. Scalars tensors will be detached into number types
                instead of torch tensors.
            pad_batch: when the items in a batch indicate different batch size,
                whether to pad all the sequences to the longest.
                If False, the batch size will be the length of the shortest sequence.
            fill_value: the value to fill the padded sequences when `pad_batch=True`.

        N)r   
batch_sizer4   r2   r6   r7   
fill_value__doc__r   r'   r   r   )r   r   r0   r2   r4   r6   r7   r9   r    r    r!   r"   <   s   
zBatchInverseTransform.__init__r   dict[str, Any]r   c              
   C  s   t || j| j| jd}t|| j| j}t|| j| j	| j
d}zt|W S  ty@ } zt|}d|v r7|d7 }t||d }~ww )Nr6   padr9   )r8   r4   r2   z
equal sizezP
MONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`.)r	   r6   r7   r9   r   r   r   r   r8   r4   r2   r   RuntimeErrorstr)r   r   Zdecollated_dataZinv_dsZ
inv_loaderrere_strr    r    r!   __call__c   s   

zBatchInverseTransform.__call__)r   r   r0   r1   r2   r3   r4   r5   r6   r   r7   r   r   r   )r   r;   r   r   )r-   r.   r/   r:   r
   r"   rB   r    r    r    r!   r   6   s    	'r   c                      s8   e Zd ZdZ					dd fddZdddZ  ZS )r   a?  
    Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys.
    Note that unlike most MapTransforms, it will delete the other keys that are not specified.
    if `keys=None`, it will decollate all the data in the input.
    It replicates the scalar values to every item of the decollated list.

    Args:
        keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.
            if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.
        detach: whether to detach the tensors. Scalars tensors will be detached into number types
            instead of torch tensors.
        pad_batch: when the items in a batch indicate different batch size,
            whether to pad all the sequences to the longest.
            If False, the batch size will be the length of the shortest sequence.
        fill_value: the value to fill the padded sequences when `pad_batch=True`.
        allow_missing_keys: don't raise exception if key is missing.

    NTFkeysKeysCollection | Noner6   r   r7   allow_missing_keysr   r   c                   s$   t  || || _|| _|| _d S r   )superr"   r6   r7   r9   )r   rC   r6   r7   r9   rE   	__class__r    r!   r"      s   
zDecollated.__init__r   dict | listc                 C  sj   t | jdkr| jd d u r|}nt|tstdi }| |D ]}|| ||< q!t|| j| j| j	dS )N   r   z@input data is not a dictionary, but specified keys to decollate.r<   )
r+   rC   r'   r%   	TypeErrorkey_iteratorr	   r6   r7   r9   )r   r   dkeyr    r    r!   rB      s   
zDecollated.__call__)NTTNF)
rC   rD   r6   r   r7   r   rE   r   r   r   )r   rI   )r-   r.   r/   r:   r"   rB   __classcell__r    r    rG   r!   r   r   s    r   )#
__future__r   r(   collections.abcr   r   typingr   torch.utils.datar   torch.utils.data.dataloaderr   r1   monai.configr   Zmonai.data.dataloadermonai.data.utilsr	   r
   r   monai.transforms.croppad.batchr   monai.transforms.inverser   monai.transforms.transformr   r   monai.utilsr   __all__r   r   r   r   r   r    r    r    r!   <module>   s$   <0