o
    i                     @  sf   d dl mZ d dlZd dlmZmZ d dlmZ ddddZddddZddddZ		ddddZ
dS )    )annotationsN)Tensornn)SlidingWindowInfererhead_outputsdict[str, list[Tensor]]keyslist[str] | NonereturnNonec                 C  sb   |du r
t |  }|D ]"}| | }t|tr|g| |< qt|d tr+t || |< qtddS )a  
    An in-place function. We expect ``head_outputs`` to be Dict[str, List[Tensor]].
    Yet if it is Dict[str, Tensor], this func converts it to Dict[str, List[Tensor]].
    It will be modified in-place.

    Args:
        head_outputs: a Dict[str, List[Tensor]] or Dict[str, Tensor], will be modifier in-place
        keys: the keys in head_output that need to have value type List[Tensor]. If not provided, will use head_outputs.keys().
    Nr   zMThe output of network should be Dict[str, List[Tensor]] or Dict[str, Tensor].)listr   
isinstancer   
ValueError)r   r   kZvalue_k r   j/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/predict_utils.pyensure_dict_value_to_list_   s   

r   c                   sV   |du r
t   } fdd|D }tt|}t|dkr)td| ddS )ai  
    We expect the values in ``head_outputs``: Dict[str, List[Tensor]] to have the same length.
    Will raise ValueError if not.

    Args:
        head_outputs: a Dict[str, List[Tensor]] or Dict[str, Tensor]
        keys: the keys in head_output that need to have values (List) with same length.
            If not provided, will use head_outputs.keys().
    Nc                   s   g | ]}t  | qS r   )len).0r   r   r   r   
<listcomp>9   s    z1check_dict_values_same_length.<locals>.<listcomp>   z>The values in the input dict should have the same length, Got .)r   r   torchuniquetensorr   r   )r   r   Znum_output_levels_listnum_output_levelsr   r   r   check_dict_values_same_length,   s   
r   imagesr   network	nn.Modulelist[Tensor]c                 C  sh   || }t |ttfrt|S t|| |du rt| }t|| g }|D ]
}|t|| 7 }q'|S )aA  
    Decompose the output of network (a dict) into a list.

    Args:
        images: input of the network
        keys: the keys in the network output whose values will be output in this func.
            If not provided, will use all keys.

    Return:
        network output values concat to a single List[Tensor]
    N)r   tupler   r   r   r   )r   r   r   r   head_outputs_sequencer   r   r   r   _network_sequence_output?   s   

r$   	list[str]infererSlidingWindowInferer | Nonec           	      C  sj   |du rt d|| t||d}t|t| }i }t|D ]\}}t||| ||d   ||< q|S )aB  
    Predict network dict output with an inferer. Compared with directly output network(images),
    it enables a sliding window inferer that can be used to handle large inputs.

    Args:
        images: input of the network, Tensor sized (B, C, H, W) or  (B, C, H, W, D)
        network: a network that takes an image Tensor sized (B, C, H, W) or (B, C, H, W, D) as input
            and outputs a dictionary Dict[str, List[Tensor]] or Dict[str, Tensor].
        keys: the keys in the output dict, should be network output keys or a subset of them.
        inferer: a SlidingWindowInferer to handle large inputs.

    Return:
        The predicted head_output from network, a Dict[str, List[Tensor]]

    Example:
        .. code-block:: python

            # define a naive network
            import torch
            import monai
            class NaiveNet(torch.nn.Module):
                def __init__(self, ):
                    super().__init__()

                def forward(self, images: torch.Tensor):
                    return {"cls": torch.randn(images.shape), "box_reg": [torch.randn(images.shape)]}

            # create a predictor
            network = NaiveNet()
            inferer = monai.inferers.SlidingWindowInferer(
                roi_size = (128, 128, 128),
                overlap = 0.25,
                cache_roi_weight_map = True,
            )
            network_output_keys=["cls", "box_reg"]
            images = torch.randn((2, 3, 512, 512, 512))  # a large input
            head_outputs = predict_with_inferer(images, network, network_output_keys, inferer)

    NzFPlease set inferer as a monai.inferers.inferer.SlidingWindowInferer(*))r   r   )r   r$   r   	enumerater   )	r   r   r   r&   r#   r   r   ir   r   r   r   predict_with_inferer\   s   *"r*   )N)r   r   r   r	   r
   r   )r   r   r   r    r   r	   r
   r!   )
r   r   r   r    r   r%   r&   r'   r
   r   )
__future__r   r   r   r   monai.inferersr   r   r   r$   r*   r   r   r   r   <module>   s   