from datetime import datetime
from extensions import db
from enum import Enum
from sqlalchemy import Enum as SQLEnum
from flask_babel import _
from flask_babel import lazy_gettext as _l

from config import AVERAGE_INPUT_COST_PER_1M, AVERAGE_OUTPUT_COST_PER_1M
from config import DEFULT_TEXT_2_VOICE_MODEL_NAME, DEFULT_TEXT_MODEL_NAME, DEFULT_TEXT_STREAM_MODEL_NAME, DEFULT_VOICE_2_TEXT_MODEL_NAME



class ModelTokenUsage(db.Model):
    __tablename__ = "model_token_usage"

    id = db.Column(db.Integer, primary_key=True)

    model_id = db.Column(
        db.Integer,
        db.ForeignKey("openai_models.id"),
        nullable=False
    )

    input_tokens = db.Column(db.Integer, default=0)
    output_tokens = db.Column(db.Integer, default=0)

    # هزینه نهایی این درخواست (کش شده برای سرعت)
    total_cost = db.Column(db.Float, nullable=False)

    user_id = db.Column(db.Integer, nullable=True)  # اگر خواستی per-user

    created_at = db.Column(db.DateTime, default=datetime.utcnow)





class Balance(db.Model):
    id = db.Column(db.Integer, primary_key=True)

    initial_charge_usd = db.Column(db.Float, nullable=False, default=0.0)
    remaining_balance_usd = db.Column(db.Float, nullable=False, default=0.0)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)



class PromptType(Enum):
    TEXT = "text"
    ANSWER_STRUCT = "answer_struct"
    CONVERSATION = "conversation"
    TEXT_TO_VOICE = "text2voice"
    REFERENCES = "references"
    


class SystemPrompt(db.Model):
    __tablename__ = "system_prompts"

    id = db.Column(db.Integer, primary_key=True)

    # نوع پرامپت
    type = db.Column(
        db.Enum(PromptType),
        nullable=False,
        unique=True
    )

    # متن پرامپت
    content = db.Column(db.Text, nullable=False)

    # فعال / غیرفعال
    is_active = db.Column(db.Boolean, default=True)

    updated_by = db.Column(db.Integer, nullable=True)  # admin user id
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    created_at = db.Column(db.DateTime, default=datetime.utcnow)



class PromptFile(db.Model):
    __tablename__ = "prompt_files"

    id = db.Column(db.Integer, primary_key=True)

    user_id = db.Column(db.Integer, nullable=True)  # اگر خواستی per-user

    prompt_type = db.Column(
        db.Enum(PromptType),
        nullable=False,
        unique=False
    )
    filename = db.Column(db.String(255), nullable=False)
    filepath = db.Column(db.String(500), nullable=False)

    uploaded_at = db.Column(db.DateTime, default=datetime.utcnow)



class ModelType(Enum):
    TEXT = "text"
    TEXT_STREAM = "text_stream"
    TEXT_TO_VOICE = "text_to_voice"
    VOICE_TO_TEXT = "voice_to_text"
    TEXT_EMBED = "text_embed"

    @property
    def label(self):
        return {
            "text": _("متن به متن"),
            "text_stream": _("متن به متن (استریم)"),
            "text_to_voice": _("متن به صوت"),
            "voice_to_text": _("صوت به متن"),
            "text_embed": _("متن به بردار"),
        }[self.value]


    # def __init__(self, value, label):
    #     self._value_ = value
    #     self.label = label




class OpenAIModel(db.Model):
    __tablename__ = "openai_models"

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(100), unique=True, nullable=False)
    model_type = db.Column( SQLEnum(ModelType, name="model_type_enum"), nullable=False )
    input_cost_per_1m = db.Column(db.Float, nullable=False)
    output_cost_per_1m = db.Column(db.Float, nullable=False)
    cache_cost_per_1m = db.Column(db.Float, nullable=True,default=0)
    supports_temperature = db.Column( db.Boolean, nullable=False, default=True)
    supports_web_search = db.Column( db.Boolean, nullable=False, default=False)

    is_active = db.Column(db.Boolean, default=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    usages = db.relationship("ModelTokenUsage", backref="model", lazy=True)

    # رابطه با مدل‌های تخصصی
    text_model_config = db.relationship("AITaskModelConfig", backref="text_model", 
                                        foreign_keys="AITaskModelConfig.text_model_id", lazy=True)
    voice_to_text_model_config = db.relationship("AITaskModelConfig", backref="voice_to_text_model", 
                                                foreign_keys="AITaskModelConfig.voice_to_text_model_id", lazy=True)
    text_to_voice_model_config = db.relationship("AITaskModelConfig", backref="text_to_voice_model", 
                                                foreign_keys="AITaskModelConfig.text_to_voice_model_id", lazy=True)

    @classmethod
    def update_tokens_in_db(cls, model_name: str, response = None , user_id=None, input_tokens:int = 0 , output_tokens:int = 0):
        # دریافت اطلاعات توکن‌ها
        total_tokens = input_tokens + output_tokens

        if response is not None:
            total_tokens = response.usage.total_tokens
            input_tokens = response.usage.input_tokens
            output_tokens = response.usage.output_tokens
        

        print(model_name,"Input tokens:", input_tokens)
        print(model_name,"Output tokens:", output_tokens)
        print(model_name,"Total tokens:", total_tokens)

        # find model
        model = cls.query.filter_by(name=model_name).first()

        if not model:
            # current_app.logger.warning(f"Unknown OpenAI model used: {model_name}")
            input_cost_per_1M = AVERAGE_INPUT_COST_PER_1M
            output_cost_per_1M = AVERAGE_OUTPUT_COST_PER_1M
        else:
            input_cost_per_1M = model.input_cost_per_1m
            output_cost_per_1M = model.output_cost_per_1m

        # calculate cost
        input_cost = (input_tokens / 1_000_000) * input_cost_per_1M
        output_cost = (output_tokens / 1_000_000) * output_cost_per_1M
        total_cost = round(input_cost + output_cost, 6)

        print("Total cost $:", total_cost)

        # save in usage
        usage = ModelTokenUsage(
            model_id=model.id if model else None,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_cost=total_cost,
            user_id=user_id
        )
        db.session.add(usage)

        # update remain
        Balance.query.filter_by(id=1).update({
            Balance.remaining_balance_usd: Balance.remaining_balance_usd - total_cost
        })

        db.session.commit()


class AITaskModelConfig(db.Model):

    __tablename__ = "ai_task_model_configs"

    id = db.Column(db.Integer, primary_key=True)
    
    # مدل‌های مختلف برای هر نوع وظیفه
    text_model_id = db.Column(db.Integer, db.ForeignKey('openai_models.id'), nullable=True)
    text_stream_model_id = db.Column(db.Integer, db.ForeignKey('openai_models.id'), nullable=True)
    voice_to_text_model_id = db.Column(db.Integer, db.ForeignKey('openai_models.id'), nullable=True)
    text_to_voice_model_id = db.Column(db.Integer, db.ForeignKey('openai_models.id'), nullable=True)
    

    is_active = db.Column(db.Boolean, default=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    @classmethod
    def get_current_config(cls):
        """دریافت تنظیمات فعال فعلی"""
        return cls.query.filter_by(is_active=True).first()
    
    @classmethod
    def get_model_for_task(cls, task_type: ModelType) -> OpenAIModel:
        """دریافت مدل مناسب برای هر نوع وظیفه"""
        config = cls.get_current_config()
        if not config:
            return None
            
        if task_type == ModelType.TEXT:
            return OpenAIModel.query.get(config.text_model_id)
        elif task_type == ModelType.TEXT_STREAM:
            return OpenAIModel.query.get(config.text_stream_model_id)
        elif task_type == ModelType.VOICE_TO_TEXT:
            return OpenAIModel.query.get(config.voice_to_text_model_id)
        elif task_type == ModelType.TEXT_TO_VOICE:
            return OpenAIModel.query.get(config.text_to_voice_model_id)
        else:
            return None
    
    @classmethod
    def get_model_name_for_task(cls, task_type:ModelType):
        try:
            model = AITaskModelConfig.get_model_for_task(task_type)
            model_name = model.name
            return model_name
        except:
            if type == ModelType.TEXT:
                return DEFULT_TEXT_MODEL_NAME
            elif type == ModelType.TEXT_STREAM:
                return DEFULT_TEXT_STREAM_MODEL_NAME
            elif type == ModelType.VOICE_TO_TEXT:
                return DEFULT_VOICE_2_TEXT_MODEL_NAME
            elif type == ModelType.TEXT_TO_VOICE:
                return DEFULT_TEXT_2_VOICE_MODEL_NAME
            return None
