U
    Ph                     @  s   d Z ddlmZ ddlmZmZ ddlmZ ddlZ	ddl
Z
ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZ dgZdd ZG dd deZdS )z]
A collection of "vanilla" transforms for crop and pad operations acting on batches of data.
    )annotations)HashableMapping)AnyN)
MetaTensor)list_data_collate)CenterSpatialCrop
SpatialPad)InvertibleTransform)MethodPytorchPadMode	TraceKeysPadListDataCollatec                 C  s@   t || tr0t|| }| ||< t|||< n| || |< |S N)
isinstancetuplelist)
to_replacebatchidx
key_or_idxZbatch_idx_list r   S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/croppad/batch.pyreplace_element    s    r   c                   @  sN   e Zd ZdZejejfddddddZddd	d
Z	e
dddddZdS )r   a	  
    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
    tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
    different sizes.

    This can be used on both list and dictionary data.
    Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms
    if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms.

    Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`.
    This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the
    `inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can
    pass the inverse through multiprocessing.

    Args:
        method: padding method (see :py:class:`monai.transforms.SpatialPad`)
        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.

    strNone)methodmodereturnc                 K  s   || _ || _|| _d S r   )r   r   kwargs)selfr   r   r   r   r   r   __init__C   s    zPadListDataCollate.__init__r   )r   c                 C  s@  t |d t}|r"t|d  ntt|d }|D  ]}g }|D ]6}t || tjtj	fsb q||
|| jdd  qD|sq6t|jdd}tt|jdd|krq6tf || j| jd| j}t|D ]^\}	}
|
| jdd }||
| }t|||	|}|r| j||	 ||| j||	 |ddd qq6t|S )	zG
        Args:
            batch: batch of data to pad-collate
        r      N)axis)spatial_sizer   r   F)check)	orig_size
extra_info)r   dictr   keysrangelentorchTensornpndarrayappendshapearraymaxallminr	   r   r   r   	enumerater   push_transformpop_transformr   )r    r   Zis_list_of_dictsZ
batch_itemr   Z
max_shapeselem	max_shapepadderr   Zbatch_ir&   paddedr   r   r   __call__H   s4    $

zPadListDataCollate.__call__r(   zdict[Hashable, np.ndarray])datar   c              
   C  s   t | tstdt|  dt| }|D ]}d }t || trL|| j}nt|}||krf|| }|r*t |d tszq*|d 	t
jtjkr*| }t|	t
jd}|d ||| ||< W 5 Q R X q*|S )Nz@Inverse can only currently be applied on dictionaries, got type .F)r   r   RuntimeErrortyper(   r   applied_operationsr
   	trace_keygetr   
CLASS_NAMEr   __name__popr   	ORIG_SIZEtrace_transform)r>   dkey
transformstransform_keyxformZcroppingr   r   r   inverset   s$    

zPadListDataCollate.inverseN)rG   
__module____qualname____doc__r   	SYMMETRICr   CONSTANTr!   r=   staticmethodrP   r   r   r   r   r   ,   s
   ,)rS   
__future__r   collections.abcr   r   typingr   numpyr.   r,   monai.data.meta_tensorr   monai.data.utilsr   Zmonai.transforms.croppad.arrayr   r	   monai.transforms.inverser
   monai.utils.enumsr   r   r   __all__r   r   r   r   r   r   <module>   s   