o
    .i,8                     @  sL   d dl mZ d dlZd dlZddlmZ ddlmZ eZ	G dd deZ
dS )    )annotationsN   )DDPMPredictionType)	Schedulerc                      sf   e Zd ZdZdddddejddfd. fddZd/d0ddZd1d!d"Z	#	d2d3d*d+Z	d4d,d-Z
  ZS )5DDIMSchedulera  
    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
    diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
    Implicit Models" https://arxiv.org/abs/2010.02502

    Args:
        num_train_timesteps: number of diffusion steps used to train the model.
        schedule: member of NoiseSchedules, name of noise schedule function in component store
        clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
        set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
            For the final step there is no previous alpha. When this option is `True` the previous alpha product is
            fixed to `1`, otherwise it uses the value of alpha at step 0.
        steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
            stable diffusion.
        prediction_type: member of DDPMPredictionType
        clip_sample_min: minimum clipping value when clip_sample equals True
        clip_sample_max: maximum clipping value when clip_sample equals True
        schedule_args: arguments to pass to the schedule function

    i  Zlinear_betaTr   g            ?num_train_timestepsintschedulestrclip_sampleboolset_alpha_to_onesteps_offsetprediction_typeclip_sample_minfloatclip_sample_maxreturnNonec	           
        s   t  j||fi |	 |tj vrtd|| _|r tdn| j	d | _
d| _ttd| jd d d tj| _|| _||g| _|| _|  | | j d S )NzAArgument `prediction_type` must be a member of DDIMPredictionTyper   r   )super__init__DDIMPredictionType__members__values
ValueErrorr   torchtensoralphas_cumprodfinal_alpha_cumprodZinit_noise_sigma
from_numpynparanger   astypeint64	timestepsr   clip_sample_valuesr   set_timesteps)
selfr   r
   r   r   r   r   r   r   Zschedule_args	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/schedulers/ddim.pyr   B   s   (
zDDIMScheduler.__init__Nnum_inference_stepsdevicestr | torch.device | Nonec                 C  s   || j krtd| d| j  d| j  d|| _| j | j }| j|kr/td| j d| dtd||  d	d	d
  tj	}t
||| _|  j| j7  _d	S )a/  
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
            device: target device to put the data.
        z`num_inference_steps`: z3 cannot be larger than `self.num_train_timesteps`: zG as the unet model trained with this scheduler can only handle maximal z timesteps.z`steps_offset`: zR cannot be greater than or equal to `num_train_timesteps // num_inference_steps : z@` as this will cause timesteps to exceed the max train timestep.r   Nr   )r   r   r.   r   r"   r#   roundcopyr$   r%   r   r!   tor&   )r)   r.   r/   
step_ratior&   r,   r,   r-   r(   h   s&   


*zDDIMScheduler.set_timestepstimestepprev_timesteptorch.Tensorc                 C  sJ   | j | }|dkr| j | n| j}d| }d| }|| d||   }|S )Nr   r   )r   r    )r)   r5   r6   alpha_prod_talpha_prod_t_prevbeta_prod_tbeta_prod_t_prevvariancer,   r,   r-   _get_variance   s   
zDDIMScheduler._get_variance        model_outputsampleeta	generatortorch.Generator | None!tuple[torch.Tensor, torch.Tensor]c                 C  s  || j | j  }| j| }|dkr| j| n| j}d| }	|}
|}| jtjkr6||	d |  |d  }
|}n3| jtjkrK|}
||d |
  |	d  }n| jtjkri|d | |	d |  }
|d | |	d |  }| j	ryt
|
| jd | jd }
| ||}||d  }d| |d  d | }|d |
 | }|dkrt
t
|r|jnd}t
j|j|j||d}| ||d | | }|| }||
fS )a  
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.
            eta: weight of noise for added noise in diffusion step.
            generator: random number generator.

        Returns:
            pred_prev_sample: Predicted previous sample
            pred_original_sample: Predicted original sample
        r   r         ?   cpu)dtyperB   r/   )r   r.   r   r    r   r   EPSILONSAMPLEV_PREDICTIONr   r   clampr'   r=   r/   	is_tensorrandnshaperH   )r)   r?   r5   r@   rA   rB   r6   r8   r9   r:   pred_original_samplepred_epsilonr<   Z	std_dev_tpred_sample_directionZpred_prev_sampler/   noiser,   r,   r-   step   s:   #
zDDIMScheduler.stepc                 C  s  || j | j  }| j| }|dkr| j| n| j}d| }|}|}	| jtjkr6||d |  |d  }|}	n3| jtjkrK|}||d |  |d  }	n| jtjkri|d | |d |  }|d | |d |  }	| j	ryt
|| jd | jd }d| d |	 }
|d | |
 }||fS )a?  
        Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output: direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: current instance of sample being created by diffusion process.

        Returns:
            pred_prev_sample: Predicted previous sample
            pred_original_sample: Predicted original sample
        r   r   rE   )r   r.   r   r    r   r   rI   rJ   rK   r   r   rL   r'   )r)   r?   r5   r@   r6   r8   r9   r:   rP   rQ   rR   Zpred_post_sampler,   r,   r-   reversed_step   s,   
zDDIMScheduler.reversed_step)r   r	   r
   r   r   r   r   r   r   r	   r   r   r   r   r   r   r   r   )N)r.   r	   r/   r0   r   r   )r5   r	   r6   r	   r   r7   )r>   N)r?   r7   r5   r	   r@   r7   rA   r   rB   rC   r   rD   )r?   r7   r5   r	   r@   r7   r   rD   )__name__
__module____qualname____doc__r   rI   r   r(   r=   rT   rU   __classcell__r,   r,   r*   r-   r   +   s"    &
Vr   )
__future__r   numpyr"   r   ddpmr   	schedulerr   r   r   r,   r,   r,   r-   <module>   s   