U
    PhB(                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlmZ d dlmZ d dlmZ d dlmZmZ d d	lmZmZ d
dddddZd
ddddddddZejfd
ddddddddZejfd
ddddddddZdS )    )annotationsN)Sequence)Any)Tensor)standardize_empty_box)
SpatialPad)compute_divisible_spatial_sizeconvert_pad_mode)PytorchPadModeensure_tuple_repzlist[Tensor] | TensorintNone)input_imagesspatial_dimsreturnc                 C  s   t | tr:t| j|d krtd|d  d| j dnLt | tr~| D ]2}t|j|d krHtd|d  d|j dqHntddS )	am  
    Validate the input dimensionality (raise a `ValueError` if invalid).

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
       z`When input_images is a Tensor, its need to be (spatial_dims + 2)-D.In this case, it should be a z-D Tensor, got Tensor shape .   zsWhen input_images is a List[Tensor], each element should have be (spatial_dims + 1)-D.In this case, it should be a z2input_images needs to be a List[Tensor] or Tensor.N)
isinstancer   lenshape
ValueErrorlist)r   r   img r   ^/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/detection/utils/detector_utils.pycheck_input_images   s    	

r   zlist[dict[str, Tensor]] | Nonestrzlist[dict[str, Tensor]])r   targetsr   target_label_keytarget_box_keyr   c           	      C  s  |dkrt dt| t|kr>t dt|  dt| dtt|D ]>}|| }|| ksp|| krt | d| d|  d|| }t|tjst dt| dt|jd	ks|jd
 d	| kr"|	 dkrt
d|j dd	|  d nt dd	|  d|j dt|s@t d|j dt||d|| |< || }t|rJt
d|j d | || |< qJ|S )a  
    Validate the input images/targets during training (raise a `ValueError` if invalid).

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        targets: a list of dict. Each dict with two keys: target_box_key and target_label_key,
            ground-truth boxes present in the image.
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
        target_label_key: the expected key of target labels.
        target_box_key: the expected key of target boxes.
    Nz4Please provide ground truth targets during training.z4len(input_images) should equal to len(targets), got z, r   z and z# are expected keys in targets. Got z0Expected target boxes to be of type Tensor, got r   r   z)Warning: Given target boxes has shape of zA. The detector reshaped it with boxes = torch.reshape(boxes, [0, z]).z2Expected target boxes to be a tensor of shape [N, z], got z.).z0Expected target boxes to be a float tensor, got r   z Warning: Given target labels is z*. The detector converted it to torch.long.)r   r   rangekeysr   torchr   typer   numelwarningswarnis_floating_pointdtyper   long)	r   r   r   r   r    itargetboxeslabelsr   r   r   check_training_targets7   s:    "
r1   zint | Sequence[int]zPytorchPadMode | strr   ztuple[Tensor, list[list[int]]])r   r   size_divisiblemodekwargsr   c                   s  t |}t| trt| j d  t |d} fddt|D }dd |D ddd }t|dkr|  g| jd  fS t| |d}t	j
| |fd	|i| g| jd  fS fd
d| D }	| d jd }
| d j}| d j}t|	}tj|dd\}}t|ks"t|kr*tdtt||d}tjt|	|
gt| ||d}tf |d|d|}t| D ]\}}||||df< qx|dd |	D fS )a  
    Pad the input images, so that the output spatial sizes are divisible by `size_divisible`.
    It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
    Padded size (H, W) or (H, W, D) is divisible by size_divisible.
    Default padding uses constant padding with value 0.0

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2D or 3D.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

    Return:
        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    N)spatial_shapekc                   s&   g | ]\}}d t | |  d fqS )r   )max).0r-   sp_i)	orig_sizer   r   
<listcomp>   s     zpad_images.<locals>.<listcomp>c                 S  s$   g | ]}|d d d D ]}|qqS )Nr!   r   )r8   sublistvalr   r   r   r;      s       r!   r   )dstr3   r3   c                   s   g | ]}|j   d  qS )N)r   )r8   r   r"   r   r   r;      s     )dimzG Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible).)r+   deviceend)spatial_sizemethodr3   .c                 S  s   g | ]}t |qS r   )r   )r8   ssr   r   r   r;      s     )r   r   r   r   r   r   	enumerater7   r	   Fpadr+   r@   r%   tensorr   r   zerosr   )r   r   r2   r3   r4   new_sizeZall_pad_widthpt_pad_widthmode_image_sizesin_channelsr+   r@   Zimage_sizes_tZmax_spatial_size_t_max_spatial_sizeimagespadderidxr   r   )r:   r   r   
pad_imageso   s0    

(


 rT   c                 K  s&   t | | t||}t| |||f|S )aV  
    Preprocess the input images, including

    - validate of the inputs
    - pad the inputs so that the output spatial sizes are divisible by `size_divisible`.
      It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
      Padded size (H, W) or (H, W, D) is divisible by size_divisible.
      Default padding uses constant padding with value 0.0

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

    Return:
        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    )r   r   rT   )r   r   r2   r3   r4   r   r   r   preprocess_images   s    

rU   )
__future__r   r(   collections.abcr   typingr   r%   torch.nn.functionalnn
functionalrF   r   monai.data.box_utilsr   Zmonai.transforms.croppad.arrayr   monai.transforms.utilsr   r	   monai.utilsr
   r   r   r1   CONSTANTrT   rU   r   r   r   r   <module>   s    <H