o
    iB(                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlmZ d dlmZ d dlmZ d dlmZmZ d d	lmZmZ d%ddZd&ddZejfd'd!d"Zejfd'd#d$ZdS )(    )annotationsN)Sequence)Any)Tensor)standardize_empty_box)
SpatialPad)compute_divisible_spatial_sizeconvert_pad_mode)PytorchPadModeensure_tuple_repinput_imageslist[Tensor] | Tensorspatial_dimsintreturnNonec                 C  s   t | trt| j|d krtd|d  d| j ddS t | trA| D ]}t|j|d kr>td|d  d|j dq%dS td)	am  
    Validate the input dimensionality (raise a `ValueError` if invalid).

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
       z`When input_images is a Tensor, its need to be (spatial_dims + 2)-D.In this case, it should be a z-D Tensor, got Tensor shape .   zsWhen input_images is a List[Tensor], each element should have be (spatial_dims + 1)-D.In this case, it should be a z2input_images needs to be a List[Tensor] or Tensor.N)
isinstancer   lenshape
ValueErrorlist)r   r   img r   k/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/detector_utils.pycheck_input_images   s.   
	
r   targetslist[dict[str, Tensor]] | Nonetarget_label_keystrtarget_box_keylist[dict[str, Tensor]]c           	      C  s  |du rt dt| t|krt dt|  dt| dtt|D ]}|| }|| vs7|| vrFt | d| d|  d|| }t|tjsZt dt| dt|jd	ksj|jd
 d	| kr|	 dkrt
d|j dd	|  d nt dd	|  d|j dt|st d|j dt||d|| |< || }t|rt
d|j d | || |< q%|S )a  
    Validate the input images/targets during training (raise a `ValueError` if invalid).

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        targets: a list of dict. Each dict with two keys: target_box_key and target_label_key,
            ground-truth boxes present in the image.
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
        target_label_key: the expected key of target labels.
        target_box_key: the expected key of target boxes.
    Nz4Please provide ground truth targets during training.z4len(input_images) should equal to len(targets), got z, r   z and z# are expected keys in targets. Got z0Expected target boxes to be of type Tensor, got r   r   z)Warning: Given target boxes has shape of zA. The detector reshaped it with boxes = torch.reshape(boxes, [0, z]).z2Expected target boxes to be a tensor of shape [N, z], got z.).z0Expected target boxes to be a float tensor, got r   z Warning: Given target labels is z*. The detector converted it to torch.long.)r   r   rangekeysr   torchr   typer   numelwarningswarnis_floating_pointdtyper   long)	r   r   r   r    r"   itargetboxeslabelsr   r   r   check_training_targets7   s@    


r4   size_divisibleint | Sequence[int]modePytorchPadMode | strkwargsr   tuple[Tensor, list[list[int]]]c                   s  t |}t| tr[t| j d  t |d} fddt|D }dd |D ddd }t|dkrA|  g| jd  fS t| |d}t	j
| |fd	|i| g| jd  fS fd
d| D }	| d jd }
| d j}| d j}t|	}tj|dd\}}t|kst|krtdtt||d}tjt|	|
gt| ||d}td|d|d|}t| D ]\}}||||df< q|dd |	D fS )a  
    Pad the input images, so that the output spatial sizes are divisible by `size_divisible`.
    It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
    Padded size (H, W) or (H, W, D) is divisible by size_divisible.
    Default padding uses constant padding with value 0.0

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2D or 3D.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

    Return:
        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    N)spatial_shapekc                   s&   g | ]\}}d t | |  d fqS )r   )max).0r0   sp_i)	orig_sizer   r   
<listcomp>   s   & zpad_images.<locals>.<listcomp>c                 S  s$   g | ]}|d d d D ]}|qqS )Nr$   r   )r>   sublistvalr   r   r   rA      s   $ r$   r   )dstr7   r7   c                   s   g | ]
}|j   d  qS )N)r   )r>   r   r%   r   r   rA      s    )dimzG Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible).)r.   deviceend)spatial_sizemethodr7   .c                 S  s   g | ]}t |qS r   )r   )r>   ssr   r   r   rA      s    r   )r   r   r   r   r   r   	enumerater=   r	   Fpadr.   rF   r(   tensorr   r   zerosr   )r   r   r5   r7   r9   new_sizeZall_pad_widthpt_pad_widthmode_image_sizesin_channelsr.   rF   Zimage_sizes_tZmax_spatial_size_t_max_spatial_sizeimagespadderidxr   r   )r@   r   r   
pad_imageso   s0   

(


 rZ   c                 K  s*   t | | t||}t| |||fi |S )aV  
    Preprocess the input images, including

    - validate of the inputs
    - pad the inputs so that the output spatial sizes are divisible by `size_divisible`.
      It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.
      Padded size (H, W) or (H, W, D) is divisible by size_divisible.
      Default padding uses constant padding with value 0.0

    Args:
        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),
            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).
        spatial_dims: number of spatial dimensions of the images, 2 or 3.
        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.
            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.
        mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for `torch.pad` function.

    Return:
        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor
        - image_sizes, the original spatial size of each image
    )r   r   rZ   )r   r   r5   r7   r9   r   r   r   preprocess_images   s   

r[   )r   r   r   r   r   r   )r   r   r   r   r   r   r    r!   r"   r!   r   r#   )r   r   r   r   r5   r6   r7   r8   r9   r   r   r:   )
__future__r   r+   collections.abcr   typingr   r(   torch.nn.functionalnn
functionalrL   r   monai.data.box_utilsr   monai.transforms.croppad.arrayr   monai.transforms.utilsr   r	   monai.utilsr
   r   r   r4   CONSTANTrZ   r[   r   r   r   r   <module>   s"   

<H