o
    i+                     @  s|   d Z ddlmZ ddlZddlmZ ddlZddlmZ ddlm	Z	m
Z
mZmZmZ ddlmZ dddZG dd dZdS )z
This script is modified from torchvision to support N-D images,

https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py
    )annotationsN)Sequence)Tensor)COMPUTE_DTYPECenterSizeModeStandardModeconvert_box_modeis_valid_box_values)look_up_optiongt_boxesr   	proposalsweightsreturnc           	      C  s6  | j d |j d krtdtt|ddgd }t| s!tdt|s)tdt|ttd}t| ttd}|d	| d|d	d	d	|f |d	d	d	|f   |d	d	|d	f  }||d	 dt	
|d	d	|d	f |d	d	|d	f   }t	j||fd
d}t	| st	| rtd|S )a  
    Encode a set of proposals with respect to some reference ground truth (gt) boxes.

    Args:
        gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
        proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
        weights: the weights for ``(cx, cy, w, h) or (cx,cy,cz, w,h,d)``

    Return:
        encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
    r   z8gt_boxes.shape[0] should be equal to proposals.shape[0].         z?gt_boxes is not valid. Please check if it contains empty boxes.z@proposals is not valid. Please check if it contains empty boxes.src_modedst_modeN   dimztargets is NaN or Inf.)shape
ValueErrorr
   lenr	   r   r   r   	unsqueezetorchlogcatisnananyisinf)	r   r   r   spatial_dimsZ	ex_cccwhdZ	gt_cccwhdZtargets_dxyzZtargets_dwhdtargets r$   f/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/box_coder.pyencode_boxes@   s,   &&r&   c                   @  sD   e Zd ZdZddd	d
ZdddZdddZdddZdddZdS ) BoxCodera  
    This class encodes and decodes a set of bounding boxes into
    the representation used for training the regressors.

    Args:
        weights: 4-element tuple or 6-element tuple
        boxes_xform_clip: high threshold to prevent sending too large values into torch.exp()

    Example:
        .. code-block:: python

            box_coder = BoxCoder(weights=[1., 1., 1., 1., 1., 1.])
            gt_boxes = torch.tensor([[1,2,1,4,5,6],[1,3,2,7,8,9]])
            proposals = gt_boxes + torch.rand(gt_boxes.shape)
            rel_gt_boxes = box_coder.encode_single(gt_boxes, proposals)
            gt_back = box_coder.decode_single(rel_gt_boxes, proposals)
            # We expect gt_back to be equal to gt_boxes
    Nr   Sequence[float]boxes_xform_clipfloat | Noner   Nonec                 C  s:   |d u r	t d}tt|ddgd | _|| _|| _d S )Ng     @O@r   r   r   )mathr   r
   r   r"   r   r)   )selfr   r)   r$   r$   r%   __init__}   s
   

zBoxCoder.__init__r   Sequence[Tensor]r   tuple[Tensor]c                 C  sN   dd |D }t jt|dd}t jt|dd}| ||}||d}|S )a  
        Encode a set of proposals with respect to some ground truth (gt) boxes.

        Args:
            gt_boxes: list of gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            proposals: list of boxes to be encoded, each element is Mx4 or Mx6 torch tensor.
                The box mode is assumed to be ``StandardMode``

        Return:
            A tuple of encoded gt, target of box regression that is used to
                convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
        c                 S  s   g | ]}t |qS r$   )r   .0br$   r$   r%   
<listcomp>   s    z#BoxCoder.encode.<locals>.<listcomp>r   r   )r   r   tupleencode_singlesplit)r-   r   r   boxes_per_imageZconcat_gt_boxesZconcat_proposalsZconcat_targetsr#   r$   r$   r%   encode   s   zBoxCoder.encoder   c                 C  s.   |j }|j}tj| j||d}t|||}|S )a  
        Encode proposals with respect to ground truth (gt) boxes.

        Args:
            gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Return:
            encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
        dtypedevice)r;   r<   r   	as_tensorr   r&   )r-   r   r   r;   r<   r   r#   r$   r$   r%   r6      s
   zBoxCoder.encode_single	rel_codesreference_boxesc                 C  s   t |trt |tjstddd |D }tjt|dd}d}|D ]}||7 }q#|dkr4||d}| ||}|dkrH||dd| j	 }|S )a  
        From a set of original reference_boxes and encoded relative box offsets,

        Args:
            rel_codes: encoded boxes, Nx4 or Nx6 torch tensor.
            reference_boxes: a list of reference boxes, each element is Mx4 or Mx6 torch tensor.
                The box mode is assumed to be ``StandardMode``

        Return:
            decoded boxes, Nx1x4 or Nx1x6 torch tensor. The box mode will be ``StandardMode``
        zInput arguments wrong type.c                 S  s   g | ]}| d qS )r   )sizer1   r$   r$   r%   r4      s    z#BoxCoder.decode.<locals>.<listcomp>r   r   r   )

isinstancer   r   r   r   r   r5   reshapedecode_singler"   )r-   r>   r?   r8   concat_boxesbox_sumval
pred_boxesr$   r$   r%   decode   s   
zBoxCoder.decodec                 C  s  | |j}|jd }g }t|ttd}t| jD ]}|dd|| j f }|dd|f }|dd|d|f | j|  }	|dd| j| d|f | j|| j   }
t	j
|
 t| jd}
|	|dddf  |dddf  }t	|
|dddf  }| |	j}t	| st	| rtdt	jd|j|jd| }|||  |||  q|ddd |d	dd  }t	j|dd
d	}|S )a  
        From a set of original boxes and encoded relative box offsets,

        Args:
            rel_codes: encoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor.
            reference_boxes: reference boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Return:
            decoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor. The box mode will to be ``StandardMode``
        rA   r   N)maxzpred_whd_axis is NaN or Inf.g      ?r:   r   r   r   )tor;   r   r   r   r   ranger"   r   r   clampr   r)   expr   r    r!   r   tensorr<   appendstackflatten)r-   r>   r?   offsetrH   Zboxes_cccwhdaxisZwhd_axisZctr_xyz_axisZ	dxyz_axisZ	dwhd_axisZpred_ctr_xyx_axisZpred_whd_axisZc_to_c_whd_axisZpred_boxes_finalr$   r$   r%   rD      s0   
 ,$zBoxCoder.decode_single)N)r   r(   r)   r*   r   r+   )r   r/   r   r/   r   r0   )r   r   r   r   r   r   )r>   r   r?   r/   r   r   )r>   r   r?   r   r   r   )	__name__
__module____qualname____doc__r.   r9   r6   rI   rD   r$   r$   r$   r%   r'   i   s    


r'   )r   r   r   r   r   r   r   r   )rX   
__future__r   r,   collections.abcr   r   r   monai.data.box_utilsr   r   r   r   r	   monai.utils.moduler
   r&   r'   r$   r$   r$   r%   <module>   s   -
)