o
    i
                     @  sT   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 G dd deZ
e
ZdS )    )annotationsN)_Loss)COMPUTE_DTYPEbox_pair_giou)LossReductionc                      s2   e Zd ZdZejfd fddZdddZ  ZS )BoxGIoULossa\  
    Compute the generalized intersection over union (GIoU) loss of a pair of boxes.
    The two inputs should have the same shape. giou_loss = 1.0 - giou

    The range of GIoU is (-1.0, 1.0]. Thus the range of GIoU loss is [0.0, 2.0).

    Args:
        reduction: {``"none"``, ``"mean"``, ``"sum"``}
            Specifies the reduction to apply to the output. Defaults to ``"mean"``.
            - ``"none"``: no reduction will be applied.
            - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
            - ``"sum"``: the output will be summed.
    	reductionLossReduction | strreturnNonec                   s   t  jt|jd d S )N)r   )super__init__r   value)selfr   	__class__ X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/giou_loss.pyr   $   s   zBoxGIoULoss.__init__inputtorch.Tensortargetc                 C  s   |j |j krtd|j  d|j  d|j}t|jtd|jtd}d| }| jtjj	kr3|
 }n| jtjj	kr?| }n| jtjj	krGn	td| j d||S )aN  
        Args:
            input: predicted bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            target: GT bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Raises:
            ValueError: When the two inputs have different shape.
        z"ground truth has different shape (z) from input ())dtypeg      ?zUnsupported reduction: z0, available options are ["mean", "sum", "none"].)shape
ValueErrorr   r   tor   r   r   MEANr   meanSUMsumNONE)r   r   r   	box_dtypegioulossr   r   r   forward'   s   	


zBoxGIoULoss.forward)r   r	   r
   r   )r   r   r   r   r
   r   )	__name__
__module____qualname____doc__r   r   r   r$   __classcell__r   r   r   r   r      s    r   )
__future__r   torchtorch.nn.modules.lossr   monai.data.box_utilsr   r   monai.utilsr   r   r"   r   r   r   r   <module>   s   .