from SimCSE.Rag import prepare_rag_components,rag
import difflib
import json
import os
from openai import OpenAI
import re
from Prompt import Multi_element1,Multi_element2,Text_impress,Multi_element3,Multi_element4

# client = OpenAI(
#     api_key="OPENAI_API_KEY",
#     base_url="https://xxx.com/v1"
# )
client = OpenAI(
        api_key="sk-qphqotphzswkukwhhohnsraecigqwxkuvowlyjqbmrazwggi",
        base_url="https://api.siliconflow.cn/v1"
    )

def Data_Init(data_file, template_file):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(template_file, 'r', encoding='utf-8') as f:
        template = json.load(f)
    return data, template

def validate_triple(triple):
    for triples in triple:
        if len(triples) != 3:
            print(triples)
            print("三元组必须包含三个元素")
            return False
        if any(not elem.strip() for elem in triples):
            print(triples)
            print("存在空元素")
            return False
    return True

def Topic_Extract(input):

    output = []
    i = 0
    num = len(input)
    input = [Text_impress + s for s in input]
    while i < num:
        response = client.chat.completions.create(
            model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
            #model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
            messages=[{'role': 'user', 'content': input[i]}],
            stream=False
        )
        clean_output = re.sub(r'<think>.*?</think>', '', response.choices[0].message.content,       # 去除think部分
                              flags=re.DOTALL).strip()
        print(clean_output)
        output.append(clean_output)
        i += 1

    return output

def SimCSE(data_impress,template_impress):

    loadtext(data_impress,template_impress)
    model, test_dataloader, device = prepare_rag_components()
    index = rag(model, test_dataloader, device)
    return index

def loadtext(data_impress,template_impress):
    open("data/AstroKGC_test/Sim_rag.txt", 'w', encoding='utf-8').close()
    for i in range(len(data_impress)):
        for j in range(len(template_impress)):
            cleaned_data = data_impress[i].replace('\n', '')  # 先去换行
            cleaned_template = template_impress[j].replace('\n', '')  # 注意这里应为 j 还是 i？
            text = f"2012test-0043||{cleaned_data}||{cleaned_template}||1"  # 避免在 f-string 中使用转义符
            with open("data/AstroKGC_test/Sim_rag.txt", 'a',encoding='utf-8') as file:
                file.write(text)
                file.write("\n")

def Topic_Aware_Rag(data,template):
    Raginput = []
    if os.path.isfile('data/Paper_Triple/TAR_input.json'):
        with open('data/Paper_Triple/TAR_input.json', 'r', encoding='utf-8') as f:
            Raginput = json.load(f)
        return Raginput

    else:
        data_texts = [article[0] for article in data]
        templates_texts = [article[0] for article in template]

        if os.path.isfile('data/Template_Triple/template_impress.json'):
            with open('data/Template_Triple/template_impress.json', 'r', encoding='utf-8') as f:
                template_impress = json.load(f)
        else:
            template_impress = Topic_Extract(templates_texts)
            os.makedirs(os.path.dirname('data/Template_Triple/template_impress.json'), exist_ok=True)
            with open('data/Template_Triple/template_impress.json', 'w', encoding='utf-8') as f:  # 指定utf-8编码
                json.dump(template_impress, f, ensure_ascii=False, indent=2)

        if os.path.isfile('data/Paper_Triple/data_impress.json'):
            with open('data/Paper_Triple/data_impress.json', 'r', encoding='utf-8') as f:
                data_impress = json.load(f)
        else:
            data_impress = Topic_Extract(data_texts)
            os.makedirs(os.path.dirname('data/Paper_Triple/data_impress.json'), exist_ok=True)
            with open('data/Paper_Triple/data_impress.json', 'w', encoding='utf-8') as f:  # 指定utf-8编码
                json.dump(data_impress, f, ensure_ascii=False, indent=2)

        index = SimCSE(data_impress,template_impress)
        for i in range(len(index)):
            input = Multi_element1 + str(template[index[i]][0]) + Multi_element2 + str(template[index[i]][1]) + Multi_element3 +str(data[i][0]) + Multi_element4
            Raginput.append(input)

        os.makedirs(os.path.dirname('data/Paper_Triple/AstroKGC_input.json'), exist_ok=True)
        with open('data/Paper_Triple/AstroKGC_input.json', 'w', encoding='utf-8') as f:  # 指定utf-8编码
            json.dump(Raginput, f, ensure_ascii=False, indent=2)
    return Raginput
def Extract_Triple(input):
    output = []
    i = 0
    num = len(input)
    input = [Text_impress + s for s in input]
    while i < num:
        response = client.chat.completions.create(
            # model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
            model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
            messages=[{'role': 'user', 'content': input[i]}],
            stream=False
        )
        clean_output = re.sub(r'<think>.*?</think>', '', response.choices[0].message.content,  # 去除think部分
                              flags=re.DOTALL).strip()
        triples_str = re.findall(r'\[.*?\]', clean_output)  # 分割为多个三元组
        triples = []
        for triple_str in triples_str:  # 去除方括号并分割成实体和关系
            triple = [item.strip().strip("'") for item in triple_str.strip('[]').split(',')]
            triples.append(triple)
        if (i % 5 == 0):
            print(f"正在执行{i}")
        if (validate_triple(triples)):  # 对抽取的三元组格式进行验证
            i += 1
            output.append(triples)

    return output

def Score(gold,pre):

    goldlen = 0
    prelen = 0
    true = 0
    for i in range(len(gold)):
        goldlen += len(gold[i])
        prelen += len(pre[i])
    for i in range(len(gold)):
        for golds in gold[i]:
            for pres in pre[i]:
                if (difflib.SequenceMatcher(None, f"{golds[0]} {golds[1]} {golds[2]}", f"{pres[0]} {pres[1]} {pres[2]}").ratio()) >= 0.85:
                    true += 1

    return true / prelen, true / goldlen, 2 * (true / prelen) * (true / goldlen) / ( (true / prelen) + (true / goldlen))

