import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .dataloader import TestDataset, load_sts_data
from .model import SimcseModel
from transformers import BertTokenizer

def rag(model, dataloader, device):
    model.eval()  # 设置模型为评估模式
    time=0
    data=[]
    with torch.no_grad():  # 设定不计算梯度，提高评估效率
        for source, target, label in dataloader:  # 每次加载一个batch的数据
            # source        [batch, 1, seq_len] -> [batch, seq_len]
            source_input_ids = source.get('input_ids').squeeze(1).to(device)
            source_attention_mask = source.get('attention_mask').squeeze(1).to(device)
            source_token_type_ids = source.get('token_type_ids').squeeze(1).to(device)
            source_pred = model(source_input_ids, source_attention_mask, source_token_type_ids)
            # target        [batch, 1, seq_len] -> [batch, seq_len]
            target_input_ids = target.get('input_ids').squeeze(1).to(device)
            target_attention_mask = target.get('attention_mask').squeeze(1).to(device)
            target_token_type_ids = target.get('token_type_ids').squeeze(1).to(device)
            target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids)
            sim = F.cosine_similarity(source_pred, target_pred,dim=-1)
            max_value, max_index = sim.max(dim=0)
            print("最大值是:",max_value.item())
            print("最大值的位置是:",max_index.item())
            data.append(max_index.item())
            time=time+1
    print(time)
    return data
def prepare_rag_components(custom_args=None):
    # 处理参数：优先使用传入的参数，否则创建默认参数
    if custom_args is None:
        # 创建默认参数对象
        class Args:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            pretrain_model_path = "./SimCSE/bertsave/"
            data_path = "data/AstroKGC_test/"
            pooler = 'first-last-avg'
            batch_size = 20
            max_length = 300
            dropout = 0.15

        args = Args()
    else:
        args = custom_args

    # 初始化数据加载器
    test_path_sp = args.data_path + "Sim_rag.txt"
    test_data_source = load_sts_data(test_path_sp)
    tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_path)
    test_dataset = TestDataset(test_data_source, tokenizer, max_len=args.max_length)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=12)

    # 初始化模型
    model = SimcseModel(
        pretrained_model=args.pretrain_model_path,
        pooling=args.pooler,
        dropout=args.dropout
    ).to(args.device)

    return model, test_dataloader, args.device


if __name__ == '__main__':
    # 执行调用
    model, test_dataloader, device = prepare_rag_components()
    rag(model, test_dataloader, device)