U
    Ph4&                     @  s.  d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	 d dl
Zd dlZd dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ d dlmZ d dlmZ d dlmZ d dl m!Z!m"Z" d dl#m$Z$m%Z%m&Z& erd dl'm'Z' dZ(ne&ddd\Z'Z(dgZ)e%* Z+dd Z,G dd dZ-dS )    )annotationsN)Callabledeepcopy)TYPE_CHECKINGAny)NdarrayOrTensor)
DataLoader)Dataset)decollate_batchpad_list_data_collate)Compose)PadListDataCollate)InvertibleTransform)Invertd)Randomizable)modestack)
CommonKeysPostFixoptional_import)tqdmTr   )nameTestTimeAugmentationc                 C  s   | S )N )xr   r   V/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/test_time_augmentation.py	_identity.   s    r   c                   @  sr   e Zd ZdZdedejejddeddeddfdddd	d
dddd
d	ddddddZ	dd Z
dddddddZdS )r   a  
    Class for performing test time augmentations. This will pass the same image through the network multiple times.

    The user passes transform(s) to be applied to each realization, and provided that at least one of those transforms
    is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial
    transforms, the inverse can be applied to each realization of the network's output. Once in the same spatial
    reference, the results can then be combined and metrics computed.

    Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's
    dependency on the applied random transforms.

    Reference:
        Wang et al.,
        Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional
        neural networks,
        https://doi.org/10.1016/j.neucom.2019.01.103

    Args:
        transform: transform (or composed) to be applied to each realization. At least one transform must be of type
        `RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
            . All random transforms must be of type `InvertibleTransform`.
        batch_size: number of realizations to infer at once.
        num_workers: how many subprocesses to use for data.
        inferrer_fn: function to use to perform inference.
        device: device on which to perform inference.
        image_key: key used to extract image from input dictionary.
        orig_key: the key of the original input data in the dict. will get the applied transform information
            for this input data, then invert them for the expected data with `image_key`.
        orig_meta_keys: the key of the metadata of original input data, will get the `affine`, `data_shape`, etc.
            the metadata is a dictionary object which contains: filename, original_shape, etc.
            if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
        meta_key_postfix: use `key_{postfix}` to fetch the metadata according to the key data,
            default is `meta_dict`, the metadata is a dictionary object.
            For example, to handle key `image`,  read/write affine matrices from the
            metadata `image_meta_dict` dictionary's `affine` field.
            this arg only works when `meta_keys=None`.
        to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
        output_device: if converted the inverted data to Tensor, move the inverted results to target device
            before `post_func`, default to "cpu".
        post_func: post processing for the inverted data, should be a callable function.
        return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True`
            will return the full data. Dimensions will be same size as when passing a single image through
            `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
        progress: whether to display a progress bar.

    Example:
        .. code-block:: python

            model = UNet(...).to(device)
            transform = Compose([RandAffined(keys, ...), ...])
            transform.set_random_state(seed=123)  # ensure deterministic evaluation

            tt_aug = TestTimeAugmentation(
                transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device
            )
            mode, mean, std, vvc = tt_aug(test_data)
    r   cpuTNFr   intr   zstr | torch.deviceboolz
str | NoneNone)	transform
batch_sizenum_workersinferrer_fndevicenearest_interporig_meta_keys	to_tensoroutput_device	post_funcreturn_full_dataprogressreturnc                 C  sb   || _ || _|| _|| _|| _|| _|| _|| _tj	| _
t| j
|||	|
||||d	| _|   d S )N)	keysr"   	orig_keysr(   meta_key_postfixr'   r)   r&   r+   )r"   r#   r$   r%   r&   	image_keyr,   r-   r   PRED	_pred_keyr   inverter_check_transforms)selfr"   r#   r$   r%   r&   r2   orig_keyr'   r(   r1   r)   r*   r+   r,   r-   r   r   r   __init__m   s*    zTestTimeAugmentation.__init__c                 C  s   t | jts| jgn| jj}tdd |D }tdd |D }t|dkrZtd t	||D ]&\}}|rd|sdtdt
|j  qddS )zVShould be at least 1 random transform, and all random transforms should be invertible.c                 S  s   g | ]}t |tqS r   )
isinstancer   .0tr   r   r   
<listcomp>   s     z:TestTimeAugmentation._check_transforms.<locals>.<listcomp>c                 S  s   g | ]}t |tqS r   )r:   r   r;   r   r   r   r>      s     r   zdTTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms.zKNot all applied random transform(s) are invertible. Problematic transform: N)r:   r"   r   
transformsnparraysumwarningswarnziptype__name__)r7   tsZrandomsZinvertiblesrir   r   r   r6      s    z&TestTimeAugmentation._check_transforms
   zdict[str, Any]zQtuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float] | NdarrayOrTensor)datanum_examplesr.   c                   s  t | |j dkrtd fddt|D }t|j}t|jjtd}g }t	rjj
rjt|n|D ]>}|j j|j< |fddt|D  qnt|d}jr|S t|dd}	|d}
|d}| |   }|	|
||fS )a  
        Args:
            data: dictionary data to be processed.
            num_examples: number of realizations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are
                calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)
                is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        r   z.num_examples should be multiple of batch size.c                   s   g | ]}t  qS r   r   )r<   _)dr   r   r>      s     z1TestTimeAugmentation.__call__.<locals>.<listcomp>)r$   r#   
collate_fnc                   s"   g | ]}  t| j qS r   )r5   r   inverser4   )r<   rJ   )r7   r   r   r>      s     )dim)dictr#   
ValueErrorranger
   r"   r	   r$   r   has_tqdmr-   r   r%   r2   tor&   r4   extendr   r   r,   r   meanstditem)r7   rL   rM   data_indsdloutsboutput_moderY   rZ   Zvvcr   )rO   r7   r   __call__   s$    


zTestTimeAugmentation.__call__)rK   )rG   
__module____qualname____doc__r   r   IMAGELABELDEFAULT_POST_FIXr9   r6   rc   r   r   r   r   r   2   s$   >(* ).
__future__r   rC   collections.abcr   copyr   typingr   r   numpyr@   torchmonai.config.type_definitionsr   Zmonai.data.dataloaderr	   Zmonai.data.datasetr
   monai.data.utilsr   r   monai.transforms.composer   monai.transforms.croppad.batchr   monai.transforms.inverser   Z monai.transforms.post.dictionaryr   monai.transforms.transformr   0monai.transforms.utils_pytorch_numpy_unificationr   r   monai.utilsr   r   r   r   rV   __all__metari   r   r   r   r   r   r   <module>   s2   