o
    )i                     @  s   d dl mZ d dlZd dlmZ d dlmZ edddd r%ddddZnddddZedddd r:ddddZnddddZG dd dejZ	G dd dej
jZG dd dejZG dd dejZG dd dejZdS )    )annotationsN)nn)optional_importztorch.nn.functionalmish)name   Finplaceboolc                 C     t jjj| |dS N)r   )torchr   
functionalr   xr    r   b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/activation.py
monai_mish      r   c                 C  s   | t t jj|  S N)r   tanhr   r   softplusr   r   r   r   r      s   siluc                 C  r
   r   )r   r   r   r   r   r   r   r   monai_swish    r   r   c                 C  s
   t | S r   )SwishImplementationapplyr   r   r   r   r   %   s   
c                      s,   e Zd ZdZd
 fdd	Zddd	Z  ZS )Swishai  Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.


    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['swish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
          ?c                      t    || _d S r   )super__init__alpha)selfr    	__class__r   r   r   @   s   

zSwish.__init__inputtorch.Tensorreturnc                 C  s   |t | j|  S r   )r   sigmoidr    r!   r$   r   r   r   forwardD   s   zSwish.forward)r   )r$   r%   r&   r%   __name__
__module____qualname____doc__r   r)   __classcell__r   r   r"   r   r   )   s    r   c                   @  s(   e Zd ZdZedd Zedd ZdS )r   zMemory efficient implementation for training
    Follows recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()
    c                 C  s   |t | }| | |S r   )r   r'   save_for_backward)ctxr$   resultr   r   r   r)   P   s   
zSwishImplementation.forwardc                 C  s,   | j d }t|}||d|d|     S )Nr   r   )saved_tensorsr   r'   )r1   grad_outputr$   Zsigmoid_inputr   r   r   backwardV   s   

zSwishImplementation.backwardN)r+   r,   r-   r.   staticmethodr)   r5   r   r   r   r   r   H   s    
r   c                      .   e Zd ZdZdd fddZdd	d
Z  ZS )MemoryEfficientSwisha%  Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.

    Memory efficient implementation for training following recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

    From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,
    this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['memswish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    Fr   r	   c                   r   r   r   r   r   r!   r   r"   r   r   r   |      

zMemoryEfficientSwish.__init__r$   r%   c                 C     t || jS r   )r   r   r(   r   r   r   r)         zMemoryEfficientSwish.forwardFr   r	   r$   r%   r*   r   r   r"   r   r8   ]   s    r8   c                      r7   )Misha  Applies the element-wise function:

    .. math::
        \text{Mish}(x) = x * tanh(\text{softplus}(x)).

    Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.

    From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,
    this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['mish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    Fr   r	   c                   r   r   r9   r:   r"   r   r   r      r;   zMish.__init__r$   r%   c                 C  r<   r   )r   r   r(   r   r   r   r)      r=   zMish.forwardr>   r?   r@   r*   r   r   r"   r   rA      s    rA   c                   @  s   e Zd ZdZdddZdS )GEGLUa  Applies the element-wise function:

    .. math::
        \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2)

    where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension.

    Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202.

    Shape:
        - Input: :math:`(N, *, 2 * D)`
        - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions
    r$   r%   c                 C  s"   |j ddd\}}|tj| S )N   )dim)chunkr   r   gelu)r!   r$   r   Zgater   r   r   r)      s   zGEGLU.forwardNr@   )r+   r,   r-   r.   r)   r   r   r   r   rB      s    rB   r>   r?   )
__future__r   r   r   monai.utilsr   r   r   Moduler   autogradFunctionr   r8   rA   rB   r   r   r   r   <module>   s   ("