o
    .i8                     @  sd   d dl mZ d dlmZ d dlZd dlZd dlmZ ddl	m
Z
 G dd deZG d	d
 d
e
ZdS )    )annotations)AnyN)StrEnum   )	Schedulerc                   @  s   e Zd ZdZdZdZdS )PNDMPredictionTypea  
    Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument.

    epsilon: predicting the noise of the diffusion process
    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
    epsilonv_predictionN)__name__
__module____qualname____doc__EPSILONV_PREDICTION r   r   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/schedulers/pndm.pyr   ,   s    r   c                      sf   e Zd ZdZddddejdfd( fddZd)d*ddZd+ddZd,d d!Z	d-d#d$Z
d.d&d'Z  ZS )/PNDMScheduleraS  
    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
    namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al.,
    "Pseudo Numerical Methods for Diffusion Models on Manifolds"  https://arxiv.org/abs/2202.09778

    Args:
        num_train_timesteps: number of diffusion steps used to train the model.
        schedule: member of NoiseSchedules, name of noise schedule function in component store
        skip_prk_steps:
            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
            before plms step.
        set_alpha_to_one:
            each diffusion step uses the value of alphas product at that step and at the previous one. For the final
            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
            otherwise it uses the value of alpha at step 0.
        prediction_type: member of DDPMPredictionType
        steps_offset:
            an offset added to the inference steps. You can use a combination of `offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
        schedule_args: arguments to pass to the schedule function
    i  linear_betaFr   num_train_timestepsintschedulestrskip_prk_stepsboolset_alpha_to_oneprediction_typesteps_offsetreturnNonec                   s   t  j||fi | |tj vrtd|| _|r tdn| j	d | _
d| _d| _|| _|| _t | _d| _t | _g | _| | d S )NzAArgument `prediction_type` must be a member of PNDMPredictionTypeg      ?r      )super__init__r   __members__values
ValueErrorr   torchtensoralphas_cumprodfinal_alpha_cumprodinit_noise_sigma
pndm_orderr   r   Tensorcur_model_outputcounter
cur_sampleetsset_timesteps)selfr   r   r   r   r   r   schedule_args	__class__r   r   r!   P   s   


zPNDMScheduler.__init__Nnum_inference_stepsdevicestr | torch.device | Nonec                 C  s`  || j krtd| d| j  d| j  d|| _| j | j }td||  tj| _|  j| j	7  _| j
rHtg | _| jddd | _nDt| j| j d dttd| j | d g| j }|dd dd	d ddd  | _| jdd
 ddd  | _t| j| jgtj}t||| _t| j| _g | _d| _dS )a/  
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
            device: target device to put the data.
        z`num_inference_steps`: z3 cannot be larger than `self.num_train_timesteps`: zG as the unet model trained with this scheduler can only handle maximal z timesteps.r   N   r   )r   r$   r5   nparangeroundastypeint64Z
_timestepsr   r   arrayprk_timestepsZplms_timestepsr*   repeattilecopyconcatenater%   
from_numpyto	timestepslenr/   r-   )r1   r5   r6   
step_ratiorA   rH   r   r   r   r0   w   s:   
 *
zPNDMScheduler.set_timestepsmodel_outputtorch.Tensortimestepsampletuple[torch.Tensor, Any]c                 C  s>   | j t| jk r| js| j|||ddfS | j|||ddfS )an  
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).
        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.
        Returns:
            pred_prev_sample: Predicted previous sample
        )rK   rM   rN   N)r-   rI   rA   r   step_prk	step_plms)r1   rK   rM   rN   r   r   r   step   s   zPNDMScheduler.stepc                 C  s(  | j du r	td| jd rdn| j| j  d }|| }| j| jd d  }| jd dkr<d| | _| j| || _n;| jd d dkrO|  jd| 7  _n(| jd d dkrb|  jd| 7  _n| jd	 d dkrw| jd|  }t	
 | _| j dkr| jn|}| ||||}|  jd7  _|S )
a  
        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
        solution to the differential equation.

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.

        Returns:
            pred_prev_sample: Predicted previous sample
        NaNumber of inference steps is 'None', you need to run 'set_timesteps' after creating the schedulerr9   r   r   gUUUUUU?r   gUUUUUU?   )r5   r$   r-   r   rA   r,   r/   appendr.   r%   r+   numel_get_prev_sample)r1   rK   rM   rN   Zdiff_to_prevprev_timestepr.   prev_sampler   r   r   rP      s,   


zPNDMScheduler.step_prkr   c                 C  s  | j du r	td| jst| jdk rt| j d|| j| j   }| jdkr7| jdd | _| j| n
|}|| j| j   }t| jdkrS| jdkrS|}|| _	nqt| jdkrq| jdkrq|| jd  d	 }| j	}t
 | _	nSt| jd	krd| jd  | jd
  d	 }n=t| jdkrd| jd  d| jd
   d| jd   d }ndd| jd  d| jd
   d| jd   d| jd    }| ||||}|  jd7  _|S )a  
        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
        times to approximate the solution.

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.

        Returns:
            pred_prev_sample: Predicted previous sample
        NrS   rT   zW can only be run AFTER scheduler has been run in 'prk' mode for at least 12 iterations r   r:   r   r8   r9               gUUUUUU?7   ;   %   	   )r5   r$   r   rI   r/   r4   r   r-   rU   r.   r%   r+   rW   )r1   rK   rM   rN   rX   rY   r   r   r   rQ      s:   


0<zPNDMScheduler.step_plmsrX   c                 C  s   | j | }|dkr| j | n| j}d| }d| }| jtjkr+|d | |d |  }|| d }	||d  || | d  }
|	| || | |
  }|S )Nr   r   g      ?)r'   r(   r   r   r   )r1   rN   rM   rX   rK   alpha_prod_talpha_prod_t_prevbeta_prod_tbeta_prod_t_prevZsample_coeffZmodel_output_denom_coeffrY   r   r   r   rW     s   


zPNDMScheduler._get_prev_sample)r   r   r   r   r   r   r   r   r   r   r   r   r   r   )N)r5   r   r6   r7   r   r   )rK   rL   rM   r   rN   rL   r   rO   )rK   rL   rM   r   rN   rL   r   rL   )rK   rL   rM   r   rN   rL   r   r   )rN   rL   rM   r   rX   r   rK   rL   )r
   r   r   r   r   r   r!   r0   rR   rP   rQ   rW   __classcell__r   r   r3   r   r   8   s    '
.

*4r   )
__future__r   typingr   numpyr;   r%   monai.utilsr   	schedulerr   r   r   r   r   r   r   <module>   s   