import os
import re
from datetime import datetime
import threading

import pickle
import time
import faiss
import numpy as np
import openai

from DocumentManager.Constants import DUMMY_EMBED_SHAPE, NET_ACCESS, TIMESLEEP

from .Documents.PDFDocument import PDFDocument
from .Documents.WordDocument import WordDocument
from .Documents.TextDocument import TextDocument
from .Documents.Document import Document 
from DocumentManager.ProgressManager import ProgressManager




class DocumentManager:
    FORMAT_CLASSES:dict[str, PDFDocument|WordDocument|TextDocument] = {
        ".pdf": PDFDocument,
        ".docx": WordDocument,
        ".txt": TextDocument
    }



    def __init__(self, 
                 directory, 
                 client:openai.OpenAI, 
                 index_path="faiss_index.bin", 
                 metadata_path="faiss_metadata.pkl", 
                 cache_lifetime=7, 
                 chunk_size = 500,
                 chunk_overlap = 50,
                 batch_size=32,
                 progress_callback=None):
        self.directory = directory #directory of all documents
        self.client = client    
        self.documents:dict[str, PDFDocument|WordDocument|TextDocument] = {}  # path -> Document instance
        self.index = None   
        self.all_chunks = []
        self.all_metadata = []
        self.index_path = index_path
        self.metadata_path = metadata_path
        self.cache_lifetime = cache_lifetime
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.batch_size = batch_size
        self.progress_manager = ProgressManager()
        if progress_callback:
            self.progress_manager.set_callback(progress_callback)

        self._scan_lock = threading.Lock()

    def set_callback(self, callback):
        self.progress_manager.set_callback(callback)

    def clean_caches(self,):
        for root, dirs, filenames in os.walk(self.directory):
            for f in filenames:
                ext = os.path.splitext(f)[1].lower()
                if ext == Document.CHACHE_EXTENTION:
                    path = os.path.join(root, f)
                    try:
                        os.remove(path)
                    except Exception as e:
                        print('ERROR: error in remove document manager cashes:', str(e))

    def scan_directory(self, save=True):
        import os
        self.documents = {}
        print("Scan started...")  # ow scan start
        start_time = datetime.now()
        #scan directory
        files = []
        cached_files = []
        for root, dirs, filenames in os.walk(self.directory):
            for f in filenames:
                ext = os.path.splitext(f)[1].lower()

                if ext in self.FORMAT_CLASSES:
                    files.append(os.path.join(root, f))

                if ext == Document.CHACHE_EXTENTION:
                    cached_files.append(os.path.join(root, f))

        # remove old cach files that their doc was deleted
        now = datetime.now()
        for cf in cached_files:
            f = Document.cached_path2path(cf)
            if not os.path.exists(f):
                mtime = os.path.getmtime(cf)
                last_modified = datetime.fromtimestamp(mtime)
                days_passed = (now - last_modified).days
                if days_passed > self.cache_lifetime:
                    try:
                        os.remove(cf)
                    except Exception as e:
                        print(f"ERROR: Could not remove cache {cf}: {e}")
                        continue

        # process new and modify files or only read didn't modify files
        for f in files:
            ext = os.path.splitext(f)[1].lower()
            cls = self.FORMAT_CLASSES[ext]
            doc = self.documents.get(f, cls(f))
            doc.load_or_process(embedding_func=self.text_to_embedding, 
                                chunk_size=self.chunk_size,
                                chunk_overlap_sentence=self.chunk_overlap, 
                                batch_size=self.batch_size )
            self.documents[f] = doc
        
        info = {
                "title": "",
                "id": "FINISH",
                "completed": 1,
                "total": 1,
                "details":{
                }
            }
        self.progress_manager.update(info)

        
        self.rebuild_index()
        if save:
            self.save_index()

        end_time = datetime.now()
        elapsed = (end_time - start_time).total_seconds()
        print(f"Scan finished (elapsed time: {elapsed:.2f} sec)")



    def rebuild_index(self):
        self.all_chunks = []
        self.all_metadata = []
        embeddings_list = []

        for doc in self.documents.values():
            if len(doc.chunks) == 0:
                continue
            self.all_chunks.extend(doc.chunks)
            self.all_metadata.extend(doc.metadata)
            embeddings_list.append(doc.embeddings)

        if embeddings_list:
            all_embeddings = np.vstack(embeddings_list)
            dimension = all_embeddings.shape[1]
            self.index = faiss.IndexFlatIP(dimension) #use cosine distance
            faiss.normalize_L2(all_embeddings) #normalize vectors
            self.index.add(all_embeddings)
        else:
            self.index = None

    def save_index(self):
        if self.index:
            faiss.write_index(self.index, self.index_path)
            with open(self.metadata_path, "wb") as f:
                pickle.dump({"chunks": self.all_chunks, "metadata": self.all_metadata}, f)

    def load_index(self):
        import os
        if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
            self.index = faiss.read_index(self.index_path)
            with open(self.metadata_path, "rb") as f:
                data = pickle.load(f)
            self.all_chunks = data["chunks"]
            self.all_metadata = data["metadata"]

    def expand_query(self, query):
        expansions = [
            query,
            f"What factors determine {query}?",
            f"What is {query} attributed to?",
            f"Causes of {query}"
        ]
        return expansions
    
    def _build_query_embedding(self, query):
        # expanded = self.expand_query(query)
        # embeddings = [self.text_to_embedding(q) for q in expanded]
        # q_emb = np.mean(embeddings, axis=0).astype("float32")
        # --------------- OR ----------------
        q_emb = self.text_to_embedding(query)

        #------------------------------------
        q_emb = np.array([q_emb], dtype="float32")
        faiss.normalize_L2(q_emb)
        return q_emb
    
    def _search_index(self, q_emb, top_k):
        print('top_k',top_k,' q_emb ',q_emb)
        D, I = self.index.search(q_emb, top_k)
        return D[0], I[0]
    

    def has_index(self,):
        if not self.index:
            return False
        return True

    def search(self, query, top_k=20):
        if not self.index:
            return [], []

        q_emb = self._build_query_embedding(query)
        scores, indices = self._search_index(q_emb, top_k)

        retrieved_chunks = [self.all_chunks[i] for i in indices]
        retrieved_metadata = [self.all_metadata[i] for i in indices]
        if NET_ACCESS:
            top_chunks_rank = self.rerank(query, retrieved_chunks)
            top_chunks = [retrieved_chunks[i] for i in top_chunks_rank]
            top_metadata = [retrieved_metadata[i] for i in top_chunks_rank]
        else:
            top_chunks  = retrieved_chunks
            top_metadata  = retrieved_metadata
        return top_chunks, top_metadata
    

    def search_related(self, query, top_k=5, similarity_threshold=0.7):
        if not self.index:
            return [], []

        q_emb = self._build_query_embedding(query)
        scores, indices = self._search_index(q_emb, top_k)

        retrieved_chunks = []
        retrieved_metadata = []

        for sim, idx in zip(scores, indices):
            if sim >= similarity_threshold:
                retrieved_chunks.append(self.all_chunks[idx])
                retrieved_metadata.append(self.all_metadata[idx])

        return retrieved_chunks, retrieved_metadata


    def text_to_embedding(self, text:str|list[str], batch=False):
        if NET_ACCESS :
            response = self.client.embeddings.create( 
                model="text-embedding-ada-002",
                input=text
            )
            if batch:
                return [item.embedding for item in response.data]
            else:
                return response.data[0].embedding

        else:
            time.sleep(TIMESLEEP)
            return np.random.rand(*DUMMY_EMBED_SHAPE)


    @staticmethod 
    def safe_json_parse(text):
        # حذف code block
        import json
        text = re.sub(r"```json", "", text)
        text = re.sub(r"```", "", text)
        text = text.strip()

        return json.loads(text)
            
    def rerank(self, query, chunks, top_k=3, relevance_threshold=60, model_name="gpt-5-mini"):
        numbered_chunks = "\n\n".join(
            [f"[{i}] {chunk}" for i, chunk in enumerate(chunks)]
        )

        prompt = f"""You are a relevance scoring system.

        Question:
        {query}

        Below are numbered text passages.

        {numbered_chunks}

        For EACH passage return a JSON list in this format:

        [
        {{"id": 0, "score": 85}},
        {{"id": 1, "score": 12}},
        ...
        ]

        Score meaning:
        0 = completely unrelated
        50 = somewhat related but does not directly answer
        100 = directly and clearly answers the question

        Return ONLY valid JSON.
        """

        response = self.client.responses.create(
            model=model_name,
            instructions="You are a precise relevance scoring model.",
            input=[
                {"role": "user", "content": prompt}
            ],
            # temperature=0
        )

        result = response.output_text.strip()
        scored = DocumentManager.safe_json_parse(result)

        # filter base on threshold
        filtered = [item for item in scored if item["score"] >= relevance_threshold]

        # sorting
        filtered.sort(key=lambda x: x["score"], reverse=True)

        #update in db-----------------------------------------------------------------
        from models.ai_model import OpenAIModel
        OpenAIModel.update_tokens_in_db(model_name=model_name,response=response)
        #-----------------------------------------------------------------------------

        # برگرداندن فقط ایندکس‌ها
        return [item["id"] for item in filtered[:top_k]]


    def update_config(self,
                      cache_lifetime=None,
                      chunk_size=None,
                      chunk_overlap=None,
                      batch_size=None,
                      top_k=None):
        """
        Update DocumentManager configuration dynamically.
        Only provided values will be updated.
        """
        base_change = False
        if chunk_size != self.chunk_size or chunk_overlap !=self.chunk_overlap:
            base_change = True

        if cache_lifetime is not None:
            self.cache_lifetime = int(cache_lifetime)

        if chunk_size is not None:
            self.chunk_size = int(chunk_size)

        if chunk_overlap is not None:
            self.chunk_overlap = int(chunk_overlap)

        if batch_size is not None:
            self.batch_size = int(batch_size)
        if top_k is not None:
            self.top_k = int(top_k)


        # Run scan in background
        if base_change:
            self.clean_caches()
        self.run_safe_scan_directory()


    def run_safe_scan_directory(self):
        threading.Thread(target=self._safe_scan, daemon=True).start()
        

    def _safe_scan(self):
        if not self._scan_lock.acquire(blocking=False):
            print("⚠Scan already running...")
            return

        try:
            self.scan_directory()
        finally:
            self._scan_lock.release()
