o
    .i,                     @  sh   d dl mZ d dlZd dlZd dlmZ ddlmZ G dd deZ	G dd	 d	eZ
G d
d deZdS )    )annotationsN)StrEnum   )	Schedulerc                   @  s    e Zd ZdZdZdZdZdZdS )DDPMVarianceTypez
    Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise
    to the denoised sample.
    Zfixed_smallZfixed_largelearnedlearned_rangeN)__name__
__module____qualname____doc__FIXED_SMALLFIXED_LARGELEARNEDLEARNED_RANGE r   r   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/schedulers/ddpm.pyr   *   s    r   c                   @  s   e Zd ZdZdZdZdZdS )DDPMPredictionTypea4  
    Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument.

    epsilon: predicting the noise of the diffusion process
    sample: directly predicting the noisy sample
    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
    epsilonsamplev_predictionN)r	   r
   r   r   EPSILONSAMPLEV_PREDICTIONr   r   r   r   r   6   s
    r   c                      sf   e Zd ZdZddejdejddfd- fddZd.d/ddZ	d0d d!Z
d.d1d$d%Z	d.d2d+d,Z  ZS )3DDPMScheduleraZ  
    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
    Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models"
    https://arxiv.org/abs/2006.11239

    Args:
        num_train_timesteps: number of diffusion steps used to train the model.
        schedule: member of NoiseSchedules, name of noise schedule function in component store
        variance_type: member of DDPMVarianceType
        clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
        prediction_type: member of DDPMPredictionType
        clip_sample_min: minimum clipping value when clip_sample equals True
        clip_sample_max: maximum clipping value when clip_sample equals True
        schedule_args: arguments to pass to the schedule function
    i  linear_betaTg      g      ?num_train_timestepsintschedulestrvariance_typeclip_sampleboolprediction_typeclip_sample_minfloatclip_sample_maxreturnNonec           	        sb   t  j||fi | |tj vrtd|tj vr!td|| _||g| _|| _	|| _
d S )Nz?Argument `variance_type` must be a member of `DDPMVarianceType`zCArgument `prediction_type` must be a member of `DDPMPredictionType`)super__init__r   __members__values
ValueErrorr   r!   clip_sample_valuesr    r#   )	selfr   r   r    r!   r#   r$   r&   schedule_args	__class__r   r   r*   U   s   

zDDPMScheduler.__init__Nnum_inference_stepsdevicestr | torch.device | Nonec                 C  sx   || j krtd| d| j  d| j  d|| _| j | j }td||  ddd tj}t	|
|| _dS )a/  
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
            device: target device to put the data.
        z`num_inference_steps`: z3 cannot be larger than `self.num_train_timesteps`: zG as the unet model trained with this scheduler can only handle maximal z timesteps.r   N)r   r-   r3   nparangeroundastypeint64torch
from_numpyto	timesteps)r/   r3   r4   
step_ratior?   r   r   r   set_timestepsm   s   
&zDDPMScheduler.set_timestepstimestepx_0torch.Tensorx_tc           
      C  sv   | j | }| j| }|dkr| j|d  n| j}| | j|  d|  }| d|  d|  }|| ||  }	|	S )z
        Compute the mean of the posterior at timestep t.

        Args:
            timestep: current timestep.
            x0: the noise-free input.
            x_t: the input noised to timestep t.

        Returns:
            Returns the mean
        r   r   )alphasalphas_cumprodonesqrtbetas)
r/   rB   rC   rE   alpha_talpha_prod_talpha_prod_t_prevZx_0_coefficientZx_t_coefficientmeanr   r   r   	_get_mean   s   

zDDPMScheduler._get_meanpredicted_variancetorch.Tensor | Nonec           	      C  s   | j | }|dkr| j |d  n| j}d| d|  | j|  }| jtjkr/tj|dd}|S | jtjkr<| j| }|S | jtj	krH|durH|S | jtj
kri|duri|}| j| }|d d }|| d| |  }|S )z
        Compute the variance of the posterior at timestep t.

        Args:
            timestep: current timestep.
            predicted_variance: variance predicted by the model.

        Returns:
            Returns the variance
        r   r   g#B;)minN   )rG   rH   rJ   r    r   r   r<   clampr   r   r   )	r/   rB   rP   rL   rM   varianceZmin_logZmax_logfracr   r   r   _get_variance   s"   

	
zDDPMScheduler._get_variancemodel_outputr   	generatortorch.Generator | None!tuple[torch.Tensor, torch.Tensor]c                 C  s  |j d |j d d kr| jdv rtj||j d dd\}}nd}| j| }|dkr1| j|d  n| j}d| }d| }	| jtjkrO||d |  |d  }
n| jtj	krX|}
n| jtj
krj|d | |d |  }
| jrzt|
| jd | jd }
|d | j|  | }| j| d |	 | }||
 ||  }td}|dkrtj| |j|j||jd}| j||d	d | }|| }||
fS )
a7  
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.
            generator: random number generator.

        Returns:
            pred_prev_sample: Predicted previous sample
        r   rS   )r   r   )dimNr   g      ?)dtypelayoutrY   r4   )rP   )shaper    r<   splitrG   rH   r#   r   r   r   r   r!   rT   r.   rJ   rF   tensorrandnsizer]   r^   r4   rW   )r/   rX   rB   r   rY   rP   rL   rM   beta_prod_tbeta_prod_t_prevpred_original_samplepred_original_sample_coeffcurrent_sample_coeffpred_prev_samplerU   noiser   r   r   step   s@   "

zDDPMScheduler.step)r   r   r   r   r    r   r!   r"   r#   r   r$   r%   r&   r%   r'   r(   )N)r3   r   r4   r5   r'   r(   )rB   r   rC   rD   rE   rD   r'   rD   )rB   r   rP   rQ   r'   rD   )
rX   rD   rB   r   r   rD   rY   rZ   r'   r[   )r	   r
   r   r   r   r   r   r   r*   rA   rO   rW   rk   __classcell__r   r   r1   r   r   D   s    
"r   )
__future__r   numpyr7   r<   monai.utilsr   	schedulerr   r   r   r   r   r   r   r   <module>   s   