import pandas as pd
import numpy as np
from datetime import date, timedelta
from typing import Dict, Any, List, Optional
from health.services.query import HealthDataQuery
from health.services.manual_log_storage import ManualLogStorage
from health.utils.logging_config import setup_logger

logger = setup_logger(__name__)

class HealthAnalyst:
    """
    Advanced analytics engine for health data using Pandas.
    Calculates correlations, trends, and lifestyle impacts.
    """

    def __init__(self):
        self.query = HealthDataQuery()
        self.manual_storage = ManualLogStorage()

    def get_dataframe(self, days: int = 30) -> pd.DataFrame:
        """
        Fetch all relevant data for the last N days and merge into a single DataFrame.
        """
        end_date = date.today()
        start_date = end_date - timedelta(days=days)
        
        # 1. Fetch Daily Metrics (Steps, Sleep, HR, Stess, Body Battery)
        # We need to fetch each metric range and merge by date
        metrics_to_fetch = ["sleep", "heart_rate", "stress", "body_battery", "steps", "hrv", "rhr"]
        
        data_map = {} # date -> dict of metrics
        
        # Initialize dates
        curr = start_date
        while curr <= end_date:
            data_map[curr.isoformat()] = {"date": curr.isoformat()}
            curr += timedelta(days=1)

        # Batch fetch metrics (this might be slow if we do it one by one, 
        # but HealthDataQuery is built for this. Ideally we optimize query later)
        for metric in metrics_to_fetch:
            points = self.query.get_metric_range(metric, start_date, end_date)
            for p in points:
                # Key extraction logic similar to Reader
                # Simplified extraction for DataFrame
                val = np.nan
                if hasattr(p, 'value'): val = p.value
                elif isinstance(p, dict):
                    # Try common keys
                    if 'average_heart_rate' in p: val = p['average_heart_rate']
                    elif 'overall_sleep_score' in p: val = p['overall_sleep_score']
                    elif 'resting_heart_rate' in p: val = p['resting_heart_rate']
                    elif 'average_stress_level' in p: val = p['average_stress_level']
                    elif 'charged' in p: val = p.get('charged', 0) # Body battery charged? 
                    # Actually for body battery we usually want max or charged
                    
                # Store in map
                d_str = p.get('calendar_date') or p.get('date')
                if d_str and d_str in data_map:
                    data_map[d_str][metric] = val

        # 2. Fetch Manual Logs (Alcohol, Fasting, Supplements)
        logs = self.manual_storage.get_logs_in_range(start_date, end_date)
        for log in logs:
            d_str = log.log_date
            if d_str in data_map:
                # Alcohol (Boolean & Amount)
                data_map[d_str]["alcohol_units"] = sum([1 for _ in log.alcohol_entries])
                data_map[d_str]["has_alcohol"] = 1 if log.alcohol_entries else 0
                
                # Fasting (Categorical)
                data_map[d_str]["fasting_mode"] = log.fasting_mode or "Normal"
                
                # Supplements (One-hotish checks)
                data_map[d_str]["has_magnesium"] = 1 if any("magnesium" in s.supplement_name.lower() for s in log.supplement_entries) else 0

        # Convert to DataFrame
        df = pd.DataFrame(list(data_map.values()))
        df['date'] = pd.to_datetime(df['date'])
        df = df.sort_values('date').set_index('date')
        return df

    def analyze_recovery_correlations(self, days: int = 90) -> Dict[str, Any]:
        """
        Analyze impact of Alcohol and Activity on Next-Day Recovery (Sleep, HRV, RHR).
        """
        df = self.get_dataframe(days)
        
        if df.empty:
            return {"error": "No data available"}

        # Shift recovery metrics to represent "Next Day"
        # We want to see if Alcohol on Day T affects Sleep on Day T (technically sleep starts on Day T night)
        # Actually usually Alcohol on Day T affects Sleep Score of the night of Day T (which serves Day T+1 morning)
        # Garmin assigns Sleep Date to the morning it ends. 
        # So Alcohol on Jan 1 affects Sleep of Jan 2.
        
        # Let's align: 
        # Feature: Alcohol(T)
        # Target: Sleep(T+1), RHR(T+1)
        
        df['next_sleep'] = df['sleep'].shift(-1)
        df['next_rhr'] = df['rhr'].shift(-1)
        df['next_hrv'] = df['hrv'].shift(-1)
        
        results = {}
        
        # 1. Alcohol Impact
        alcohol_days = df[df['has_alcohol'] == 1]
        sober_days = df[df['has_alcohol'] == 0]
        
        if len(alcohol_days) > 0 and len(sober_days) > 0:
            results['alcohol_impact'] = {
                "sample_size_alcohol": len(alcohol_days),
                "sample_size_sober": len(sober_days),
                "avg_sleep_alcohol": float(alcohol_days['next_sleep'].mean()),
                "avg_sleep_sober": float(sober_days['next_sleep'].mean()),
                "avg_rhr_alcohol": float(alcohol_days['next_rhr'].mean()),
                "avg_rhr_sober": float(sober_days['next_rhr'].mean()),
                "sleep_diff": float(alcohol_days['next_sleep'].mean() - sober_days['next_sleep'].mean())
            }
            
        return results

    def analyze_fitness_trends(self, days: int = 90) -> Dict[str, Any]:
        """
        Analyze RHR vs Activity trends.
        """
        df = self.get_dataframe(days)
        if df.empty: return {"error": "No data"}
        
        # Calculate rolling averages (7-day)
        df['rhr_roll'] = df['rhr'].rolling(7).mean()
        df['steps_roll'] = df['steps'].rolling(7).mean()
        
        # Simple Linear Trend
        # Drop NaNs
        valid = df.dropna(subset=['rhr'])
        if len(valid) > 10:
            # Simple slope (not robust but indicative)
            # rhr change over period
            start_rhr = valid.iloc[0]['rhr_roll'] if not np.isnan(valid.iloc[0]['rhr_roll']) else valid.iloc[0]['rhr']
            end_rhr = valid.iloc[-1]['rhr_roll'] if not np.isnan(valid.iloc[-1]['rhr_roll']) else valid.iloc[-1]['rhr']
            
            return {
                "period_days": days,
                "start_rhr_7d_avg": float(start_rhr),
                "end_rhr_7d_avg": float(end_rhr),
                "trend": "improving" if end_rhr < start_rhr else "declining"
            }
        
        return {"msg": "Not enough data for trend analysis"}

    def analyze_lifestyle_impact(self, days: int = 30) -> Dict[str, Any]:
        """
        Analyze Fasting and Supplements.
        """
        df = self.get_dataframe(days)
        if df.empty: return {"error": "No data"}
        
        results = {}
        
        # Fasting Impact
        # Compare "Normal" vs other modes
        fasting_stats = df.groupby('fasting_mode')['sleep'].mean().to_dict()
        results['fasting_sleep_scores'] = fasting_stats
        
        return results

    def compare_groups(self, condition_col: str, target_col: str, days: int = 90) -> Dict[str, Any]:
        """
        Compare average of 'target_col' when 'condition_col' is non-zero vs zero.
        Example: Sleep Score when has_magnesium=1 vs 0.
        """
        df = self.get_dataframe(days)
        if df.empty: return {"error": "No data"}
        
        if condition_col not in df.columns or target_col not in df.columns:
            return {"error": f"Columns not found: {condition_col} or {target_col}"}
            
        group_true = df[df[condition_col] > 0][target_col]
        group_false = df[df[condition_col] == 0][target_col]
        
        if len(group_true) == 0 or len(group_false) == 0:
            return {"error": "One group has no data (e.g. never took supplement)"}
            
        avg_true = group_true.mean()
        avg_false = group_false.mean()
        diff_pct = ((avg_true - avg_false) / avg_false) * 100 if avg_false != 0 else 0
        
        return {
            "condition": condition_col,
            "target": target_col,
            "avg_with_condition": float(avg_true),
            "avg_without_condition": float(avg_false),
            "sample_with": len(group_true),
            "sample_without": len(group_false),
            "difference_pct": float(diff_pct),
            "verdict": "better" if avg_true > avg_false else "worse"
        }
    def analyze_lagged_correlation(self, driver: str, target: str, lag: int = 1, days: int = 90) -> Dict[str, Any]:
        """
        Analyze if 'driver' on Day T correlates with 'target' on Day T+lag.
        Example: Alcohol (T) -> Sleep Score (T+1).
        
        Args:
            driver: Column name of the driver (e.g., 'alcohol_units', 'stress').
            target: Column name of the target (e.g., 'sleep', 'hrv', 'rhr').
            lag: Days to shift target (1 means next day).
            days: Lookback window.
        """
        df = self.get_dataframe(days)
        if df.empty: return {"error": "No data"}
        
        # Verify columns exist
        if driver not in df.columns or target not in df.columns:
            return {"error": f"Columns not found: {driver} or {target}"}
        
        # Shift target
        # lag=1 means we align Driver(T) with Target(T+1)
        target_col = f"target_lag_{lag}"
        df[target_col] = df[target].shift(-lag)
        
        # Filter valid rows
        valid = df.dropna(subset=[driver, target_col])
        if len(valid) < 5:
            return {"error": "Not enough data points (<5)"}
            
        # Calculate Correlation
        corr = valid[driver].corr(valid[target_col])
        
        return {
            "driver": driver,
            "target": target,
            "lag_days": lag,
            "correlation": float(corr) if not pd.isna(corr) else 0.0,
            "sample_size": len(valid),
            "msg": f" Correlation {corr:.2f} (1.0 is perfect positive, -1.0 is perfect negative)"
        }
