o
    .i5                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
 ddlmZ ddlmZ G d	d
 d
e
Z	dddZG dd deZdS )    )annotations)UnionN)LogisticNormal)StrEnum   )DDPMPredictionType)	Schedulerc                   @  s   e Zd ZdZejZdS )RFlowPredictionTypez
    Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.

    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
    N)__name__
__module____qualname____doc__r   V_PREDICTION r   r   j/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/schedulers/rectified_flow.pyr	   +   s    
r	            ?     c           	      C  sD   | | } || d|  }|| }||  d|d |    }|| }|S )aQ  
    Applies a transformation to the timestep based on image resolution scaling.

    Args:
        t (torch.Tensor): The original timestep(s).
        input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
        base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
        scale (float): Scaling factor for the transformation.
        num_train_timesteps (int): Total number of training timesteps.
        spatial_dim (int): Number of spatial dimensions in the image.

    Returns:
        torch.Tensor: Transformed timestep(s).
    r   r   r   )	tinput_img_size_numelbase_img_size_numelscalenum_train_timestepsspatial_dimZratio_spacerationew_tr   r   r   timestep_transform5   s   r   c                   @  s`   e Zd ZdZ											
d5d6ddZd7d d!Z	"	"d8d9d)d*Zd+d, Z	"d:d;d3d4Zd"S )<RFlowSchedulera  
    A rectified flow scheduler for guiding the diffusion process in a generative model.

    Supports uniform and logit-normal sampling methods, timestep transformation for
    different resolutions, and noise addition during diffusion.

    Args:
        num_train_timesteps (int): Total number of training timesteps.
        use_discrete_timesteps (bool): Whether to use discrete timesteps.
        sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
        loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
        scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
        use_timestep_transform (bool): Whether to apply timestep transformation.
            If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
        transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
        steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
        base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
        spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.

    Example:

        .. code-block:: python

            # define a scheduler
            noise_scheduler = RFlowScheduler(
                num_train_timesteps = 1000,
                use_discrete_timesteps = True,
                sample_method = 'logit-normal',
                use_timestep_transform = True,
                base_img_size_numel = 32 * 32 * 32,
                spatial_dim = 3
            )

            # during training
            inputs = torch.ones(2,4,64,64,32)
            noise = torch.randn_like(inputs)
            timesteps = noise_scheduler.sample_timesteps(inputs)
            noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
            predicted_velocity = diffusion_unet(
                x=noisy_inputs,
                timesteps=timesteps
            )
            loss = loss_l1(predicted_velocity, (inputs - noise))

            # during inference
            noisy_inputs = torch.randn(2,4,64,64,32)
            input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
            noise_scheduler.set_timesteps(
                num_inference_steps=30, input_img_size_numel=input_img_size_numel)
            )
            all_next_timesteps = torch.cat(
                (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
            )
            for t, next_t in tqdm(
                zip(noise_scheduler.timesteps, all_next_timesteps),
                total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
            ):
                predicted_velocity = diffusion_unet(
                    x=noisy_inputs,
                    timesteps=timesteps
                )
                noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
            final_output = noisy_inputs
    r   Tuniform        r   Fr   r   r   r   intuse_discrete_timestepsboolsample_methodstrlocfloatr   use_timestep_transformtransform_scalesteps_offsetr   r   c                   s   t j _| _| _|	 _|
 _|dvrtd| d| _|dkr8t	t
|gt
|g _ fdd _| _| _| _d S )N)r   logit-normalzsample_method = z:, which has to be chosen from ['uniform', 'logit-normal'].r+   c                   s(    j | jd fd d df | jS )Nr   )distributionsampleshapetodevice)xselfr   r   <lambda>   s   ( z)RFlowScheduler.__init__.<locals>.<lambda>)r	   r   prediction_typer   r"   r   r   
ValueErrorr$   r   torchtensorr,   sample_tr(   r)   r*   )r3   r   r"   r$   r&   r   r(   r)   r*   r   r   r   r2   r   __init__   s    

zRFlowScheduler.__init__original_samplestorch.Tensornoise	timestepsreturnc                 C  s   |  | j }d| }|jdkr!|d jdg|jdd R  }n|jdkr7|d jdg|jdd R  }ntd|j || d| |  }|S )	aV  
        Add noise to the original samples.

        Args:
            original_samples: original samples
            noise: noise to add to samples
            timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.

        Returns:
            noisy_samples: sample with added noise
        r      ).NNNNN   ).NNNz9noise tensor has to be 4D or 5D tensor, yet got shape of )r'   r   ndimexpandr.   r6   )r3   r;   r=   r>   Z
timepointsnoisy_samplesr   r   r   	add_noise   s   
"
"zRFlowScheduler.add_noiseNnum_inference_stepsr0   str | torch.device | Noner   
int | NoneNonec                   s   |j ks	|dk rtd| dj  dj  d|_fddtjD }jr2dd |D }jr? fd	d|D }t|tj	}jrQ|tj
}t||_ jj7  _d
S )a  
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
            device: target device to put the data.
            input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
        r   z`num_inference_steps`: zM should be at least 1, and cannot be larger than `self.num_train_timesteps`: zG as the unet model trained with this scheduler can only handle maximal z timesteps.c                   s    g | ]}d | j    j qS )r   )rG   r   ).0ir2   r   r   
<listcomp>   s    z0RFlowScheduler.set_timesteps.<locals>.<listcomp>c                 S  s   g | ]}t t|qS r   )r!   roundrK   r   r   r   r   rM      s    c              	     s$   g | ]}t | jjjd qS )r   r   r   r   )r   r   r   r   rO   r   r3   r   r   rM      s    N)r   r6   rG   ranger"   r(   nparrayastypefloat16int64r7   
from_numpyr/   r>   r*   )r3   rG   r0   r   r>   Ztimesteps_npr   rQ   r   set_timesteps   s.   

zRFlowScheduler.set_timestepsc                 C  s   | j dkrtj|jd f|jd| j }n| j dkr"| || j }| jr)| }| j	rIt
t|jdd }t||| j| jt|jd d}|S )z
        Randomly samples training timesteps using the chosen sampling method.

        Args:
            x_start (torch.Tensor): The input tensor for sampling.

        Returns:
            torch.Tensor: Sampled timesteps.
        r   r   )r0   r+      NrP   )r$   r7   randr.   r0   r   r9   r"   longr(   prodr8   r   r   len)r3   Zx_startr   r   r   r   r   sample_timesteps   s    

 
zRFlowScheduler.sample_timestepsmodel_outputtimestepr-   next_timestepUnion[int, None]!tuple[torch.Tensor, torch.Tensor]c           	      C  s   t | drt| jtstd|}|dur#t|}t|| | j }n| jdkr/dt| j nd}|||  }||| | j  }||fS )a  
        Predicts the next sample in the diffusion process.

        Args:
            model_output (torch.Tensor): Output from the trained diffusion model.
            timestep (int): Current timestep in the diffusion chain.
            sample (torch.Tensor): Current sample in the process.
            next_timestep (Union[int, None]): Optional next timestep.

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.
        rG   znum_inference_steps is missing or not an integer in the class.Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it.Nr   r   r    )hasattr
isinstancerG   r!   AttributeErrorr'   r   )	r3   r`   ra   r-   rb   Zv_preddtpred_post_samplepred_original_sampler   r   r   step  s   zRFlowScheduler.step)
r   Tr   r    r   Fr   r   r   r   )r   r!   r"   r#   r$   r%   r&   r'   r   r'   r(   r#   r)   r'   r*   r!   r   r!   r   r!   )r;   r<   r=   r<   r>   r<   r?   r<   )NN)rG   r!   r0   rH   r   rI   r?   rJ   )N)
r`   r<   ra   r!   r-   r<   rb   rc   r?   rd   )	r
   r   r   r   r:   rF   rY   r_   rk   r   r   r   r   r   P   s(    C
$.r   )r   r   r   r   )
__future__r   typingr   numpyrS   r7   Ztorch.distributionsr   monai.utilsr   ddpmr   	schedulerr   r	   r   r   r   r   r   r   <module>   s   
