# -*- coding: utf-8 -*-
# @Time    : 2025/2/2 16:20
# @Author  : ljc
# @FileName: config.py
# @Software: PyCharm


# 1. 简介
"""
目的：
    配置 PyTorch 运行环境, 设置计算精度与随机种子.
函数：
    1) default_set
    2) set_all_seeds
解释：
    1) default_set 函数: 设置 PyTorch 默认使用的浮点精度类型和计算设备 (CPU/GPU).
    2) set_all_seeds 函数: 设置所有随机数生成器的种子, 降低随机性.
"""


# 2. 导入所需库
import torch
import numpy as np
import random
import os
# 2.1 默认只使用一块 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# 3. 设置默认浮点类型、调用 GPU
def default_set(type=32) -> tuple[torch.dtype, torch.device]:
   
    """
    设置 PyTorch 默认使用的浮点精度和计算设备.
    
    输入参数:
    -----------
    参数:
        type (int): 浮点精度位数, 可选值为 16、32 或 64, 默认为 32.
    
    输出参数:
    -----------
    返回:
        tuple: (dtype, device) - 数据类型和计算设备.
        
    异常:
        ValueError: 当提供的 type 参数不是支持的值时抛出.
    """
    
    # 3.1 默认使用 GPU, 如果没有 GPU, 则使用 CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 3.2 设置数据类型, 默认 32 位浮点数
    if type == 16:
        dtype = torch.float16
    elif type == 32:
        dtype = torch.float32
    elif type == 64:
        dtype = torch.float64
    else:
        raise ValueError(f"Not supported floating point precision: {type}, please use 16, 32 or 64")
    
    # 3.3 设置全局默认数据类型
    torch.set_default_dtype(dtype)
    
    # 3.4 如果使用 CUDA 且为 float32, 启用 TF32 加速
    if type == 32 and device.type == 'cuda':
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # 3.5 返回数据类型和计算设备
    return dtype, device


# 4. 设置随机种子
def set_all_seeds(seed=666) -> None:

    """
    设置所有随机数生成器的种子, 降低随机性.
    注意: 在 GPU 上使用 32 位浮点数时, 即使设置了所有种子, 结果仍可能有微小差异.
    
    输入参数:
    -----------
    参数:
        seed (int): 随机种子, 默认为 666.
    
    输出参数:
    -----------
    返回:
        None.
    """

    # 4.1 设置 Python 的随机种子
    random.seed(seed)
    
    # 4.2 设置 numpy 的随机种子
    np.random.seed(seed)
    
    # 4.3 设置 PyTorch 的随机种子
    torch.manual_seed(seed)

    # 4.4 GPU 相关设置
    if torch.cuda.is_available():
        # 4.4.1 设置 CUDA 随机种子
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
        # 4.4.2 尽量减少 CUDNN 的随机性
        # 注意: 这些设置可以减少但不能完全消除 GPU 计算的非确定性
        # 特别是在使用 32 位浮点数和某些优化器 (如 Adam) 时, 结果可能仍有差异
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # 4.5 设置 Python 的 hash 种子以确保字典等的一致性
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # 4.6 注意: 要获得更高的确定性, 可考虑添加以下代码 (可能降低性能):
    # if hasattr(torch, 'use_deterministic_algorithms'):
    #     torch.use_deterministic_algorithms(True)
    # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"