U
    PhU                     @  sT  d dl mZ d dlZd dlZd dlZd dlmZ d dlmZm	Z	m
Z
 d dlZd dlZd dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZmZmZ erd dlm Z! dZ"d dl#Z#dZ$ned\Z!Z"ed\Z#Z$dgZ%G dd dZ&G dd de&Z'G dd de&Z(dddddZ)dddddZ*G dd dZ+dS )    )annotationsN)partial)TYPE_CHECKINGAnyCallable)	Optimizer)DEFAULT_PROTOCOL)
DataLoader)	eval_mode)ExponentialLRLinearLR)StateCachercopy_to_deviceoptional_importTzmatplotlib.pyplottqdmLearningRateFinderc                   @  sD   e Zd ZdddddddZedd Zd	d
 Zdd Zdd ZdS )DataLoaderIterr	   r   Nonedata_loaderimage_extractorlabel_extractorreturnc                 C  s>   t |tstdt| d|| _t|| _|| _|| _d S )NzLoader has unsupported type: z1. Expected type was `torch.utils.data.DataLoader`)	
isinstancer	   
ValueErrortyper   iter	_iteratorr   r   selfr   r   r    r    O/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/optimizers/lr_finder.py__init__/   s    

zDataLoaderIter.__init__c                 C  s   | j jS N)r   datasetr   r    r    r!   r$   9   s    zDataLoaderIter.datasetc                 C  s   |  |}| |}||fS r#   )r   r   )r   
batch_dataimageslabelsr    r    r!   inputs_labels_from_batch=   s    

z'DataLoaderIter.inputs_labels_from_batchc                 C  s   | S r#   r    r%   r    r    r!   __iter__B   s    zDataLoaderIter.__iter__c                 C  s   t | j}| |S r#   )nextr   r)   )r   batchr    r    r!   __next__E   s    
zDataLoaderIter.__next__N)	__name__
__module____qualname__r"   propertyr$   r)   r*   r-   r    r    r    r!   r   -   s   

r   c                      s4   e Zd Zddddddd fddZd	d
 Z  ZS )TrainDataLoaderIterTr	   r   boolr   )r   r   r   
auto_resetr   c                   s   t  ||| || _d S r#   )superr"   r4   )r   r   r   r   r4   	__class__r    r!   r"   L   s    zTrainDataLoaderIter.__init__c                 C  sf   zt | j}| |\}}W n@ tk
r\   | js4 t| j| _t | j}| |\}}Y nX ||fS r#   )r+   r   r)   StopIterationr4   r   r   )r   r,   inputsr(   r    r    r!   r-   R   s    

zTrainDataLoaderIter.__next__)T)r.   r/   r0   r"   r-   __classcell__r    r    r6   r!   r2   J   s    r2   c                      s@   e Zd ZdZddddd fddZdd	 Z fd
dZ  ZS )ValDataLoaderItera  This iterator will reset itself **only** when it is acquired by
    the syntax of normal `iterator`. That is, this iterator just works
    like a `torch.data.DataLoader`. If you want to restart it, you
    should use it like:

        ```
        loader_iter = ValDataLoaderIter(data_loader)
        for batch in loader_iter:
            ...

        # `loader_iter` should run out of values now, you can restart it by:
        # 1. the way we use a `torch.data.DataLoader`
        for batch in loader_iter:        # __iter__ is called implicitly
            ...

        # 2. passing it into `iter()` manually
        loader_iter = iter(loader_iter)  # __iter__ is called by `iter()`
        ```
    r	   r   r   r   c                   s&   t  ||| t| j| _d| _d S Nr   )r5   r"   lenr   	run_limitrun_counterr   r6   r    r!   r"   u   s    zValDataLoaderIter.__init__c                 C  s"   | j | jkrt| j| _d| _ | S r<   )r?   r>   r   r   r   r%   r    r    r!   r*   z   s    zValDataLoaderIter.__iter__c                   s   |  j d7  _ t  S )N   )r?   r5   r-   r%   r6   r    r!   r-      s    zValDataLoaderIter.__next__)r.   r/   r0   __doc__r"   r*   r-   r:   r    r    r6   r!   r;   `   s   r;   r   ztorch.Tensor)xr   c                 C  s   t | tr| d n| d }|S )z3Default callable for getting image from batch data.imager   r   dictrB   outr    r    r!   default_image_extractor   s    rH   c                 C  s   t | tr| d n| d }|S )z3Default callable for getting label from batch data.labelr@   rD   rF   r    r    r!   default_label_extractor   s    rJ   c                   @  s  e Zd ZdZddddeedfddddd	d
d	ddd	ddddZddddZdee	dddddddddfdddddddddddd	d	dd d!d"Z
d#dd$d%d&Zd'd( Zd=d)dd	dd*d+d,Zd>d-d	dd.d/d0Zd?ddd2d3d4d5Zd@ddd6d3d7d8ZdAddd	d9d	d9d:d;d<ZdS )Br   a  Learning rate range test.

    The learning rate range test increases the learning rate in a pre-training run
    between two boundaries in a linear or exponential manner. It provides valuable
    information on how well the network can be trained over a range of learning rates
    and what is the optimal learning rate.

    Example (fastai approach):
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100)
    >>> lr_finder.get_steepest_gradient()
    >>> lr_finder.plot() # to inspect the loss-learning rate graph

    Example (Leslie Smith's approach):
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear")

    Gradient accumulation is supported; example:
    >>> train_data = ...    # prepared dataset
    >>> desired_bs, real_bs = 32, 4         # batch size
    >>> accumulation_steps = desired_bs // real_bs     # required steps for accumulation
    >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True)
    >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps)

    By default, image will be extracted from data loader with x["image"] and x[0], depending on whether
    batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader
    returns something other than this, pass a callable function to extract it, e.g.:
    >>> image_extractor = lambda x: x["input"]
    >>> label_extractor = lambda x: x[100]
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor)

    References:
    Modified from: https://github.com/davidtvs/pytorch-lr-finder.
    Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
    NTFz	nn.Moduler   ztorch.nn.Modulezstr | torch.device | Noner3   z
str | Noneztypes.ModuleTypeintr   )model	optimizer	criteriondevicememory_cache	cache_diramppickle_modulepickle_protocolverboser   c                 C  s   || _ |   || _|| _g g d| _|| _|| _|| _|
| _t	| j
 j| _t||||	d| _| jd| j  | jd| j   |r|n| j| _dS )a8  Constructor.

        Args:
            model: wrapped model.
            optimizer: wrapped optimizer.
            criterion: wrapped loss function.
            device: device on which to test. run a string ("cpu" or "cuda") with an
                optional ordinal for the device type (e.g. "cuda:X", where is the ordinal).
                Alternatively, can be an object representing the device on which the
                computation will take place. Default: None, uses the same device as `model`.
            memory_cache: if this flag is set to True, `state_dict` of
                model and optimizer will be cached in memory. Otherwise, they will be saved
                to files under the `cache_dir`.
            cache_dir: path for storing temporary files. If no path is
                specified, system-wide temporary directory is used. Notice that this
                parameter will be ignored if `memory_cache` is True.
            amp: use Automatic Mixed Precision
            pickle_module: module used for pickling metadata and objects, default to `pickle`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: can be specified to override the default protocol, default to `2`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            verbose: verbose output
        Returns:
            None
        lrloss)	in_memoryrQ   rS   rT   rL   rM   N)rM   _check_for_schedulerrL   rN   historyrP   rQ   rR   rU   r+   
parametersrO   model_devicer   state_cacherstore
state_dict)r   rL   rM   rN   rO   rP   rQ   rR   rS   rT   rU   r    r    r!   r"      s&    )   zLearningRateFinder.__init__)r   c                 C  s:   | j | jd | j| jd | j | j dS )z9Restores the model and optimizer to their initial states.rL   rM   N)rL   load_state_dictr^   retrieverM   tor]   r%   r    r    r!   reset   s    zLearningRateFinder.resetg      $@d   expg?   r@   r	   zDataLoader | Noner   zfloat | Nonefloatstr)train_loader
val_loaderr   r   start_lrend_lrnum_iter	step_modesmooth_f
diverge_thaccumulation_stepsnon_blocking_transferr4   r   c                 C  s   g g d| _ td }| j| j |   |r:| | |dkrJtd| dkrft	| j
||}n*| dkrt| j
||}ntd| |	dk s|	dkrtd	t|||}|rt|||}| jrtrttjd
d}tjj}nt}t}||D ]}| jr ts td|d  d|  | j|||d}|rD| j||d}| j d | d  |  |dkrt|}n6|	dkr|	| d|	 | j d d   }||k r|}| j d | ||
| kr| jr|d  qq|r| jrtd |   dS )aj  Performs the learning rate range test.

        Args:
            train_loader: training set data loader.
            val_loader: validation data loader (if desired).
            image_extractor: callable function to get the image from a batch of data.
                Default: `x["image"] if isinstance(x, dict) else x[0]`.
            label_extractor: callable function to get the label from a batch of data.
                Default: `x["label"] if isinstance(x, dict) else x[1]`.
            start_lr : the starting learning rate for the range test.
                The default is the optimizer's learning rate.
            end_lr: the maximum learning rate to test. The test may stop earlier than
                this if the result starts diverging.
            num_iter: the max number of iterations for test.
            step_mode: schedule for increasing learning rate: (`linear` or `exp`).
            smooth_f: the loss smoothing factor within the `[0, 1[` interval. Disabled
                if set to `0`, otherwise loss is smoothed using exponential smoothing.
            diverge_th: test is stopped when loss surpasses threshold:
                `diverge_th * best_loss`.
            accumulation_steps: steps for gradient accumulation. If set to `1`,
                gradients are not accumulated.
            non_blocking_transfer: when `True`, moves data to device asynchronously if
                possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.
            auto_reset: if `True`, returns model and optimizer to original states at end
                of test.
        Returns:
            None
        rV   infr@   z `num_iter` must be larger than 1rf   linearz#expected one of (exp, linear), got r   z$smooth_f is outside the range [0, 1[zComputing optimal learning rate)descz+Computing optimal learning rate, iteration /)rs   rW   rX   z%Stopping early, the loss has divergedzResetting model and optimizerN)r[   rh   rL   rc   rO   rZ   _set_learning_rater   lowerr   rM   r   r2   r;   rU   has_tqdmr   r   trangewriterangeprint_train_batch	_validateappendget_lrsteprd   )r   rj   rk   r   r   rl   rm   rn   ro   rp   rq   rr   rs   r4   	best_lossZlr_schedule
train_iterval_iterr|   Ztprint	iterationrX   r    r    r!   
range_test   s\    .






zLearningRateFinder.range_testzfloat | list)new_lrsr   c                 C  s\   t |ts|gt| jj }t|t| jjkr8tdt| jj|D ]\}}||d< qFdS )z#Set learning rate(s) for optimizer.zYLength of `new_lrs` is not equal to the number of parameter groups in the given optimizerrW   N)r   listr=   rM   param_groupsr   zip)r   r   param_groupnew_lrr    r    r!   ry   y  s    
z%LearningRateFinder._set_learning_ratec                 C  s"   | j jD ]}d|krtdqdS )z/Check optimizer doesn't already have scheduler.
initial_lrz0Optimizer already has a scheduler attached to itN)rM   r   RuntimeError)r   r   r    r    r!   rZ     s    z'LearningRateFinder._check_for_schedulerr2   )r   rr   rs   r   c              
   C  s   | j   d}| j  t|D ]}t|\}}t||g| j|d\}}|  |}| ||}	|	| }	| j	rt
| jdr|d | dk}
tjj	j|	| j|
d}|  W 5 Q R X n|	  ||	 7 }q | j  |S )Nr   rO   non_blockingZ
_amp_stashr@   )delay_unscale)rL   trainrM   	zero_gradr~   r+   r   rO   rN   rR   hasattrtorchcudaZ
scale_lossbackwarditemr   )r   r   rr   rs   
total_lossir9   r(   outputsrX   r   Zscaled_lossr    r    r!   r     s"    



zLearningRateFinder._train_batchr;   )r   rs   r   c              	   C  sx   d}t | jV |D ]J\}}t||g| j|d\}}| |}| ||}|| t| 7 }qW 5 Q R X |t|j S )Nr   r   )r
   rL   r   rO   rN   r   r=   r$   )r   r   rs   Zrunning_lossr9   r(   r   rX   r    r    r!   r     s      

 zLearningRateFinder._validater   ztuple[list, list])
skip_startskip_endr   c                 C  sd   |dk rt d|dk r t d| jd }| jd }t|| d }||| }||| }||fS )zGet learning rates and their corresponding losses

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the end.
        r   zskip_start cannot be negativezskip_end cannot be negativerW   rX   r@   )r   r[   r=   )r   r   r   lrslossesend_idxr    r    r!   get_lrs_and_losses  s    

z%LearningRateFinder.get_lrs_and_lossesz)tuple[float, float] | tuple[(None, None)]c                 C  sZ   |  ||\}}z&tt| }|| || fW S  tk
rT   td Y dS X dS )aC  Get learning rate which has steepest gradient and its corresponding loss

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the end.

        Returns:
            Learning rate which has steepest gradient and its corresponding loss
        zBFailed to compute the gradients, there might not be enough points.)NNN)r   npgradientarrayargminr   r   )r   r   r   r   r   Zmin_grad_idxr    r    r!   get_steepest_gradient  s    
z(LearningRateFinder.get_steepest_gradientz
Any | None)r   r   log_lraxsteepest_lrr   c              	   C  s   t std dS | ||\}}d}|dkr:t \}}||| |r| ||\}	}
|	dk	r|j|	|
dddddd |	  |r|
d	 |d
 |d |dk	rt  |S )a  Plots the learning rate range test.

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the start.
            log_lr: True to plot the learning rate in a logarithmic
                scale; otherwise, plotted in a linear scale.
            ax: the plot is created in the specified matplotlib axes object and the
                figure is not be shown. If `None`, then the figure and axes object are
                created in this method and the figure is shown.
            steepest_lr: plot the learning rate which had the steepest gradient.

        Returns:
            The `matplotlib.axes.Axes` object that contains the plot. Returns `None` if
            `matplotlib` is not installed.
        z(Matplotlib is missing, can't plot resultNK   ored   zsteepest gradient)smarkercolorzorderrI   logzLearning rateLoss)has_matplotlibwarningswarnr   pltsubplotsplotr   scatterlegend
set_xscale
set_xlabel
set_ylabelshow)r   r   r   r   r   r   r   r   figZlr_at_steepest_gradZloss_at_steepest_gradr    r    r!   r     s8    
	


zLearningRateFinder.plot)T)T)r   r   )r   r   )r   r   TNT)r.   r/   r0   rA   pickler   r"   rd   rH   rJ   r   ry   rZ   r   r   r   r   r   r    r    r    r!   r      sH   +$@
*z #     ),
__future__r   r   typesr   	functoolsr   typingr   r   r   numpyr   r   torch.nnnntorch.optimr   torch.serializationr   torch.utils.datar	   monai.networks.utilsr
   Zmonai.optimizers.lr_schedulerr   r   monai.utilsr   r   r   matplotlib.pyplotpyplotr   r   r   r{   __all__r   r2   r;   rH   rJ   r   r    r    r    r!   <module>   s8   %