U
    Ph                     @  s   d dl mZ d dlZd dlmZ d dlmZ edddd rNdd
dddZndd
dddZedddd rdd
dddZnd d
dddZG dd dejZ	G dd dej
jZG dd dejZG dd dejZG dd dejZdS )!    )annotationsN)nn)optional_importztorch.nn.functionalmish)name   Fboolinplacec                 C  s   t jjj| |dS Nr	   )torchr   
functionalr   xr
    r   U/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/activation.py
monai_mish   s    r   c                 C  s   | t t jj|  S N)r   tanhr   r   softplusr   r   r   r   r      s    siluc                 C  s   t jjj| |dS r   )r   r   r   r   r   r   r   r   monai_swish    s    r   c                 C  s
   t | S r   )SwishImplementationapplyr   r   r   r   r   %   s    c                      s2   e Zd ZdZd	 fdd	ZdddddZ  ZS )
Swishai  Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.


    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['swish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
          ?c                   s   t    || _d S r   )super__init__alpha)selfr   	__class__r   r   r   @   s    
zSwish.__init__torch.Tensor)inputreturnc                 C  s   |t | j|  S r   )r   sigmoidr   r   r#   r   r   r   forwardD   s    zSwish.forward)r   __name__
__module____qualname____doc__r   r'   __classcell__r   r   r    r   r   )   s   r   c                   @  s(   e Zd ZdZedd Zedd ZdS )r   zMemory efficient implementation for training
    Follows recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()
    c                 C  s   |t | }| | |S r   )r   r%   save_for_backward)ctxr#   resultr   r   r   r'   P   s    
zSwishImplementation.forwardc                 C  s,   | j d }t|}||d|d|     S )Nr   r   )saved_tensorsr   r%   )r/   grad_outputr#   Zsigmoid_inputr   r   r   backwardV   s    

zSwishImplementation.backwardN)r)   r*   r+   r,   staticmethodr'   r3   r   r   r   r   r   H   s
   
r   c                      s6   e Zd ZdZddd fddZddd	d
Z  ZS )MemoryEfficientSwisha%  Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.

    Memory efficient implementation for training following recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

    From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,
    this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['memswish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    Fr   r	   c                   s   t    || _d S r   r   r   r
   r   r
   r    r   r   r   |   s    
zMemoryEfficientSwish.__init__r"   r#   c                 C  s   t || jS r   )r   r
   r&   r   r   r   r'      s    zMemoryEfficientSwish.forward)Fr(   r   r   r    r   r5   ]   s   r5   c                      s6   e Zd ZdZddd fddZddd	d
Z  ZS )Misha  Applies the element-wise function:

    .. math::
        \text{Mish}(x) = x * tanh(\text{softplus}(x)).

    Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.

    From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,
    this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['mish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    Fr   r	   c                   s   t    || _d S r   r6   r7   r    r   r   r      s    
zMish.__init__r"   r8   c                 C  s   t || jS r   )r   r
   r&   r   r   r   r'      s    zMish.forward)Fr(   r   r   r    r   r9      s   r9   c                   @  s   e Zd ZdZddddZdS )GEGLUa  Applies the element-wise function:

    .. math::
        \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2)

    where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension.

    Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202.

    Shape:
        - Input: :math:`(N, *, 2 * D)`
        - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions
    r"   r8   c                 C  s"   |j ddd\}}|tj| S )N   )dim)chunkr   r   gelu)r   r#   r   Zgater   r   r   r'      s    zGEGLU.forwardN)r)   r*   r+   r,   r'   r   r   r   r   r:      s   r:   )F)F)F)F)
__future__r   r   r   monai.utilsr   r   r   Moduler   autogradFunctionr   r5   r9   r:   r   r   r   r   <module>   s   ("