o
    .i8$                     @  s   d dl mZ d dlZd dlmZ d dlmZmZ eddZe	ddd#d$ddZ
e	ddd#d$ddZe	ddd%d&ddZe	ddd'd(dd ZG d!d" d"ejZdS ))    )annotationsN)ComponentStoreunsqueeze_rightNoiseSchedulesz%Functions to generate noise scheduleslinear_betazLinear beta schedule-C6?{Gz?num_train_timestepsint
beta_startfloatbeta_endc                 C  s   t j||| t jdS )a  
    Linear beta noise schedule function.

    Args:
        num_train_timesteps: number of timesteps
        beta_start: start of beta range, default 1e-4
        beta_end: end of beta range, default 2e-2

    Returns:
        betas: beta schedule tensor
    dtypetorchlinspacefloat32r	   r   r    r   e/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/schedulers/scheduler.py_linear_beta+   s   r   Zscaled_linear_betazScaled linear beta schedulec                 C  s    t j|d |d | t jdd S )a  
    Scaled linear beta noise schedule function.

    Args:
        num_train_timesteps: number of timesteps
        beta_start: start of beta range, default 1e-4
        beta_end: end of beta range, default 2e-2

    Returns:
        betas: beta schedule tensor
          ?r      r   r   r   r   r   _scaled_linear_beta;   s    r   Zsigmoid_betazSigmoid beta schedule   	sig_rangec                 C  s&   t | || }t |||  | S )aB  
    Sigmoid beta noise schedule function.

    Args:
        num_train_timesteps: number of timesteps
        beta_start: start of beta range, default 1e-4
        beta_end: end of beta range, default 2e-2
        sig_range: pos/neg range of sigmoid input, default 6

    Returns:
        betas: beta schedule tensor
    )r   r   sigmoid)r	   r   r   r   betasr   r   r   _sigmoid_betaK   s   r   cosinezCosine scheduleMb?sc                 C  s   t d| | d }t ||  | d|  t j d d }||d   }d|dd |dd   }t |dd	}d| }t j|dd
}|||fS )z
    Cosine noise schedule, see https://arxiv.org/abs/2102.09672

    Args:
        num_train_timesteps: number of timesteps
        s: smoothing factor, default 8e-3 (see referenced paper)

    Returns:
        (betas, alphas, alpha_cumprod) values
    r      r   r         ?Ng        g+?dim)r   r   cospiitemclipcumprod)r	   r"   xalphas_cumprodr   alphasr   r   r   _cosine_beta]   s   (
r0   c                      s8   e Zd ZdZdd fd
dZdddZdddZ  ZS )	Schedulera  
    Base class for other schedulers based on a noise schedule function.

    This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here
    the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`,
    which is the name of a component in NoiseSchedules. These components must all be callables which return either
    the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions
    can be provided by using the NoiseSchedules.add_def, for example:

    .. code-block:: python

        from monai.networks.schedulers import NoiseSchedules, DDPMScheduler

        @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function")
        def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2):
            return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)

        scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule")

    All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of
    timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through
    the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules
    to get a listing of stored objects with their docstring descriptions.

    Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule
    type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended
    to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are
    still used for some schedules but these are provided as keyword arguments now.

    Args:
        num_train_timesteps: number of diffusion steps used to train the model.
        schedule: member of NoiseSchedules,
            a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple
        schedule_args: arguments to pass to the schedule function
      r   r	   r
   schedulestrreturnNonec                   s   t    ||d< t| di |}t|tr |\| _| _| _n|| _d| j | _tj	| jdd| _|| _
td| _d | _t|d dd| _d S )Nr	   r$   r   r&   r#   r%   r   )super__init__r   
isinstancetupler   r/   r.   r   r,   r	   tensoronenum_inference_stepsarange	timesteps)selfr	   r3   schedule_argsZnoise_sched	__class__r   r   r8      s   

zScheduler.__init__original_samplestorch.Tensornoiser?   c                 C  sf   | j j|j|jd| _ ||j}t| j | d |j}td| j |  d |j}|| ||  }|S )aB  
        Add noise to the original samples.

        Args:
            original_samples: original samples
            noise: noise to add to samples
            timesteps: timesteps tensor indicating the timestep to be computed for each sample.

        Returns:
            noisy_samples: sample with added noise
        devicer   r   r#   r.   torH   r   r   ndim)r@   rD   rF   r?   Zsqrt_alpha_cumprodsqrt_one_minus_alpha_prodZnoisy_samplesr   r   r   	add_noise   s   zScheduler.add_noisesamplec                 C  sf   | j j|j|jd| _ ||j}t| j | d |j}td| j |  d |j}|| ||  }|S )NrG   r   r#   rI   )r@   rN   rF   r?   Zsqrt_alpha_prodrL   Zvelocityr   r   r   get_velocity   s   zScheduler.get_velocity)r2   r   )r	   r
   r3   r4   r5   r6   )rD   rE   rF   rE   r?   rE   r5   rE   )rN   rE   rF   rE   r?   rE   r5   rE   )__name__
__module____qualname____doc__r8   rM   rO   __classcell__r   r   rB   r   r1   s   s
    $
r1   )r   r   )r	   r
   r   r   r   r   )r   r   r   )r	   r
   r   r   r   r   r   r   )r!   )r	   r
   r"   r   )
__future__r   r   torch.nnnnmonai.utilsr   r   r   add_defr   r   r   r0   Moduler1   r   r   r   r   <module>   s    




