
from utils import *
from utils.clinical_ranges import evaluate_all_metrics
from langchain.chains import LLMChain
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

# Mapping from LLM title to input data key
title_to_key = {
    "Heart Rate (BPM)": "heart_rate",
    "Breathing Rate": "breathing_rate",
    "Oxygen Saturation": "oxygen_saturation",
    "Blood Pressure": "blood_pressure",
    "Stress Level": "stress_level",
    "Heart Variability": "heart_variability",
    "PRQ": "prq",
    "Activity": "activity",
    "Sleep": "sleep",
    "Equilibrium": "equilibrium",
    "Metabolism": "metabolism",
    "Relaxation": "relaxation",
    "Cardiovascular Age": "cardiovascular_age",
    "Hemoglobin": "hemoglobin",
    "Cholesterol": "cholesterol",
    "A1C Risk": "a1c_risk",
    "Cholesterol Risk": "cholesterol_risk",
    "A1C Range": "a1c_range",
    "Cholesterol Range": "cholesterol_range",
    "Wellness Score": "wellness_score",
    "Atrial Fibrillation": "atrial_fibrillation",
    "HbA1c": "hba1c",
    "Atrial Fibrillation": "atrial_fibrillation",
    "HbA1c": "hba1c",
    "Cardiovascular BMI": "cardiovascular_bmi"
}

@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=1, max=10),
    retry=retry_if_exception_type((ValueError, KeyError, Exception))
)
def call_llm_with_retry(data):
    """LLM call with automatic retry on failure"""
    # Use modern RunnableSequence (prompt | llm) instead of deprecated LLMChain
    runnable = Prompt | llm | parser
    llm_output = runnable.invoke(input={"data": data})
    return llm_output

def response(data, well_score=False):
    try:
        # Calculate status and score programmatically
        data_dict = data.dict() if hasattr(data, 'dict') else data
        
        # Filter out None values to prevent LLM hallucination
        filtered_data = {k: v for k, v in data_dict.items() if v is not None}
        
        clinical_evaluations = evaluate_all_metrics(filtered_data)
        
        # Prepare detailed input for LLM so it knows the calculated status
        llm_input_list = []
        for key, value in filtered_data.items():
            status = clinical_evaluations.get(key, {}).get("status", "Normal")
            title = next((t for t, k in title_to_key.items() if k == key), key.replace("_", " ").title())
            llm_input_list.append(f"{title}: {value} (Status: {status})")
            # If it's wellness_score, also add Overall Wellness Score to LLM input
            if key == "wellness_score":
                llm_input_list.append(f"Overall Wellness Score: {value} (Status: {status})")
        
        llm_input_string = "\n".join(llm_input_list)
        
        # Call LLM with detailed status data
        llm_response_obj = call_llm_with_retry(llm_input_string)

        # Get metrics list from 'Data' or 'metrics'
        metrics_list = getattr(llm_response_obj, 'Data', None)
        if metrics_list is None:
            metrics_list = getattr(llm_response_obj, 'metrics', [])

        # Build formatted_data with programmatic status/score, LLM tips
        formatted_data = []
        
        # Create a map of LLM responses for easy lookup
        llm_metrics_map = {}
        for metric in metrics_list:
            title = getattr(metric, "title", str(metric))
            llm_metrics_map[title] = metric
            # Add normalized key for robust lookup
            norm_key = title.split('(')[0].strip().lower()
            llm_metrics_map[norm_key] = metric

        # Iterate through INPUT data keys to ensure we only return what was sent
        for key, value in filtered_data.items():
            # Find the corresponding title for this key
            title = next((t for t, k in title_to_key.items() if k == key), None)
            if not title:
                # Fallback: try to match key to title format
                title = key.replace("_", " ").title()
            
            # Get programmatic evaluation
            if key in clinical_evaluations:
                status = clinical_evaluations[key]["status"]
                # Use raw value for wellness_score as requested by user
                if key == "wellness_score":
                    score = value
                else:
                    score = clinical_evaluations[key]["score"]
            else:
                status = "Normal"
                score = value if key == "wellness_score" else 5

            # Get tip from LLM response if available
            tip = ""
            if title in llm_metrics_map:
                tip = getattr(llm_metrics_map[title], "tip", "")
            else:
                # Try normalized lookup
                norm_title = title.split('(')[0].strip().lower()
                if norm_title in llm_metrics_map:
                    tip = getattr(llm_metrics_map[norm_title], "tip", "")
            
            # Fallback tip if LLM failed to provide one
            if not tip:
                if status == "Normal" or status == "Optimal":
                    tip = "This state supports overall physical resilience, so continue with your current wellness habits."
                elif status == "High" or status == "Elevated":
                    tip = "Consider a gentle adjustment to your daily routine to help bring this closer to a balanced range."
                elif status == "Low":
                    tip = "Prioritizing restorative habits and balanced nutrition can help support a more optimal level."
                else:
                    tip = "Maintaining consistent daily habits will help support long-term stability in this area."
            
            # Use raw scores for all metrics
            display_score = score

            formatted_data.append({
                "title": title,
                "status": status,
                "tip": tip,
                "score": display_score
            })

        # Add Overall Wellness Score as a separate entry with the same score
        if well_score:
            wellness_val = filtered_data.get("wellness_score")
            if wellness_val is not None:
                # Get tip for Overall Wellness Score from LLM map
                tip = ""
                if "Overall Wellness Score" in llm_metrics_map:
                    tip = getattr(llm_metrics_map["Overall Wellness Score"], "tip", "")
                
                # Score-appropriate fallback if LLM didn't provide a tip
                if not tip:
                    if wellness_val >= 8:
                        tip = "Your overall wellness is strong. Maintain your current routines to preserve this balance."
                    elif wellness_val >= 6:
                        tip = "Your wellness is on track but has room to improve. Focus on one area—sleep, movement, or stress—to elevate your baseline."
                    elif wellness_val >= 4:
                        tip = "Your body is showing signs of strain. Prioritize rest, hydration, and reduced intensity today."
                    else:
                        tip = "Your recovery is significantly impaired. Focus on immediate rest and minimizing physical and mental load."
                
                # Explicitly create dict without status
                overall_entry = {
                    "title": "Overall Wellness Score",
                    "tip": tip,
                    "score": wellness_val
                }
                formatted_data.append(overall_entry)
                
        # Final verification: Ensure Overall Wellness Score has no status
        for entry in formatted_data:
            if entry.get("title") == "Overall Wellness Score" and "status" in entry:
                del entry["status"]

        return JSONResponse(
            status_code=200,
            content={
                "ResponseCode": 200,
                "Status": True,
                "Message": "Get user scan report successfully",
                "Data": formatted_data
            }
        )

    except Exception as e:
        import traceback
        traceback.print_exc()
        print(f"Error in response function: {e}")
        # Optionally log the error here
        raise HTTPException(
            status_code=500,
            detail={
                "ResponseCode": 500,
                "Status": False,
                "Message": "Internal server error. Please try again later.",
                # "Message": f"Failed to process please try {e}",
                "Data": []
            }
        )
