U
    uPh
                     @  sT   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 G dd deZ
e
ZdS )    )annotationsN)_Loss)COMPUTE_DTYPEbox_pair_giou)LossReductionc                      s@   e Zd ZdZejfddd fddZddddd	d
Z  ZS )BoxGIoULossa\  
    Compute the generalized intersection over union (GIoU) loss of a pair of boxes.
    The two inputs should have the same shape. giou_loss = 1.0 - giou

    The range of GIoU is (-1.0, 1.0]. Thus the range of GIoU loss is [0.0, 2.0).

    Args:
        reduction: {``"none"``, ``"mean"``, ``"sum"``}
            Specifies the reduction to apply to the output. Defaults to ``"mean"``.
            - ``"none"``: no reduction will be applied.
            - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
            - ``"sum"``: the output will be summed.
    zLossReduction | strNone)	reductionreturnc                   s   t  jt|jd d S )N)r	   )super__init__r   value)selfr	   	__class__ K/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/giou_loss.pyr   $   s    zBoxGIoULoss.__init__ztorch.Tensor)inputtargetr
   c                 C  s   |j |j kr&td|j  d|j  d|j}t|jtd|jtd}d| }| jtjj	krf|
 }n:| jtjj	kr~| }n"| jtjj	krntd| j d||S )aN  
        Args:
            input: predicted bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
            target: GT bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``

        Raises:
            ValueError: When the two inputs have different shape.
        z"ground truth has different shape (z) from input ())dtypeg      ?zUnsupported reduction: z0, available options are ["mean", "sum", "none"].)shape
ValueErrorr   r   tor   r	   r   MEANr   meanSUMsumNONE)r   r   r   	box_dtypegioulossr   r   r   forward'   s     	
 


zBoxGIoULoss.forward)	__name__
__module____qualname____doc__r   r   r   r"   __classcell__r   r   r   r   r      s   r   )
__future__r   torchtorch.nn.modules.lossr   Zmonai.data.box_utilsr   r   monai.utilsr   r   r    r   r   r   r   <module>   s   .