U
    Php+                     @  s   d Z ddlmZ ddlZddlmZ ddlZddlmZ ddlm	Z	m
Z
mZmZmZ ddlmZ ddddd	d
dZG dd dZdS )z
This script is modified from torchvision to support N-D images,

https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py
    )annotationsN)Sequence)Tensor)COMPUTE_DTYPECenterSizeModeStandardModeconvert_box_modeis_valid_box_values)look_up_optionr   )gt_boxes	proposalsweightsreturnc           	      C  s:  | j d |j d krtdtt|ddgd }t| sBtdt|sRtdt|ttd}t| ttd}|d	| d|d	d	d	|f |d	d	d	|f   |d	d	|d	f  }||d	 dt	
|d	d	|d	f |d	d	|d	f   }t	j||fd
d}t	| s.t	| r6td|S )a  
    Encode a set of proposals with respect to some reference ground truth (gt) boxes.

    Args:
        gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
        proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
        weights: the weights for ``(cx, cy, w, h) or (cx,cy,cz, w,h,d)``

    Return:
        encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
    r   z8gt_boxes.shape[0] should be equal to proposals.shape[0].         z?gt_boxes is not valid. Please check if it contains empty boxes.z@proposals is not valid. Please check if it contains empty boxes.src_modedst_modeN   dimztargets is NaN or Inf.)shape
ValueErrorr
   lenr	   r   r   r   	unsqueezetorchlogcatisnananyisinf)	r   r   r   spatial_dimsZ	ex_cccwhdZ	gt_cccwhdZtargets_dxyzZtargets_dwhdtargets r$   Y/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/detection/utils/box_coder.pyencode_boxes@   s,    && r&   c                   @  sl   e Zd ZdZdddddddZd	d	d
dddZddddddZdd	ddddZddddddZdS )BoxCodera  
    This class encodes and decodes a set of bounding boxes into
    the representation used for training the regressors.

    Args:
        weights: 4-element tuple or 6-element tuple
        boxes_xform_clip: high threshold to prevent sending too large values into torch.exp()

    Example:
        .. code-block:: python

            box_coder = BoxCoder(weights=[1., 1., 1., 1., 1., 1.])
            gt_boxes = torch.tensor([[1,2,1,4,5,6],[1,3,2,7,8,9]])
            proposals = gt_boxes + torch.rand(gt_boxes.shape)
            rel_gt_boxes = box_coder.encode_single(gt_boxes, proposals)
            gt_back = box_coder.decode_single(rel_gt_boxes, proposals)
            # We expect gt_back to be equal to gt_boxes
    NzSequence[float]zfloat | NoneNone)r   boxes_xform_clipr   c                 C  s:   |d krt d}tt|ddgd | _|| _|| _d S )Ng     @O@r   r   r   )mathr   r
   r   r"   r   r)   )selfr   r)   r$   r$   r%   __init__}   s
    
zBoxCoder.__init__zSequence[Tensor]ztuple[Tensor])r   r   r   c                 C  sN   dd |D }t jt|dd}t jt|dd}| ||}||d}|S )a  
        Encode a set of proposals with respect to some ground truth (gt) boxes.

        Args:
            gt_boxes: list of gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            proposals: list of boxes to be encoded, each element is Mx4 or Mx6 torch tensor.
                The box mode is assumed to be ``StandardMode``

        Return:
            A tuple of encoded gt, target of box regression that is used to
                convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
        c                 S  s   g | ]}t |qS r$   )r   .0br$   r$   r%   
<listcomp>   s     z#BoxCoder.encode.<locals>.<listcomp>r   r   )r   r   tupleencode_singlesplit)r+   r   r   boxes_per_imageZconcat_gt_boxesZconcat_proposalsZconcat_targetsr#   r$   r$   r%   encode   s    zBoxCoder.encoder   c                 C  s.   |j }|j}tj| j||d}t|||}|S )a  
        Encode proposals with respect to ground truth (gt) boxes.

        Args:
            gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Return:
            encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.
        dtypedevice)r7   r8   r   	as_tensorr   r&   )r+   r   r   r7   r8   r   r#   r$   r$   r%   r2      s
    zBoxCoder.encode_single)	rel_codesreference_boxesr   c                 C  s   t |trt |tjstddd |D }tjt|dd}d}|D ]}||7 }qF|dkrh||d}| ||}|dkr||dd| j	 }|S )a  
        From a set of original reference_boxes and encoded relative box offsets,

        Args:
            rel_codes: encoded boxes, Nx4 or Nx6 torch tensor.
            reference_boxes: a list of reference boxes, each element is Mx4 or Mx6 torch tensor.
                The box mode is assumed to be ``StandardMode``

        Return:
            decoded boxes, Nx1x4 or Nx1x6 torch tensor. The box mode will be ``StandardMode``
        zInput arguments wrong type.c                 S  s   g | ]}| d qS )r   )sizer-   r$   r$   r%   r0      s     z#BoxCoder.decode.<locals>.<listcomp>r   r   r   )

isinstancer   r   r   r   r   r1   reshapedecode_singler"   )r+   r:   r;   r4   concat_boxesbox_sumval
pred_boxesr$   r$   r%   decode   s    
zBoxCoder.decodec                 C  s  | |j}|jd }g }t|ttd}t| jD ]2}|dd|| j f }|dd|f }|dd|d|f | j|  }	|dd| j| d|f | j|| j   }
t	j
|
 t| jd}
|	|dddf  |dddf  }t	|
|dddf  }| |	j}t	| s*t	| r2tdt	jd|j|jd| }|||  |||  q2|ddd |d	dd  }t	j|dd
d	}|S )a  
        From a set of original boxes and encoded relative box offsets,

        Args:
            rel_codes: encoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor.
            reference_boxes: reference boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Return:
            decoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor. The box mode will to be ``StandardMode``
        r=   r   N)maxzpred_whd_axis is NaN or Inf.g      ?r6   r   r   r   )tor7   r   r   r   r   ranger"   r   r   clampr   r)   expr   r    r!   r   tensorr8   appendstackflatten)r+   r:   r;   offsetrD   Zboxes_cccwhdaxisZwhd_axisZctr_xyz_axisZ	dxyz_axisZ	dwhd_axisZpred_ctr_xyx_axisZpred_whd_axisZc_to_c_whd_axisZpred_boxes_finalr$   r$   r%   r@      s,    
 ,$ zBoxCoder.decode_single)N)	__name__
__module____qualname____doc__r,   r5   r2   rE   r@   r$   r$   r$   r%   r'   i   s   r'   )rT   
__future__r   r*   collections.abcr   r   r   monai.data.box_utilsr   r   r   r   r	   monai.utils.moduler
   r&   r'   r$   r$   r$   r%   <module>.   s   )