U
    jPh*                     @  sp   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ G dd deZG d	d
 d
eZdS )    )annotationsN)_Loss)get_act_layer)LossReduction)StrEnumc                   @  s   e Zd ZdZdZdZdS )AdversarialCriterionsZbceZhingeleast_squaresN)__name__
__module____qualname__BCEHINGELEAST_SQUARE r   r   R/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/adversarial_loss.pyr      s   r   c                      s   e Zd ZdZejejdfddddd fdd	Zd
dd
dddZ	d
d
dddZ
dddddddZd
d
d
dddZ  ZS )PatchAdversarialLossa-  
    Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator.
    Warning: due to the possibility of using different criterions, the output of the discrimination
    mustn't be passed to a final activation layer. That is taken care of internally within the loss.

    Args:
        reduction: {``"none"``, ``"mean"``, ``"sum"``}
            Specifies the reduction to apply to the output. Defaults to ``"mean"``.

            - ``"none"``: no reduction will be applied.
            - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
            - ``"sum"``: the output will be summed.

        criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.
            Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs
            through an activation layer prior to calling the loss.
        no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.
    FzLossReduction | strstrboolNone)	reduction	criterionno_activation_leastsqreturnc                   s   t  jt|d | ttkr4tddt d| _d| _	|  |tj
krjtd| _tjj|d| _nV|tjkrtd| _d| _	n:|tjkr|rd | _ntd	d
difd| _tjj|d| _|| _|| _d S )N)r   zGUnrecognised criterion entered for Adversarial Loss. Must be one in: %sz, g      ?g        SIGMOIDZTANHg      	LEAKYRELUnegative_slopeg?)name)super__init__r   lowerlistr   
ValueErrorjoin
real_label
fake_labelr   r   
activationtorchnnBCELossloss_fctr   r   MSELossr   r   )selfr   r   r   	__class__r   r   r   2   s.    




zPatchAdversarialLoss.__init__ztorch.Tensor)inputtarget_is_realr   c                 C  sJ   |r
| j n| j}td|| |d j}|d |	|S )a  
        Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.

        Args:
            input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale
            discriminator). This is used to match the shape.
            target_is_real: whether the input is real or wannabe-real (1s) or fake (0s).
        Returns:
           r   F)
r#   r$   r&   tensorfill_typetodevicerequires_grad_	expand_as)r+   r.   r/   Zfilling_labelZlabel_tensorr   r   r   get_target_tensorT   s    
&
z&PatchAdversarialLoss.get_target_tensor)r.   r   c                 C  s8   t d|d  |d j}|d ||S )z
        Gets a zero tensor.

        Args:
            input: tensor which shape you want the zeros tensor to correspond to.
        Returns:
        r   F)r&   r1   r3   r4   r5   r6   r7   )r+   r.   Zzero_label_tensorr   r   r   get_zero_tensorc   s    	$
z$PatchAdversarialLoss.get_zero_tensorztorch.Tensor | listz!torch.Tensor | list[torch.Tensor])r.   r/   for_discriminatorr   c                 C  s$  |s|sd}t d t|ts&|g}g }t|D ]8\}}| jtjkrZ|| 	|| q2|| 
| q2g }t|D ]Z\}}| jdk	r| |}| jtjkr|s| | || }	n| ||| }	||	 qx|dk	r | jtjkrtt|}
n$| jtjkrtt|}
n|}
|
S )aL  

        Args:
            input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors
                or a tensor; they shouldn't have gone through an activation layer.
            target_is_real: whereas the input corresponds to discriminator output for real or fake images
            for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last
                case, target_is_real is set to True, as the generator wants the input to be dimmed as real.
        Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale
            discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the
            summed or mean loss over the tensor and discriminator/s.

        TzVariable target_is_real has been set to False, but for_discriminator is setto False. To optimise a generator, target_is_real must be set to True.N)warningswarn
isinstancer    	enumerater   r   r   appendr8   r9   r%   _forward_singler   r   MEANr&   meanstackSUMsum)r+   r.   r/   r:   target__Zdisc_outZ	loss_listZdisc_indZloss_lossr   r   r   forwardp   s6    



zPatchAdversarialLoss.forward)r.   targetr   c                 C  sX   | j tjks| j tjkr&| ||}n.| j tjkrTt|d | |}t	| }|S )Nr0   )
r   r   r   r   r)   r   r&   minr9   rB   )r+   r.   rJ   rI   minvalr   r   r   r@      s    z$PatchAdversarialLoss._forward_single)r	   r
   r   __doc__r   rA   r   r   r   r8   r9   rI   r@   __classcell__r   r   r,   r   r      s   "6r   )
__future__r   r;   r&   torch.nn.modules.lossr   monai.networks.layers.utilsr   monai.utilsr   monai.utils.enumsr   r   r   r   r   r   r   <module>   s   